Skip to content

Commit ce1d6ab

Browse files
committed
Refactor plugin to use Tokenizer and Serving APIs instead of loading the model directly
1 parent c68b279 commit ce1d6ab

File tree

6 files changed

+38
-94
lines changed

6 files changed

+38
-94
lines changed

gatewayd_plugin.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ plugins:
2828
- METRICS_ENABLED=True
2929
- METRICS_UNIX_DOMAIN_SOCKET=/tmp/gatewayd-plugin-sql-ids-ips.sock
3030
- METRICS_PATH=/metrics
31-
- API_ADDRESS=http://localhost:5000
31+
- TOKENIZER_API_ADDRESS=http://localhost:8000
32+
- SERVING_API_ADDRESS=http://localhost:8501
3233
# Threshold determine the minimum prediction confidence
3334
# required to detect an SQL injection attack. Any value
3435
# between 0 and 1 is valid, and it is inclusive.
3536
# Anything below 0.8 is not recommended,
3637
# but it is dependent on the application and testing.
3738
- THRESHOLD=0.8
38-
- MODEL_PATH=sqli_model
3939
# The following env-vars disable the verbose logging of Tensorflow.
4040
- KMP_AFFINITY=noverbose
4141
- TF_CPP_MIN_LOG_LEVEL=3

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ module github.com/gatewayd-io/gatewayd-plugin-sql-ids-ips
33
go 1.22
44

55
require (
6+
github.com/carlmjohnson/requests v0.23.5
67
github.com/corazawaf/libinjection-go v0.1.3
7-
github.com/galeone/tensorflow/tensorflow/go v0.0.0-20240119075110-6ad3cf65adfe
88
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.5
99
github.com/hashicorp/go-hclog v1.6.2
1010
github.com/hashicorp/go-plugin v1.6.0

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
22
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
33
github.com/bufbuild/protocompile v0.4.0 h1:LbFKd2XowZvQ/kajzguUp2DC9UEIQhIq77fZZlaQsNA=
44
github.com/bufbuild/protocompile v0.4.0/go.mod h1:3v93+mbWn/v3xzN+31nwkJfrEpAUwp+BagBSZWx+TP8=
5+
github.com/carlmjohnson/requests v0.23.5 h1:NPANcAofwwSuC6SIMwlgmHry2V3pLrSqRiSBKYbNHHA=
6+
github.com/carlmjohnson/requests v0.23.5/go.mod h1:zG9P28thdRnN61aD7iECFhH5iGGKX2jIjKQD9kqYH+o=
57
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
68
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
79
github.com/corazawaf/libinjection-go v0.1.3 h1:PUplAYho1BBl0tIVbhDsNRuVGIeUYSiCEc9oQpb2rJU=
@@ -17,8 +19,6 @@ github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM=
1719
github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE=
1820
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
1921
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
20-
github.com/galeone/tensorflow/tensorflow/go v0.0.0-20240119075110-6ad3cf65adfe h1:7yELf1NFEwECpXMGowkoftcInMlVtLTCdwWLmxKgzNM=
21-
github.com/galeone/tensorflow/tensorflow/go v0.0.0-20240119075110-6ad3cf65adfe/go.mod h1:TelZuq26kz2jysARBwOrTv16629hyUsHmIoj54QqyFo=
2222
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.5 h1:H1S4CKS4IfezxlvgBLtSJ/3s85wznxgxJEnwLys+kIM=
2323
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.5/go.mod h1:1XS2ufw+8VRTHAbDf18Y7rSPlOczeQ/baUWPqJrDkeE=
2424
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=

main.go

+2-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"flag"
55
"os"
66

7-
tf "github.com/galeone/tensorflow/tensorflow/go"
87
sdkConfig "github.com/gatewayd-io/gatewayd-plugin-sdk/config"
98
"github.com/gatewayd-io/gatewayd-plugin-sdk/logging"
109
"github.com/gatewayd-io/gatewayd-plugin-sdk/metrics"
@@ -40,21 +39,11 @@ func main() {
4039
}
4140

4241
pluginInstance.Impl.Threshold = cast.ToFloat32(cfg["threshold"])
43-
44-
modelPath := cast.ToString(cfg["modelPath"])
45-
// Load the model from the file system
46-
model, err := tf.LoadSavedModel(modelPath, []string{"serve"}, nil)
47-
if err != nil {
48-
logger.Error("Failed to load model", "error", err)
49-
panic(err)
50-
}
51-
defer model.Session.Close()
52-
53-
pluginInstance.Impl.Model = model
5442
pluginInstance.Impl.EnableLibinjection = cast.ToBool(cfg["enableLibinjection"])
5543
pluginInstance.Impl.LibinjectionPermissiveMode = cast.ToBool(
5644
cfg["libinjectionPermissiveMode"])
57-
pluginInstance.Impl.APIAddress = cast.ToString(cfg["apiAddress"])
45+
pluginInstance.Impl.TokenizerAPIAddress = cast.ToString(cfg["tokenizerAPIAddress"])
46+
pluginInstance.Impl.ServingAPIAddress = cast.ToString(cfg["servingAPIAddress"])
5847
}
5948

