Skip to content

Commit d71dab7

Browse files
authored
fix ConstExpr to handle custom functions returned error (#239)
1 parent fba6d31 commit d71dab7

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

expr_test.go

+26-1
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,7 @@ func TestExpr_call_floatarg_func_with_int(t *testing.T) {
11601160
}
11611161
}
11621162

1163-
func TestConstExpr_error(t *testing.T) {
1163+
func TestConstExpr_error_panic(t *testing.T) {
11641164
env := map[string]interface{}{
11651165
"divide": func(a, b int) int { return a / b },
11661166
}
@@ -1174,6 +1174,31 @@ func TestConstExpr_error(t *testing.T) {
11741174
require.Equal(t, "compile error: integer divide by zero (1:5)\n | 1 + divide(1, 0)\n | ....^", err.Error())
11751175
}
11761176

1177+
type divideError struct{ Message string }
1178+
1179+
func (e divideError) Error() string {
1180+
return e.Message
1181+
}
1182+
func TestConstExpr_error_as_error(t *testing.T) {
1183+
env := map[string]interface{}{
1184+
"divide": func(a, b int) (int, error) {
1185+
if b == 0 {
1186+
return 0, divideError{"integer divide by zero"}
1187+
}
1188+
return a / b, nil
1189+
},
1190+
}
1191+
1192+
_, err := expr.Compile(
1193+
`1 + divide(1, 0)`,
1194+
expr.Env(env),
1195+
expr.ConstExpr("divide"),
1196+
)
1197+
require.Error(t, err)
1198+
require.Equal(t, "integer divide by zero", err.Error())
1199+
require.IsType(t, divideError{}, err)
1200+
}
1201+
11771202
func TestConstExpr_error_wrong_type(t *testing.T) {
11781203
env := map[string]interface{}{
11791204
"divide": 0,

optimizer/const_expr.go

+11-3
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ package optimizer
22

33
import (
44
"fmt"
5-
. "github.com/antonmedv/expr/ast"
6-
"github.com/antonmedv/expr/file"
75
"reflect"
86
"strings"
7+
8+
. "github.com/antonmedv/expr/ast"
9+
"github.com/antonmedv/expr/file"
910
)
1011

12+
var errorType = reflect.TypeOf((*error)(nil)).Elem()
13+
1114
type constExpr struct {
1215
applied bool
1316
err error
@@ -70,7 +73,12 @@ func (c *constExpr) Exit(node *Node) {
7073
}
7174

7275
out := fn.Call(in)
73-
constNode := &ConstantNode{Value: out[0].Interface()}
76+
value := out[0].Interface()
77+
if len(out) == 2 && out[1].Type() == errorType && !out[1].IsNil() {
78+
c.err = out[1].Interface().(error)
79+
return
80+
}
81+
constNode := &ConstantNode{Value: value}
7482
patch(constNode)
7583
}
7684
}

0 commit comments

Comments
 (0)