Skip to content

Commit 6c3e421

Browse files
Arm backend: Add GELU operator (#10109)
- Add GELU decomposition pass - BI handled by table op pass - Add tests - xfail partitioning test with task to revisit Signed-off-by: Iliyan Georgiev <Iliyan.Georgiev@arm.com>
1 parent c2e5d23 commit 6c3e421

8 files changed

+300
-1
lines changed

backends/arm/_passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
2222
from .decompose_div_pass import DecomposeDivPass # noqa
23+
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2324
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2425
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2526
from .decompose_linear_pass import DecomposeLinearPass # noqa

backends/arm/_passes/arm_pass_manager.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
2727
DecomposeDivPass,
28+
DecomposeGeluPass,
2829
DecomposeLayerNormPass,
2930
DecomposeLeakyReLUPass,
3031
DecomposeLinearPass,
@@ -132,6 +133,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
132133
self.add_pass(ConvertMeanDimToAveragePoolPass())
133134
self.add_pass(DecomposeDivPass())
134135
self.add_pass(DecomposeSoftmaxPass())
136+
self.add_pass(DecomposeGeluPass())
135137
self.add_pass(ConvertFullLikeToFullPass())
136138
self.add_pass(ConvertToClampPass())
137139
self.add_pass(ConvertMinMaxPass())
+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
torch_gelu = (torch.ops.aten.gelu.default,)
12+
13+
edge_gelu = (exir_ops.edge.aten.gelu.default,)
14+
15+
16+
def _get_gelu_ops(op) -> tuple:
17+
"""
18+
Returns the operators needed to decompose GELU
19+
"""
20+
21+
if op in edge_gelu:
22+
return (
23+
exir_ops.edge.aten.full.default,
24+
exir_ops.edge.aten.add.Tensor,
25+
exir_ops.edge.aten.mul.Tensor,
26+
exir_ops.edge.aten.tanh.default,
27+
exir_ops.edge.aten.erf.default,
28+
)
29+
if op in torch_gelu:
30+
return (
31+
torch.ops.aten.full.default,
32+
torch.ops.aten.add.Tensor,
33+
torch.ops.aten.mul.Tensor,
34+
torch.ops.aten.tanh.default,
35+
torch.ops.aten.erf.default,
36+
)
37+
raise RuntimeError(f"Can't get GeLU decomposition ops for op {op}")
38+
39+
40+
class DecomposeGeluPass(ExportPass):
41+
"""
42+
This pass decomposes the GELU operator into primitive ops.
43+
Aiming to adhere closely to the reference implementations built into
44+
ExecuTorch. Including using the same pre-calculated constants.
45+
46+
This operator has two formulae depending on the value of the
47+
approximate argument. Examples below include the added full
48+
operators necessary for the initialization for constants used in
49+
each respective formula.
50+
51+
aten.gelu(x, approximate="none") becomes:
52+
%FULL_0_5 = full()
53+
%FULL_1 = full()
54+
%FULL_SQRT1_2 = full()
55+
%op1 = mul(x, %FULL_SQRT1_2)
56+
%op2 = erf(%op1)
57+
%op3 = add(%op2, %FULL_1)
58+
%op4 = mul(%op3, %FULL_0_5)
59+
%op5 = mul(%x, %op4)
60+
61+
aten.gelu(x, approximate="tanh") becomes:
62+
%FULL_0_5 = full()
63+
%FULL_1 = full()
64+
%FULL_SQRT2 = full()
65+
%FULL_2_SQRTPI = full()
66+
%FULL_CUBE_COEFF = full()
67+
%SQRT_MUL = mul(%FULL_SQRT2, %FULL_2_SQRTPI)
68+
%SQRT_2_PI = mul(%SQRT_MUL, %FULL_0_5)
69+
%sqr_x = mul(x, x)
70+
%cube_x = mul(sqr_x, x)
71+
%op1 = mul(%cube_x, %FULL_CUBE_COEFF)
72+
%op2 = add(%x, %op1)
73+
%op3 = mul(%op2, %SQRT_2_PI)
74+
%op4 = tanh(%op3)
75+
%op5 = add(%op4, %FULL_1)
76+
%op6 = mul(%x, %op5)
77+
%op7 = mul(%op6, %FULL_0_5)
78+
"""
79+
80+
def call_operator(self, op, args, kwargs, meta):
81+
if op not in torch_gelu + edge_gelu:
82+
return super().call_operator(op, args, kwargs, meta)
83+
84+
full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op)
85+
86+
input = get_node_arg(args, 0)
87+
# If approximate is default (none) it does not appear in kwargs
88+
approximate = get_node_arg(kwargs, "approximate", "none")
89+
90+
shape = meta["val"].size()
91+
dtype = meta["val"].dtype
92+
93+
FULL_0_5 = super().call_operator(
94+
full_op, ([1] * len(shape), 0.5), {"dtype": dtype}, meta
95+
)
96+
FULL_1 = super().call_operator(
97+
full_op, ([1] * len(shape), 1), {"dtype": dtype}, meta
98+
)
99+
100+
if approximate == "none":
101+
# Constant mirrors ExecuTorch implementation for parity.
102+
FULL_SQRT1_2 = super().call_operator(
103+
full_op, ([1] * len(shape), 0.70710678118654752440), {}, meta
104+
)
105+
106+
op1 = super().call_operator(mul_op, (input, FULL_SQRT1_2), {}, meta)
107+
op2 = super().call_operator(erf_op, (op1,), {}, meta)
108+
op3 = super().call_operator(add_op, (op2, FULL_1), {}, meta)
109+
op4 = super().call_operator(mul_op, (op3, FULL_0_5), {}, meta)
110+
return super().call_operator(mul_op, (input, op4), {}, meta)
111+
112+
elif approximate == "tanh":
113+
# Constants mirror ExecuTorch implementation for parity.
114+
FULL_SQRT2 = super().call_operator(
115+
full_op,
116+
([1] * len(shape), 1.41421356237309504880),
117+
{"dtype": dtype},
118+
meta,
119+
)
120+
FULL_2_SQRTPI = super().call_operator(
121+
full_op,
122+
([1] * len(shape), 1.12837916709551257390),
123+
{"dtype": dtype},
124+
meta,
125+
)
126+
FULL_CUBE_COEFF = super().call_operator(
127+
full_op, ([1] * len(shape), 0.044715), {"dtype": dtype}, meta
128+
)
129+
130+
# Mirrors ExecuTorch implementations for calculating this value
131+
SQRT_MUL = super().call_operator(
132+
mul_op, (FULL_SQRT2, FULL_2_SQRTPI), {}, meta
133+
)
134+
SQRT_2_PI = super().call_operator(mul_op, (SQRT_MUL, FULL_0_5), {}, meta)
135+
136+
# Avoiding using POW in order to reduce pass order reliance.
137+
sqr_x = super().call_operator(mul_op, (input, input), {}, meta)
138+
cube_x = super().call_operator(mul_op, (sqr_x, input), {}, meta)
139+
op1 = super().call_operator(mul_op, (cube_x, FULL_CUBE_COEFF), {}, meta)
140+
op2 = super().call_operator(add_op, (input, op1), {}, meta)
141+
op3 = super().call_operator(mul_op, (op2, SQRT_2_PI), {}, meta)
142+
op4 = super().call_operator(tanh_op, (op3,), {}, meta)
143+
op5 = super().call_operator(add_op, (op4, FULL_1), {}, meta)
144+
op6 = super().call_operator(mul_op, (input, op5), {}, meta)
145+
return super().call_operator(mul_op, (op6, FULL_0_5), {}, meta)
146+
else:
147+
raise RuntimeError(
148+
f"approximate argument expected 'none' or 'tanh' but got {approximate}"
149+
)

