1
1
package graphql .kickstart .servlet ;
2
2
3
3
import static java .util .Arrays .asList ;
4
+ import static java .util .Collections .emptyList ;
4
5
import static java .util .Collections .singletonList ;
5
6
import static java .util .stream .Collectors .toList ;
6
7
@@ -65,6 +66,7 @@ public class GraphQLWebsocketServlet extends Endpoint {
65
66
private final AtomicBoolean isShuttingDown = new AtomicBoolean (false );
66
67
private final AtomicBoolean isShutDown = new AtomicBoolean (false );
67
68
private final Object cacheLock = new Object ();
69
+ private final List <String > allowedOrigins ;
68
70
69
71
public GraphQLWebsocketServlet (GraphQLConfiguration configuration ) {
70
72
this (configuration , null );
@@ -77,21 +79,23 @@ public GraphQLWebsocketServlet(
77
79
configuration .getGraphQLInvoker (),
78
80
configuration .getInvocationInputFactory (),
79
81
configuration .getObjectMapper (),
80
- connectionListeners );
82
+ connectionListeners ,
83
+ configuration .getAllowedOrigins ());
81
84
}
82
85
83
86
public GraphQLWebsocketServlet (
84
87
GraphQLInvoker graphQLInvoker ,
85
88
GraphQLSubscriptionInvocationInputFactory invocationInputFactory ,
86
89
GraphQLObjectMapper graphQLObjectMapper ) {
87
- this (graphQLInvoker , invocationInputFactory , graphQLObjectMapper , null );
90
+ this (graphQLInvoker , invocationInputFactory , graphQLObjectMapper , null , emptyList () );
88
91
}
89
92
90
93
public GraphQLWebsocketServlet (
91
94
GraphQLInvoker graphQLInvoker ,
92
95
GraphQLSubscriptionInvocationInputFactory invocationInputFactory ,
93
96
GraphQLObjectMapper graphQLObjectMapper ,
94
- Collection <SubscriptionConnectionListener > connectionListeners ) {
97
+ Collection <SubscriptionConnectionListener > connectionListeners ,
98
+ List <String > allowedOrigins ) {
95
99
List <ApolloSubscriptionConnectionListener > listeners = new ArrayList <>();
96
100
if (connectionListeners != null ) {
97
101
connectionListeners .stream ()
@@ -114,12 +118,10 @@ public GraphQLWebsocketServlet(
114
118
Stream .of (fallbackSubscriptionProtocolFactory ))
115
119
.map (SubscriptionProtocolFactory ::getProtocol )
116
120
.collect (toList ());
121
+ this .allowedOrigins = allowedOrigins ;
117
122
}
118
123
119
124
public GraphQLWebsocketServlet (
120
- GraphQLInvoker graphQLInvoker ,
121
- GraphQLSubscriptionInvocationInputFactory invocationInputFactory ,
122
- GraphQLObjectMapper graphQLObjectMapper ,
123
125
List <SubscriptionProtocolFactory > subscriptionProtocolFactory ,
124
126
SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory ) {
125
127
@@ -132,6 +134,8 @@ public GraphQLWebsocketServlet(
132
134
Stream .of (fallbackSubscriptionProtocolFactory ))
133
135
.map (SubscriptionProtocolFactory ::getProtocol )
134
136
.collect (toList ());
137
+
138
+ this .allowedOrigins = emptyList ();
135
139
}
136
140
137
141
@ Override
@@ -202,6 +206,26 @@ private void closeUnexpectedly(Session session, Throwable t) {
202
206
}
203
207
}
204
208
209
+ public boolean checkOrigin (String originHeaderValue ) {
210
+ if (originHeaderValue == null || originHeaderValue .isBlank ()) {
211
+ return allowedOrigins .isEmpty ();
212
+ }
213
+ String originToCheck = trimTrailingSlash (originHeaderValue );
214
+ if (!allowedOrigins .isEmpty ()) {
215
+ if (allowedOrigins .contains ("*" )) {
216
+ return true ;
217
+ }
218
+ return allowedOrigins .stream ()
219
+ .map (this ::trimTrailingSlash )
220
+ .anyMatch (originToCheck ::equalsIgnoreCase );
221
+ }
222
+ return true ;
223
+ }
224
+
225
+ private String trimTrailingSlash (String origin ) {
226
+ return (origin .endsWith ("/" ) ? origin .substring (0 , origin .length () - 1 ) : origin );
227
+ }
228
+
205
229
public void modifyHandshake (
206
230
ServerEndpointConfig sec , HandshakeRequest request , HandshakeResponse response ) {
207
231
sec .getUserProperties ().put (HANDSHAKE_REQUEST_KEY , request );
0 commit comments