Skip to content

Commit a938d66

Browse files
committed
fix(#516): add origin check to websockets
1 parent 6a0b786 commit a938d66

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLConfiguration.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public class GraphQLConfiguration {
3636
private final ContextSetting contextSetting;
3737
private final GraphQLResponseCacheManager responseCacheManager;
3838
@Getter private final Executor asyncExecutor;
39+
@Getter private final List<String> allowedOrigins;
3940
private HttpRequestHandler requestHandler;
4041

4142
private GraphQLConfiguration(
@@ -49,9 +50,11 @@ private GraphQLConfiguration(
4950
ContextSetting contextSetting,
5051
Supplier<BatchInputPreProcessor> batchInputPreProcessor,
5152
GraphQLResponseCacheManager responseCacheManager,
52-
Executor asyncExecutor) {
53+
Executor asyncExecutor,
54+
List<String> allowedOrigins) {
5355
this.invocationInputFactory = invocationInputFactory;
5456
this.asyncExecutor = asyncExecutor;
57+
this.allowedOrigins = allowedOrigins;
5558
this.graphQLInvoker = graphQLInvoker != null ? graphQLInvoker : queryInvoker.toGraphQLInvoker();
5659
this.objectMapper = objectMapper;
5760
this.listeners = listeners;
@@ -148,6 +151,7 @@ public static class Builder {
148151
private int asyncMaxPoolSize = 200;
149152
private Executor asyncExecutor;
150153
private AsyncTaskDecorator asyncTaskDecorator;
154+
private List<String> allowedOrigins = new ArrayList<>();
151155

152156
private Builder(GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder) {
153157
this.invocationInputFactoryBuilder = invocationInputFactoryBuilder;
@@ -249,6 +253,13 @@ public Builder with(AsyncTaskDecorator asyncTaskDecorator) {
249253
return this;
250254
}
251255

256+
public Builder allowedOrigins(List<String> allowedOrigins) {
257+
if (allowedOrigins != null) {
258+
this.allowedOrigins.addAll(allowedOrigins);
259+
}
260+
return this;
261+
}
262+
252263
private Executor getAsyncExecutor() {
253264
if (asyncExecutor != null) {
254265
return asyncExecutor;
@@ -279,7 +290,8 @@ public GraphQLConfiguration build() {
279290
contextSetting,
280291
batchInputPreProcessorSupplier,
281292
responseCacheManager,
282-
getAsyncTaskExecutor());
293+
getAsyncTaskExecutor(),
294+
allowedOrigins);
283295
}
284296
}
285297
}

graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLWebsocketServlet.java

+30-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package graphql.kickstart.servlet;
22

33
import static java.util.Arrays.asList;
4+
import static java.util.Collections.emptyList;
45
import static java.util.Collections.singletonList;
56
import static java.util.stream.Collectors.toList;
67

@@ -65,6 +66,7 @@ public class GraphQLWebsocketServlet extends Endpoint {
6566
private final AtomicBoolean isShuttingDown = new AtomicBoolean(false);
6667
private final AtomicBoolean isShutDown = new AtomicBoolean(false);
6768
private final Object cacheLock = new Object();
69+
private final List<String> allowedOrigins;
6870

6971
public GraphQLWebsocketServlet(GraphQLConfiguration configuration) {
7072
this(configuration, null);
@@ -77,21 +79,23 @@ public GraphQLWebsocketServlet(
7779
configuration.getGraphQLInvoker(),
7880
configuration.getInvocationInputFactory(),
7981
configuration.getObjectMapper(),
80-
connectionListeners);
82+
connectionListeners,
83+
configuration.getAllowedOrigins());
8184
}
8285

8386
public GraphQLWebsocketServlet(
8487
GraphQLInvoker graphQLInvoker,
8588
GraphQLSubscriptionInvocationInputFactory invocationInputFactory,
8689
GraphQLObjectMapper graphQLObjectMapper) {
87-
this(graphQLInvoker, invocationInputFactory, graphQLObjectMapper, null);
90+
this(graphQLInvoker, invocationInputFactory, graphQLObjectMapper, null, emptyList());
8891
}
8992

9093
public GraphQLWebsocketServlet(
9194
GraphQLInvoker graphQLInvoker,
9295
GraphQLSubscriptionInvocationInputFactory invocationInputFactory,
9396
GraphQLObjectMapper graphQLObjectMapper,
94-
Collection<SubscriptionConnectionListener> connectionListeners) {
97+
Collection<SubscriptionConnectionListener> connectionListeners,
98+
List<String> allowedOrigins) {
9599
List<ApolloSubscriptionConnectionListener> listeners = new ArrayList<>();
96100
if (connectionListeners != null) {
97101
connectionListeners.stream()
@@ -114,12 +118,10 @@ public GraphQLWebsocketServlet(
114118
Stream.of(fallbackSubscriptionProtocolFactory))
115119
.map(SubscriptionProtocolFactory::getProtocol)
116120
.collect(toList());
121+
this.allowedOrigins = allowedOrigins;
117122
}
118123

119124
public GraphQLWebsocketServlet(
120-
GraphQLInvoker graphQLInvoker,
121-
GraphQLSubscriptionInvocationInputFactory invocationInputFactory,
122-
GraphQLObjectMapper graphQLObjectMapper,
123125
List<SubscriptionProtocolFactory> subscriptionProtocolFactory,
124126
SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory) {
125127

@@ -132,6 +134,8 @@ public GraphQLWebsocketServlet(
132134
Stream.of(fallbackSubscriptionProtocolFactory))
133135
.map(SubscriptionProtocolFactory::getProtocol)
134136
.collect(toList());
137+
138+
this.allowedOrigins = emptyList();
135139
}
136140

137141
@Override
@@ -202,6 +206,26 @@ private void closeUnexpectedly(Session session, Throwable t) {
202206
}
203207
}
204208

209+
public boolean checkOrigin(String originHeaderValue) {
210+
if (originHeaderValue == null || originHeaderValue.isBlank()) {
211+
return allowedOrigins.isEmpty();
212+
}
213+
String originToCheck = trimTrailingSlash(originHeaderValue);
214+
if (!allowedOrigins.isEmpty()) {
215+
if (allowedOrigins.contains("*")) {
216+
return true;
217+
}
218+
return allowedOrigins.stream()
219+
.map(this::trimTrailingSlash)
220+
.anyMatch(originToCheck::equalsIgnoreCase);
221+
}
222+
return true;
223+
}
224+
225+
private String trimTrailingSlash(String origin) {
226+
return (origin.endsWith("/") ? origin.substring(0, origin.length() - 1) : origin);
227+
}
228+
205229
public void modifyHandshake(
206230
ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
207231
sec.getUserProperties().put(HANDSHAKE_REQUEST_KEY, request);

0 commit comments

Comments
 (0)