6049
goplugin.Serve(&goplugin.ServeConfig{

plugin/module.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ var (
3434
"metricsEnabled": sdkConfig.GetEnv("METRICS_ENABLED", "true"),
3535
"metricsUnixDomainSocket": sdkConfig.GetEnv(
3636
"METRICS_UNIX_DOMAIN_SOCKET", "/tmp/gatewayd-plugin-sql-ids-ips.sock"),
37-
"metricsEndpoint": sdkConfig.GetEnv("METRICS_ENDPOINT", "/metrics"),
38-
"apiAddress": sdkConfig.GetEnv("API_ADDRESS", "http://localhost:5000"),
37+
"metricsEndpoint": sdkConfig.GetEnv("METRICS_ENDPOINT", "/metrics"),
38+
"tokenizerAPIAddress": sdkConfig.GetEnv(
39+
"TOKENIZER_API_ADDRESS", "http://localhost:8000"),
40+
"servingAPIAddress": sdkConfig.GetEnv(
41+
"SERVING_API_ADDRESS", "http://localhost:8501"),
3942
"threshold": sdkConfig.GetEnv("THRESHOLD", "0.8"),
40-
"modelPath": sdkConfig.GetEnv("MODEL_PATH", "sqli_model"),
4143
"enableLibinjection": sdkConfig.GetEnv("ENABLE_LIBINJECTION", "true"),
4244
"libinjectionPermissiveMode": sdkConfig.GetEnv("LIBINJECTION_MODE", "true"),
4345
},

plugin/plugin.go

+26-73
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
package plugin
22

33
import (
4-
"bytes"
54
"context"
65
"encoding/base64"
76
"encoding/json"
8-
"net/http"
9-
"net/url"
107

8+
"github.com/carlmjohnson/requests"
119
"github.com/corazawaf/libinjection-go"
12-
tf "github.com/galeone/tensorflow/tensorflow/go"
1310
sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
1411
"github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres"
1512
sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin"
@@ -25,11 +22,11 @@ type Plugin struct {
2522
goplugin.GRPCPlugin
2623
v1.GatewayDPluginServiceServer
2724
Logger hclog.Logger
28-
Model *tf.SavedModel
2925
Threshold float32
3026
EnableLibinjection bool
3127
LibinjectionPermissiveMode bool
32-
APIAddress string
28+
TokenizerAPIAddress string
29+
ServingAPIAddress string
3330
}
3431

3532
type InjectionDetectionPlugin struct {
@@ -105,87 +102,43 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
105102
}
106103
queryString := cast.ToString(queryMap["String"])
107104

108-
// Create a JSON body for the request.
109-
body, err := json.Marshal(map[string]interface{}{
110-
"query": queryString,
111-
})
112-
if err != nil {
113-
p.Logger.Error("Failed to marshal body", "error", err)
114-
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
115-
return p.errorResponse(req, queryString), nil
116-
}
117-
return req, nil
118-
}
119-
// Make an HTTP POST request to the tokenize service.
120-
tokenizeEndpoint, err := url.JoinPath(p.APIAddress, "/tokenize_and_sequence")
121-
if err != nil {
122-
p.Logger.Error("Failed to join API address and path", "error", err)
123-
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
124-
return p.errorResponse(req, queryString), nil
125-
}
126-
return req, nil
127-
}
128-
resp, err := http.Post(tokenizeEndpoint, "application/json", bytes.NewBuffer(body))
105+
var tokens map[string]interface{}
106+
err = requests.
107+
URL(p.TokenizerAPIAddress).
108+
Path("/tokenize_and_sequence").
109+
BodyJSON(map[string]interface{}{
110+
"query": queryString,
111+
}).
112+
ToJSON(&tokens).
113+
Fetch(context.Background())
129114
if err != nil {
130-
p.Logger.Error("Failed to make GET request", "error", err)
115+
p.Logger.Error("Failed to make POST request", "error", err)
131116
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
132117
return p.errorResponse(req, queryString), nil
133118
}
134119
return req, nil
135120
}
136121