backends/arm/_passes/insert_table_ops.py

+14
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class TableOps:
5656
# Targets that must be treated explicitly
5757
special_table_ops: Set[EdgeOpOverload] = {
5858
exir_ops.edge.aten.pow.Tensor_Scalar,
59+
exir_ops.edge.aten.gelu.default,
5960
}
6061

6162
def __init__(self, exported_program: ExportedProgram):
@@ -76,6 +77,19 @@ def __getitem__(self, node: Node):
7677
# Exponent is a constant. Embed it into a lambda.
7778
exp = cast(int, node.args[1])
7879
return lambda x: torch.pow(x, exp).flatten()
80+
case exir_ops.edge.aten.gelu.default:
81+
# If kwargs not present it is default "none"
82+
approximate = cast(
83+
str,
84+
(
85+
node.kwargs["approximate"]
86+
if "approximate" in node.kwargs
87+
else "none"
88+
),
89+
)
90+
return lambda x: torch.nn.functional.gelu(
91+
x, approximate=approximate
92+
).flatten()
7993
case _:
8094
# Op must be handled if it's inside self.special_ops
8195
raise AssertionError("Unhandled table operation")

backends/arm/operator_support/tosa_supported_operators.py

+2
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def is_node_supported(
225225
exir_ops.edge.aten.bitwise_left_shift.Tensor,
226226
exir_ops.edge.aten.__lshift__.Scalar,
227227
torch.ops.aten.scalar_tensor.default,
228+
exir_ops.edge.aten.gelu.default,
228229
]
229230

