Skip to content

Commit bb2b4ee

Browse files
committed
Include more fields in the Act log
Refactor constants
1 parent bd86ad3 commit bb2b4ee

File tree

2 files changed

+98
-45
lines changed

2 files changed

+98
-45
lines changed

plugin/metrics.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ var (
2020
Name: "on_traffic_from_client_total",
2121
Help: "The total number of calls to the onTrafficFromClient method",
2222
})
23-
Detections = promauto.NewCounter(prometheus.CounterOpts{
23+
Detections = promauto.NewCounterVec(prometheus.CounterOpts{
2424
Namespace: metrics.Namespace,
2525
Name: "detections_total",
2626
Help: "The total number of malicious requests detected",
27-
})
27+
}, []string{"detector"})
2828
Preventions = promauto.NewCounter(prometheus.CounterOpts{
2929
Namespace: metrics.Namespace,
3030
Name: "preventions_total",

plugin/plugin.go

+96-43
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,32 @@ import (
1919
"google.golang.org/grpc"
2020
)
2121

22+
const (
23+
DecodedQueryField string = "decodedQuery"
24+
DetectorField string = "detector"
25+
ScoreField string = "score"
26+
QueryField string = "query"
27+
ErrorField string = "error"
28+
IsInjectionField string = "is_injection"
29+
ResponseField string = "response"
30+
OutputsField string = "outputs"
31+
TokensField string = "tokens"
32+
RequestField string = "request"
33+
StringField string = "String"
34+
35+
DeepLearningModel string = "deep_learning_model"
36+
Libinjection string = "libinjection"
37+
38+
ErrorLevel string = "error"
39+
ExceptionLevel string = "EXCEPTION"
40+
ErrorNumber string = "42000"
41+
DetectionMessage string = "SQL injection detected"
42+
ErrorResponseMessage string = "Back off, you're not welcome here."
43+
44+
TokenizeAndSequencePath string = "/tokenize_and_sequence"
45+
PredictPath string = "/v1/models/%s/versions/%s:predict"
46+
)
47+
2248
type Plugin struct {
2349
goplugin.GRPCPlugin
2450
v1.GatewayDPluginServiceServer
@@ -44,7 +70,9 @@ func (p *InjectionDetectionPlugin) GRPCServer(b *goplugin.GRPCBroker, s *grpc.Se
4470
}
4571

4672
// GRPCClient returns the plugin client.
47-
func (p *InjectionDetectionPlugin) GRPCClient(ctx context.Context, b *goplugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
73+
func (p *InjectionDetectionPlugin) GRPCClient(
74+
ctx context.Context, b *goplugin.GRPCBroker, c *grpc.ClientConn,
75+
) (any, error) {
4876
return v1.NewGatewayDPluginServiceClient(c), nil
4977
}
5078

@@ -73,92 +101,119 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
73101
// Handle the client message.
74102
req, err := postgres.HandleClientMessage(req, p.Logger)
75103
if err != nil {
76-
p.Logger.Debug("Failed to handle client message", "error", err)
104+
p.Logger.Debug("Failed to handle client message", ErrorField, err)
77105
}
78106

79107
// Get the client request from the GatewayD request.
80-
request := cast.ToString(sdkPlugin.GetAttr(req, "request", ""))
108+
request := cast.ToString(sdkPlugin.GetAttr(req, RequestField, ""))
81109
if request == "" {
82110
return req, nil
83111
}
84112

85113
// Get the query from the request.
86-
query := cast.ToString(sdkPlugin.GetAttr(req, "query", ""))
114+
query := cast.ToString(sdkPlugin.GetAttr(req, QueryField, ""))
87115
if query == "" {
88116
p.Logger.Debug("Failed to get query from request, possibly not a SQL query request")
89117
return req, nil
90118
}
91-
p.Logger.Trace("Query", "query", query)
119+
p.Logger.Trace("Query", QueryField, query)
92120

93121
// Decode the query.
94122
decodedQuery, err := base64.StdEncoding.DecodeString(query)
95123
if err != nil {
96124
return req, err
97125
}
98-
p.Logger.Trace("Decoded Query", "decodedQuery", decodedQuery)
126+
p.Logger.Trace("Decoded Query", DecodedQueryField, decodedQuery)
99127

100128
// Unmarshal query into a map.
101-
var queryMap map[string]interface{}
129+
var queryMap map[string]any
102130
if err := json.Unmarshal(decodedQuery, &queryMap); err != nil {
103-
p.Logger.Error("Failed to unmarshal query", "error", err)
131+
p.Logger.Error("Failed to unmarshal query", ErrorField, err)
104132
return req, nil
105133
}
106-
queryString := cast.ToString(queryMap["String"])
134+
queryString := cast.ToString(queryMap[StringField])
107135

108-
var tokens map[string]interface{}
136+
var tokens map[string]any
109137
err = requests.
110138
URL(p.TokenizerAPIAddress).
111-
Path("/tokenize_and_sequence").
112-
BodyJSON(map[string]interface{}{
113-
"query": queryString,
139+
Path(TokenizeAndSequencePath).
140+
BodyJSON(map[string]any{
141+
QueryField: queryString,
114142
}).
115143
ToJSON(&tokens).
116144
Fetch(context.Background())
117145
if err != nil {
118-
p.Logger.Error("Failed to make POST request", "error", err)
146+
p.Logger.Error("Failed to make POST request", ErrorField, err)
119147
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
120-
return p.errorResponse(req, queryString), nil
148+
return p.errorResponse(
149+
req,
150+
map[string]any{
151+
QueryField: queryString,
152+
DetectorField: Libinjection,
153+
ErrorField: "Failed to make POST request to tokenizer API",
154+
},
155+
), nil
121156
}
122157
return req, nil
123158
}
124159

125-
var output map[string]interface{}
160+
var output map[string]any
126161
err = requests.
127162
URL(p.ServingAPIAddress).
128-
Path(fmt.Sprintf("/v1/models/%s/versions/%s:predict", p.ModelName, p.ModelVersion)).
129-
BodyJSON(map[string]interface{}{
130-
"inputs": []interface{}{cast.ToSlice(tokens["tokens"])},
163+
Path(fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion)).
164+
BodyJSON(map[string]any{
165+
"inputs": []any{cast.ToSlice(tokens[TokensField])},
131166
}).
132167
ToJSON(&output).
133168
Fetch(context.Background())
134169
if err != nil {
135-
p.Logger.Error("Failed to make POST request", "error", err)
170+
p.Logger.Error("Failed to make POST request", ErrorField, err)
136171
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
137-
return p.errorResponse(req, queryString), nil
172+
return p.errorResponse(
173+
req,
174+
map[string]any{
175+
QueryField: queryString,
176+
DetectorField: Libinjection,
177+
ErrorField: "Failed to make POST request to serving API",
178+
},
179+
), nil
138180
}
139181
return req, nil
140182
}
141183

142-
predictions := cast.ToSlice(output["outputs"])
184+
predictions := cast.ToSlice(output[OutputsField])
143185
scores := cast.ToSlice(predictions[0])
144186
score := cast.ToFloat32(scores[0])
145-
p.Logger.Trace("Deep learning model prediction", "score", score)
187+
p.Logger.Trace("Deep learning model prediction", ScoreField, score)
146188

147189
// Check the prediction against the threshold,
148190
// otherwise check if the query is an SQL injection using libinjection.
149191
injection := p.isSQLi(queryString)
150192
if score >= p.Threshold {
151193
if p.EnableLibinjection && !injection {
152-
p.Logger.Debug("False positive detected by libinjection")
194+
p.Logger.Debug("False positive detected", DetectorField, Libinjection)
153195
}
154196

155-
Detections.Inc()
156-
p.Logger.Warn("SQL injection detected by deep learning model", "score", score)
157-
return p.errorResponse(req, queryString), nil
197+
Detections.With(map[string]string{DetectorField: DeepLearningModel}).Inc()
198+
p.Logger.Warn(DetectionMessage, ScoreField, score, DetectorField, DeepLearningModel)
199+
return p.errorResponse(
200+
req,
201+
map[string]any{
202+
QueryField: queryString,
203+
ScoreField: score,
204+
DetectorField: DeepLearningModel,
205+
},
206+
), nil
158207
} else if p.EnableLibinjection && injection && !p.LibinjectionPermissiveMode {
159-
Detections.Inc()
160-
p.Logger.Warn("SQL injection detected by libinjection")
161-
return p.errorResponse(req, queryString), nil
208+
Detections.With(map[string]string{DetectorField: Libinjection}).Inc()
209+
p.Logger.Warn(DetectionMessage, DetectorField, Libinjection)
210+
return p.errorResponse(
211+
req,
212+
map[string]any{
213+
QueryField: queryString,
214+
DetectorField: Libinjection,
215+
},
216+
), nil
162217
} else {
163218
p.Logger.Trace("No SQL injection detected")
164219
}
@@ -175,45 +230,43 @@ func (p *Plugin) isSQLi(query string) bool {
175230
// Check if the query is an SQL injection using libinjection.
176231
injection, _ := libinjection.IsSQLi(query)
177232
if injection {
178-
p.Logger.Warn("SQL injection detected by libinjection")
233+
p.Logger.Warn(DetectionMessage, DetectorField, Libinjection)
179234
}
180-
p.Logger.Trace("SQLInjection", "is_injection", cast.ToString(injection))
235+
p.Logger.Trace("SQLInjection", IsInjectionField, cast.ToString(injection))
181236
return injection
182237
}
183238

184-
func (p *Plugin) errorResponse(req *v1.Struct, queryString string) *v1.Struct {
239+
func (p *Plugin) errorResponse(req *v1.Struct, fields map[string]any) *v1.Struct {
185240
Preventions.Inc()
186241

187242
// Create a PostgreSQL error response.
188243
errResp := postgres.ErrorResponse(
189-
"SQL injection detected",
190-
"EXCEPTION",
191-
"42000",
192-
"Back off, you're not welcome here.",
244+
DetectionMessage,
245+
ExceptionLevel,
246+
ErrorNumber,
247+
ErrorResponseMessage,
193248
)
194249

195250
// Create a ready for query response.
196251
readyForQuery := &pgproto3.ReadyForQuery{TxStatus: 'I'}
197252
// TODO: Decide whether to terminate the connection.
198253
response, err := readyForQuery.Encode(errResp)
199254
if err != nil {
200-
p.Logger.Error("Failed to encode ready for query response", "error", err)
255+
p.Logger.Error("Failed to encode ready for query response", ErrorField, err)
201256
return req
202257
}
203258

204259
signals, err := v1.NewList([]any{
205260
sdkAct.Terminate().ToMap(),
206-
sdkAct.Log("error", "SQL injection detected", map[string]any{
207-
"query": queryString,
208-
}).ToMap(),
261+
sdkAct.Log(ErrorLevel, DetectionMessage, fields).ToMap(),
209262
})
210263
if err != nil {
211-
p.Logger.Error("Failed to create signals", "error", err)
264+
p.Logger.Error("Failed to create signals", ErrorField, err)
212265
return req
213266
}
214267

215268
// Create a response to send back to the client.
216269
req.Fields[sdkAct.Signals] = v1.NewListValue(signals)
217-
req.Fields["response"] = v1.NewBytesValue(response)
270+
req.Fields[ResponseField] = v1.NewBytesValue(response)
218271
return req
219272
}

0 commit comments

Comments
 (0)