-
Notifications
You must be signed in to change notification settings - Fork 362
/
Copy pathkernelone.py
47 lines (37 loc) · 1.16 KB
/
kernelone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import onnx
from onnx import TensorProto
from onnx.checker import check_model
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
def create_custom_operator():
# Define input and output names
input_names = ["X", "Y"]
output_names = ["Z"]
# Create a custom operator node
custom_op_node = onnx.helper.make_node(
"CustomOpOne", # Custom operator name
input_names,
output_names,
domain="v1", # Custom domain name
)
# Create an ONNX graph
graph = onnx.helper.make_graph(
[custom_op_node],
"custom_opone_model",
[
onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, [3])
for name in input_names
],
[
onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, [3])
for name in output_names
],
)
# Create the ONNX model
model = onnx.helper.make_model(graph)
# check_model(model)
print(model)
# Save the model to a file
onnx.save(model, "custom_kernel_one_model.onnx")
if __name__ == "__main__":
create_custom_operator()