137-
// Read the response body.
138-
defer resp.Body.Close()
139-
var data map[string]interface{}
140-
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
141-
p.Logger.Error("Failed to decode response body", "error", err)
142-
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
143-
return p.errorResponse(req, queryString), nil
144-
}
145-
return req, nil
146-
}
147-
148-
// Get the tokens from the response.
149-
var tokens []float32
150-
for _, v := range data["tokens"].([]interface{}) {
151-
tokens = append(tokens, cast.ToFloat32(v))
152-
}
153-
154-
// Convert []float32 to a [][]float32.
155-
allTokens := make([][]float32, 1)
156-
allTokens[0] = tokens
157-
158-
p.Logger.Trace("Tokens", "tokens", allTokens)
159-
160-
// Create a tensor from the tokens.
161-
inputTensor, err := tf.NewTensor(allTokens)
122+
var output map[string]interface{}
123+
err = requests.
124+
URL(p.ServingAPIAddress).
125+
Path("/v1/models/sqli_model:predict").
126+
BodyJSON(map[string]interface{}{
127+
"inputs": []interface{}{cast.ToSlice(tokens["tokens"])},
128+
}).
129+
ToJSON(&output).
130+
Fetch(context.Background())
162131
if err != nil {
163-
p.Logger.Error("Failed to create input tensor", "error", err)
132+
p.Logger.Error("Failed to make POST request", "error", err)
164133
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
165134
return p.errorResponse(req, queryString), nil
166135
}
167136
return req, nil
168137
}
169138

170-
// Run the model to predict if the query is malicious or not.
171-
output, err := p.Model.Session.Run(
172-
map[tf.Output]*tf.Tensor{
173-
p.Model.Graph.Operation("serving_default_embedding_input").Output(0): inputTensor,
174-
},
175-
[]tf.Output{
176-
p.Model.Graph.Operation("StatefulPartitionedCall").Output(0),
177-
},
178-
nil,
179-
)
180-
if err != nil {
181-
p.Logger.Error("Failed to run model", "error", err)
182-
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
183-
return p.errorResponse(req, queryString), nil
184-
}
185-
return req, nil
186-
}
187-
predictions := output[0].Value().([][]float32)
188-
score := predictions[0][0]
139+
predictions := cast.ToSlice(output["outputs"])
140+
scores := cast.ToSlice(predictions[0])
141+
score := cast.ToFloat32(scores[0])
189142
p.Logger.Trace("Deep learning model prediction", "score", score)
190143

191144
// Check the prediction against the threshold,

0 commit comments

Comments
 (0)