generated from gatewayd-io/plugin-template-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplugin.go
180 lines (152 loc) · 5.21 KB
/
plugin.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
package plugin
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
tf "github.com/galeone/tensorflow/tensorflow/go"
"github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres"
sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin"
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
"github.com/hashicorp/go-hclog"
goplugin "github.com/hashicorp/go-plugin"
"github.com/jackc/pgx/pgproto3"
"github.com/spf13/cast"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/structpb"
)
type Plugin struct {
goplugin.GRPCPlugin
v1.GatewayDPluginServiceServer
Logger hclog.Logger
Model *tf.SavedModel
Threshold float32
}
type TemplatePlugin struct {
goplugin.NetRPCUnsupportedPlugin
Impl Plugin
}
// GRPCServer registers the plugin with the gRPC server.
func (p *TemplatePlugin) GRPCServer(b *goplugin.GRPCBroker, s *grpc.Server) error {
v1.RegisterGatewayDPluginServiceServer(s, &p.Impl)
return nil
}
// GRPCClient returns the plugin client.
func (p *TemplatePlugin) GRPCClient(ctx context.Context, b *goplugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return v1.NewGatewayDPluginServiceClient(c), nil
}
// NewTemplatePlugin returns a new instance of the TestPlugin.
func NewTemplatePlugin(impl Plugin) *TemplatePlugin {
return &TemplatePlugin{
NetRPCUnsupportedPlugin: goplugin.NetRPCUnsupportedPlugin{},
Impl: impl,
}
}
// GetPluginConfig returns the plugin config. This is called by GatewayD
// when the plugin is loaded. The plugin config is used to configure the
// plugin.
func (p *Plugin) GetPluginConfig(
ctx context.Context, _ *structpb.Struct) (*structpb.Struct, error) {
GetPluginConfig.Inc()
return structpb.NewStruct(PluginConfig)
}
// OnTrafficFromClient is called when a request is received by GatewayD from the client.
// This can be used to modify the request or terminate the connection by returning an error
// or a response.
func (p *Plugin) OnTrafficFromClient(
ctx context.Context, req *structpb.Struct) (*structpb.Struct, error) {
OnTrafficFromClient.Inc()
req, err := postgres.HandleClientMessage(req, p.Logger)
if err != nil {
p.Logger.Debug("Failed to handle client message", "error", err)
}
// Get the client request from the GatewayD request.
request := cast.ToString(sdkPlugin.GetAttr(req, "request", ""))
if request == "" {
return req, nil
}
// Get the query from the request.
query := cast.ToString(sdkPlugin.GetAttr(req, "query", ""))
if query == "" {
p.Logger.Debug("Failed to get query from request, possibly not a SQL query request")
return req, nil
}
p.Logger.Trace("Query", "query", query)
// Decode the query.
decodedQuery, err := base64.StdEncoding.DecodeString(query)
if err != nil {
return req, err
}
p.Logger.Trace("Decoded Query", "decodedQuery", decodedQuery)
// Unmarshal query into a map.
var queryMap map[string]interface{}
if err := json.Unmarshal(decodedQuery, &queryMap); err != nil {
p.Logger.Error("Failed to unmarshal query", "error", err)
return req, nil
}
// Make an HTTP GET request to the tokenize service.
resp, err := http.Get(
fmt.Sprintf("http://localhost:5000/tokenize_and_sequence/%s", queryMap["String"]))
if err != nil {
p.Logger.Error("Failed to make GET request", "error", err)
return req, nil
}
// Read the response body.
defer resp.Body.Close()
var data map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
p.Logger.Error("Failed to decode response body", "error", err)
return req, nil
}
var tokens []float32
for _, v := range data["tokens"].([]interface{}) {
tokens = append(tokens, cast.ToFloat32(v))
}
// Convert []float32 to a [][]float32.
allTokens := make([][]float32, 1)
allTokens[0] = tokens
p.Logger.Trace("Tokens", "tokens", allTokens)
inputTensor, err := tf.NewTensor(allTokens)
if err != nil {
p.Logger.Error("Failed to create input tensor", "error", err)
return req, nil
}
output, err := p.Model.Session.Run(
map[tf.Output]*tf.Tensor{
p.Model.Graph.Operation("serving_default_embedding_input").Output(0): inputTensor,
},
[]tf.Output{
p.Model.Graph.Operation("StatefulPartitionedCall").Output(0),
},
nil,
)
if err != nil {
p.Logger.Error("Failed to run model", "error", err)
return req, nil
}
predictions := output[0].Value().([][]float32)
// Define the threshold for the prediction.
p.Logger.Debug("Prediction", "prediction", predictions[0][0])
if predictions[0][0] >= p.Threshold {
p.Logger.Warn("SQL Injection Detected", "prediction", predictions[0][0])
// Create a PostgreSQL error response.
errResp := &pgproto3.ErrorResponse{
Severity: "ERROR",
Message: "SQL Injection Detected",
Detail: "Back off, you're not welcome here.",
}
// Create a ready for query response.
readyForQuery := &pgproto3.ReadyForQuery{TxStatus: 'I'}
// Create a buffer to write the response to.
response := errResp.Encode(nil)
// TODO: Decide whether to terminate the connection.
response = readyForQuery.Encode(response)
// Create a response to send back to the client.
req.Fields["response"] = structpb.NewStringValue(
base64.StdEncoding.EncodeToString(response))
req.Fields["terminate"] = structpb.NewBoolValue(true)
return req, nil
}
return req, nil
}