1
1
package plugin
2
2
3
3
import (
4
- "bytes"
5
4
"context"
6
5
"encoding/base64"
7
6
"encoding/json"
8
- "net/http"
9
- "net/url"
10
7
8
+ "github.com/carlmjohnson/requests"
11
9
"github.com/corazawaf/libinjection-go"
12
- tf "github.com/galeone/tensorflow/tensorflow/go"
13
10
sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
14
11
"github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres"
15
12
sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin"
@@ -25,11 +22,11 @@ type Plugin struct {
25
22
goplugin.GRPCPlugin
26
23
v1.GatewayDPluginServiceServer
27
24
Logger hclog.Logger
28
- Model * tf.SavedModel
29
25
Threshold float32
30
26
EnableLibinjection bool
31
27
LibinjectionPermissiveMode bool
32
- APIAddress string
28
+ TokenizerAPIAddress string
29
+ ServingAPIAddress string
33
30
}
34
31
35
32
type InjectionDetectionPlugin struct {
@@ -105,87 +102,43 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
105
102
}
106
103
queryString := cast .ToString (queryMap ["String" ])
107
104
108
- // Create a JSON body for the request.
109
- body , err := json .Marshal (map [string ]interface {}{
110
- "query" : queryString ,
111
- })
112
- if err != nil {
113
- p .Logger .Error ("Failed to marshal body" , "error" , err )
114
- if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
115
- return p .errorResponse (req , queryString ), nil
116
- }
117
- return req , nil
118
- }
119
- // Make an HTTP POST request to the tokenize service.
120
- tokenizeEndpoint , err := url .JoinPath (p .APIAddress , "/tokenize_and_sequence" )
121
- if err != nil {
122
- p .Logger .Error ("Failed to join API address and path" , "error" , err )
123
- if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
124
- return p .errorResponse (req , queryString ), nil
125
- }
126
- return req , nil
127
- }
128
- resp , err := http .Post (tokenizeEndpoint , "application/json" , bytes .NewBuffer (body ))
105
+ var tokens map [string ]interface {}
106
+ err = requests .
107
+ URL (p .TokenizerAPIAddress ).
108
+ Path ("/tokenize_and_sequence" ).
109
+ BodyJSON (map [string ]interface {}{
110
+ "query" : queryString ,
111
+ }).
112
+ ToJSON (& tokens ).
113
+ Fetch (context .Background ())
129
114
if err != nil {
130
- p .Logger .Error ("Failed to make GET request" , "error" , err )
115
+ p .Logger .Error ("Failed to make POST request" , "error" , err )
131
116
if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
132
117
return p .errorResponse (req , queryString ), nil
133
118
}
134
119
return req , nil
135
120
}
136
121
137
- // Read the response body.
138
- defer resp .Body .Close ()
139
- var data map [string ]interface {}
140
- if err := json .NewDecoder (resp .Body ).Decode (& data ); err != nil {
141
- p .Logger .Error ("Failed to decode response body" , "error" , err )
142
- if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
143
- return p .errorResponse (req , queryString ), nil
144
- }
145
- return req , nil
146
- }
147
-
148
- // Get the tokens from the response.
149
- var tokens []float32
150
- for _ , v := range data ["tokens" ].([]interface {}) {
151
- tokens = append (tokens , cast .ToFloat32 (v ))
152
- }
153
-
154
- // Convert []float32 to a [][]float32.
155
- allTokens := make ([][]float32 , 1 )
156
- allTokens [0 ] = tokens
157
-
158
- p .Logger .Trace ("Tokens" , "tokens" , allTokens )
159
-
160
- // Create a tensor from the tokens.
161
- inputTensor , err := tf .NewTensor (allTokens )
122
+ var output map [string ]interface {}
123
+ err = requests .
124
+ URL (p .ServingAPIAddress ).
125
+ Path ("/v1/models/sqli_model:predict" ).
126
+ BodyJSON (map [string ]interface {}{
127
+ "inputs" : []interface {}{cast .ToSlice (tokens ["tokens" ])},
128
+ }).
129
+ ToJSON (& output ).
130
+ Fetch (context .Background ())
162
131
if err != nil {
163
- p .Logger .Error ("Failed to create input tensor " , "error" , err )
132
+ p .Logger .Error ("Failed to make POST request " , "error" , err )
164
133
if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
165
134
return p .errorResponse (req , queryString ), nil
166
135
}
167
136
return req , nil
168
137
}
169
138
170
- // Run the model to predict if the query is malicious or not.
171
- output , err := p .Model .Session .Run (
172
- map [tf.Output ]* tf.Tensor {
173
- p .Model .Graph .Operation ("serving_default_embedding_input" ).Output (0 ): inputTensor ,
174
- },
175
- []tf.Output {
176
- p .Model .Graph .Operation ("StatefulPartitionedCall" ).Output (0 ),
177
- },
178
- nil ,
179
- )
180
- if err != nil {
181
- p .Logger .Error ("Failed to run model" , "error" , err )
182
- if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
183
- return p .errorResponse (req , queryString ), nil
184
- }
185
- return req , nil
186
- }
187
- predictions := output [0 ].Value ().([][]float32 )
188
- score := predictions [0 ][0 ]
139
+ predictions := cast .ToSlice (output ["outputs" ])
140
+ scores := cast .ToSlice (predictions [0 ])
141
+ score := cast .ToFloat32 (scores [0 ])
189
142
p .Logger .Trace ("Deep learning model prediction" , "score" , score )
190
143
191
144
// Check the prediction against the threshold,
0 commit comments