@@ -19,6 +19,32 @@ import (
19
19
"google.golang.org/grpc"
20
20
)
21
21
22
+ const (
23
+ DecodedQueryField string = "decodedQuery"
24
+ DetectorField string = "detector"
25
+ ScoreField string = "score"
26
+ QueryField string = "query"
27
+ ErrorField string = "error"
28
+ IsInjectionField string = "is_injection"
29
+ ResponseField string = "response"
30
+ OutputsField string = "outputs"
31
+ TokensField string = "tokens"
32
+ RequestField string = "request"
33
+ StringField string = "String"
34
+
35
+ DeepLearningModel string = "deep_learning_model"
36
+ Libinjection string = "libinjection"
37
+
38
+ ErrorLevel string = "error"
39
+ ExceptionLevel string = "EXCEPTION"
40
+ ErrorNumber string = "42000"
41
+ DetectionMessage string = "SQL injection detected"
42
+ ErrorResponseMessage string = "Back off, you're not welcome here."
43
+
44
+ TokenizeAndSequencePath string = "/tokenize_and_sequence"
45
+ PredictPath string = "/v1/models/%s/versions/%s:predict"
46
+ )
47
+
22
48
type Plugin struct {
23
49
goplugin.GRPCPlugin
24
50
v1.GatewayDPluginServiceServer
@@ -44,7 +70,9 @@ func (p *InjectionDetectionPlugin) GRPCServer(b *goplugin.GRPCBroker, s *grpc.Se
44
70
}
45
71
46
72
// GRPCClient returns the plugin client.
47
- func (p * InjectionDetectionPlugin ) GRPCClient (ctx context.Context , b * goplugin.GRPCBroker , c * grpc.ClientConn ) (interface {}, error ) {
73
+ func (p * InjectionDetectionPlugin ) GRPCClient (
74
+ ctx context.Context , b * goplugin.GRPCBroker , c * grpc.ClientConn ,
75
+ ) (any , error ) {
48
76
return v1 .NewGatewayDPluginServiceClient (c ), nil
49
77
}
50
78
@@ -73,92 +101,119 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
73
101
// Handle the client message.
74
102
req , err := postgres .HandleClientMessage (req , p .Logger )
75
103
if err != nil {
76
- p .Logger .Debug ("Failed to handle client message" , "error" , err )
104
+ p .Logger .Debug ("Failed to handle client message" , ErrorField , err )
77
105
}
78
106
79
107
// Get the client request from the GatewayD request.
80
- request := cast .ToString (sdkPlugin .GetAttr (req , "request" , "" ))
108
+ request := cast .ToString (sdkPlugin .GetAttr (req , RequestField , "" ))
81
109
if request == "" {
82
110
return req , nil
83
111
}
84
112
85
113
// Get the query from the request.
86
- query := cast .ToString (sdkPlugin .GetAttr (req , "query" , "" ))
114
+ query := cast .ToString (sdkPlugin .GetAttr (req , QueryField , "" ))
87
115
if query == "" {
88
116
p .Logger .Debug ("Failed to get query from request, possibly not a SQL query request" )
89
117
return req , nil
90
118
}
91
- p .Logger .Trace ("Query" , "query" , query )
119
+ p .Logger .Trace ("Query" , QueryField , query )
92
120
93
121
// Decode the query.
94
122
decodedQuery , err := base64 .StdEncoding .DecodeString (query )
95
123
if err != nil {
96
124
return req , err
97
125
}
98
- p .Logger .Trace ("Decoded Query" , "decodedQuery" , decodedQuery )
126
+ p .Logger .Trace ("Decoded Query" , DecodedQueryField , decodedQuery )
99
127
100
128
// Unmarshal query into a map.
101
- var queryMap map [string ]interface {}
129
+ var queryMap map [string ]any
102
130
if err := json .Unmarshal (decodedQuery , & queryMap ); err != nil {
103
- p .Logger .Error ("Failed to unmarshal query" , "error" , err )
131
+ p .Logger .Error ("Failed to unmarshal query" , ErrorField , err )
104
132
return req , nil
105
133
}
106
- queryString := cast .ToString (queryMap ["String" ])
134
+ queryString := cast .ToString (queryMap [StringField ])
107
135
108
- var tokens map [string ]interface {}
136
+ var tokens map [string ]any
109
137
err = requests .
110
138
URL (p .TokenizerAPIAddress ).
111
- Path ("/tokenize_and_sequence" ).
112
- BodyJSON (map [string ]interface {} {
113
- "query" : queryString ,
139
+ Path (TokenizeAndSequencePath ).
140
+ BodyJSON (map [string ]any {
141
+ QueryField : queryString ,
114
142
}).
115
143
ToJSON (& tokens ).
116
144
Fetch (context .Background ())
117
145
if err != nil {
118
- p .Logger .Error ("Failed to make POST request" , "error" , err )
146
+ p .Logger .Error ("Failed to make POST request" , ErrorField , err )
119
147
if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
120
- return p .errorResponse (req , queryString ), nil
148
+ return p .errorResponse (
149
+ req ,
150
+ map [string ]any {
151
+ QueryField : queryString ,
152
+ DetectorField : Libinjection ,
153
+ ErrorField : "Failed to make POST request to tokenizer API" ,
154
+ },
155
+ ), nil
121
156
}
122
157
return req , nil
123
158
}
124
159
125
- var output map [string ]interface {}
160
+ var output map [string ]any
126
161
err = requests .
127
162
URL (p .ServingAPIAddress ).
128
- Path (fmt .Sprintf ("/v1/models/%s/versions/%s:predict" , p .ModelName , p .ModelVersion )).
129
- BodyJSON (map [string ]interface {} {
130
- "inputs" : []interface {}{ cast .ToSlice (tokens ["tokens" ])},
163
+ Path (fmt .Sprintf (PredictPath , p .ModelName , p .ModelVersion )).
164
+ BodyJSON (map [string ]any {
165
+ "inputs" : []any { cast .ToSlice (tokens [TokensField ])},
131
166
}).
132
167
ToJSON (& output ).
133
168
Fetch (context .Background ())
134
169
if err != nil {
135
- p .Logger .Error ("Failed to make POST request" , "error" , err )
170
+ p .Logger .Error ("Failed to make POST request" , ErrorField , err )
136
171
if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
137
- return p .errorResponse (req , queryString ), nil
172
+ return p .errorResponse (
173
+ req ,
174
+ map [string ]any {
175
+ QueryField : queryString ,
176
+ DetectorField : Libinjection ,
177
+ ErrorField : "Failed to make POST request to serving API" ,
178
+ },
179
+ ), nil
138
180
}
139
181
return req , nil
140
182
}
141
183
142
- predictions := cast .ToSlice (output ["outputs" ])
184
+ predictions := cast .ToSlice (output [OutputsField ])
143
185
scores := cast .ToSlice (predictions [0 ])
144
186
score := cast .ToFloat32 (scores [0 ])
145
- p .Logger .Trace ("Deep learning model prediction" , "score" , score )
187
+ p .Logger .Trace ("Deep learning model prediction" , ScoreField , score )
146
188
147
189
// Check the prediction against the threshold,
148
190
// otherwise check if the query is an SQL injection using libinjection.
149
191
injection := p .isSQLi (queryString )
150
192
if score >= p .Threshold {
151
193
if p .EnableLibinjection && ! injection {
152
- p .Logger .Debug ("False positive detected by libinjection" )
194
+ p .Logger .Debug ("False positive detected" , DetectorField , Libinjection )
153
195
}
154
196
155
- Detections .Inc ()
156
- p .Logger .Warn ("SQL injection detected by deep learning model" , "score" , score )
157
- return p .errorResponse (req , queryString ), nil
197
+ Detections .With (map [string ]string {DetectorField : DeepLearningModel }).Inc ()
198
+ p .Logger .Warn (DetectionMessage , ScoreField , score , DetectorField , DeepLearningModel )
199
+ return p .errorResponse (
200
+ req ,
201
+ map [string ]any {
202
+ QueryField : queryString ,
203
+ ScoreField : score ,
204
+ DetectorField : DeepLearningModel ,
205
+ },
206
+ ), nil
158
207
} else if p .EnableLibinjection && injection && ! p .LibinjectionPermissiveMode {
159
- Detections .Inc ()
160
- p .Logger .Warn ("SQL injection detected by libinjection" )
161
- return p .errorResponse (req , queryString ), nil
208
+ Detections .With (map [string ]string {DetectorField : Libinjection }).Inc ()
209
+ p .Logger .Warn (DetectionMessage , DetectorField , Libinjection )
210
+ return p .errorResponse (
211
+ req ,
212
+ map [string ]any {
213
+ QueryField : queryString ,
214
+ DetectorField : Libinjection ,
215
+ },
216
+ ), nil
162
217
} else {
163
218
p .Logger .Trace ("No SQL injection detected" )
164
219
}
@@ -175,45 +230,43 @@ func (p *Plugin) isSQLi(query string) bool {
175
230
// Check if the query is an SQL injection using libinjection.
176
231
injection , _ := libinjection .IsSQLi (query )
177
232
if injection {
178
- p .Logger .Warn ("SQL injection detected by libinjection" )
233
+ p .Logger .Warn (DetectionMessage , DetectorField , Libinjection )
179
234
}
180
- p .Logger .Trace ("SQLInjection" , "is_injection" , cast .ToString (injection ))
235
+ p .Logger .Trace ("SQLInjection" , IsInjectionField , cast .ToString (injection ))
181
236
return injection
182
237
}
183
238
184
- func (p * Plugin ) errorResponse (req * v1.Struct , queryString string ) * v1.Struct {
239
+ func (p * Plugin ) errorResponse (req * v1.Struct , fields map [ string ] any ) * v1.Struct {
185
240
Preventions .Inc ()
186
241
187
242
// Create a PostgreSQL error response.
188
243
errResp := postgres .ErrorResponse (
189
- "SQL injection detected" ,
190
- "EXCEPTION" ,
191
- "42000" ,
192
- "Back off, you're not welcome here." ,
244
+ DetectionMessage ,
245
+ ExceptionLevel ,
246
+ ErrorNumber ,
247
+ ErrorResponseMessage ,
193
248
)
194
249
195
250
// Create a ready for query response.
196
251
readyForQuery := & pgproto3.ReadyForQuery {TxStatus : 'I' }
197
252
// TODO: Decide whether to terminate the connection.
198
253
response , err := readyForQuery .Encode (errResp )
199
254
if err != nil {
200
- p .Logger .Error ("Failed to encode ready for query response" , "error" , err )
255
+ p .Logger .Error ("Failed to encode ready for query response" , ErrorField , err )
201
256
return req
202
257
}
203
258
204
259
signals , err := v1 .NewList ([]any {
205
260
sdkAct .Terminate ().ToMap (),
206
- sdkAct .Log ("error" , "SQL injection detected" , map [string ]any {
207
- "query" : queryString ,
208
- }).ToMap (),
261
+ sdkAct .Log (ErrorLevel , DetectionMessage , fields ).ToMap (),
209
262
})
210
263
if err != nil {
211
- p .Logger .Error ("Failed to create signals" , "error" , err )
264
+ p .Logger .Error ("Failed to create signals" , ErrorField , err )
212
265
return req
213
266
}
214
267
215
268
// Create a response to send back to the client.
216
269
req .Fields [sdkAct .Signals ] = v1 .NewListValue (signals )
217
- req .Fields ["response" ] = v1 .NewBytesValue (response )
270
+ req .Fields [ResponseField ] = v1 .NewBytesValue (response )
218
271
return req
219
272
}
0 commit comments