230231
return supported
@@ -361,6 +362,7 @@ def is_node_supported(
361362
exir_ops.edge.aten.sub.Tensor,
362363
exir_ops.edge.aten.tanh.default,
363364
exir_ops.edge.aten.upsample_nearest2d.vec,
365+
exir_ops.edge.aten.gelu.default,
364366
):
365367
return True
366368
elif node.target in (

backends/arm/quantizer/quantization_annotator.py

+1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def _match_pattern(
178178
torch.ops.aten.hardswish_.default,
179179
torch.ops.aten.full_like.default,
180180
torch.ops.aten.pow.Tensor_Scalar,
181+
torch.ops.aten.gelu.default,
181182
]
182183

183184
_one_to_one_shared_input_qspec = [

backends/arm/test/misc/test_partition_decomposed_quantized_ops.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,12 @@ def test_softplus_tosa_BI(test_data: input_t1):
117117
# Since GELU will not be quantized by TosaQuantizer, the Dropout's input will not be quantized either.
118118
# If so, the Dropout should not be partitioned by TosaPartitioner for TOSA BI profile. This test tests that the
119119
# partitioner indeed does not partition the Dropout (clone) for TOSA BI.
120-
@common.parametrize("test_data", test_data)
120+
@common.parametrize(
121+
"test_data",
122+
test_data,
123+
{"3d_rand": "MLETORCH-909: Partition test to not rely on unsupported ops"},
124+
strict=False,
125+
)
121126
def test_linear_residaul_tosa_MI(test_data: input_t1):
122127
pipeline = TosaPipelineMI[input_t1](
123128
LinearResidualModule(),

backends/arm/test/ops/test_gelu.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineBI,
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
input_t1 = Tuple[torch.Tensor]
18+
19+
20+
class Gelu(torch.nn.Module):
21+
aten_op = "torch.ops.aten.gelu.default"
22+
exir_op = "executorch_exir_dialects_edge__ops_aten_gelu_default"
23+
24+
test_data: dict[str, Tuple[str, input_t1]] = {
25+
"zeros_none": (
26+
"none",
27+
torch.zeros(1, 10, 10, 10),
28+
),
29+
"ones_none": (
30+
"none",
31+
torch.ones(10, 10, 10),
32+
),
33+
"rand_none": (
34+
"none",
35+
(torch.rand(10, 10) - 0.5),
36+
),
37+
"randn_pos_none": (
38+
"none",
39+
(torch.randn(1, 4, 4, 4) + 10),
40+
),
41+
"randn_neg_none": (
42+
"none",
43+
(torch.randn(1, 4, 4, 4) - 10),
44+
),
45+
"ramp_none": (
46+
"none",
47+
torch.arange(-16, 16, 0.2),
48+
),
49+
"zeros_tanh": (
50+
"tanh",
51+
torch.zeros(1, 10, 10, 10),
52+
),
53+
"ones_tanh": (
54+
"tanh",
55+
torch.ones(10, 10, 10),
56+
),
57+
"rand_tanh": (
58+
"tanh",
59+
(torch.rand(10, 10) - 0.5),
60+
),
61+
"randn_pos_tanh": (
62+
"tanh",
63+
(torch.randn(1, 4, 4, 4) + 10),
64+
),
65+
"randn_neg_tanh": (
66+
"tanh",
67+
(torch.randn(1, 4, 4, 4) - 10),
68+
),
69+
"ramp_tanh": (
70+
"tanh",
71+
torch.arange(-16, 16, 0.2),
72+
),
73+
}
74+
75+
def __init__(self, approximate: str = "none"):
76+
super().__init__()
77+
self.gelu = torch.nn.GELU(approximate)
78+
79+
def forward(self, x: torch.Tensor):
80+
return self.gelu(x)
81+
82+
83+
@common.parametrize("test_data", Gelu.test_data)
84+
def test_gelu_tosa_MI(test_data: input_t1):
85+
approximate = test_data[0]
86+
TosaPipelineMI[input_t1](
87+
Gelu(approximate),
88+
(test_data[1],),
89+
Gelu.aten_op,
90+
Gelu.exir_op,
91+
use_to_edge_transform_and_lower=False,
92+
).run()
93+
94+
95+
@common.parametrize("test_data", Gelu.test_data)
96+
def test_gelu_tosa_BI(test_data: input_t1):
97+
approximate = test_data[0]
98+
TosaPipelineBI[input_t1](
99+
Gelu(approximate),
100+
(test_data[1],),
101+
Gelu.aten_op,
102+
Gelu.exir_op,
103+
).run()
104+
105+
106+
@common.parametrize("test_data", Gelu.test_data)
107+
def test_gelu_u55_BI(test_data: input_t1):
108+
approximate = test_data[0]
109+
EthosU55PipelineBI[input_t1](
110+
Gelu(approximate),
111+
(test_data[1],),
112+
Gelu.aten_op,
113+
Gelu.exir_op,
114+
).run()
115+
116+
117+
@common.parametrize("test_data", Gelu.test_data)
118+
def test_gelu_u85_BI(test_data: input_t1):
119+
approximate = test_data[0]
120+
EthosU85PipelineBI[input_t1](
121+
Gelu(approximate),
122+
(test_data[1],),
123+
Gelu.aten_op,
124+
Gelu.exir_op,
125+
).run()

0 commit comments

Comments
 (0)