Skip to content

Commit ded019d

Browse files
ckganesanantonmedv
authored andcommitted
Enable Support for Arrays in Sum, Mean, and Median Functions (#580)
1 parent 355fb28 commit ded019d

File tree

4 files changed

+206
-200
lines changed

4 files changed

+206
-200
lines changed

builtin/builtin.go

+29-172
Original file line numberDiff line numberDiff line change
@@ -135,42 +135,21 @@ var Builtins = []*Function{
135135
Name: "ceil",
136136
Fast: Ceil,
137137
Validate: func(args []reflect.Type) (reflect.Type, error) {
138-
if len(args) != 1 {
139-
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
140-
}
141-
switch kind(args[0]) {
142-
case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface:
143-
return floatType, nil
144-
}
145-
return anyType, fmt.Errorf("invalid argument for ceil (type %s)", args[0])
138+
return validateRoundFunc("ceil", args)
146139
},
147140
},
148141
{
149142
Name: "floor",
150143
Fast: Floor,
151144
Validate: func(args []reflect.Type) (reflect.Type, error) {
152-
if len(args) != 1 {
153-
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
154-
}
155-
switch kind(args[0]) {
156-
case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface:
157-
return floatType, nil
158-
}
159-
return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0])
145+
return validateRoundFunc("floor", args)
160146
},
161147
},
162148
{
163149
Name: "round",
164150
Fast: Round,
165151
Validate: func(args []reflect.Type) (reflect.Type, error) {
166-
if len(args) != 1 {
167-
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
168-
}
169-
switch kind(args[0]) {
170-
case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface:
171-
return floatType, nil
172-
}
173-
return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0])
152+
return validateRoundFunc("round", args)
174153
},
175154
},
176155
{
@@ -392,185 +371,63 @@ var Builtins = []*Function{
392371
},
393372
{
394373
Name: "max",
395-
Func: Max,
374+
Func: func(args ...any) (any, error) {
375+
return minMax("max", runtime.Less, args...)
376+
},
396377
Validate: func(args []reflect.Type) (reflect.Type, error) {
397-
switch len(args) {
398-
case 0:
399-
return anyType, fmt.Errorf("not enough arguments to call max")
400-
case 1:
401-
if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice {
402-
return anyType, nil
403-
}
404-
fallthrough
405-
default:
406-
for _, arg := range args {
407-
switch kind(arg) {
408-
case reflect.Interface, reflect.Array, reflect.Slice:
409-
return anyType, nil
410-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
411-
default:
412-
return anyType, fmt.Errorf("invalid argument for max (type %s)", arg)
413-
}
414-
}
415-
return args[0], nil
416-
}
378+
return validateAggregateFunc("max", args)
417379
},
418380
},
419381
{
420382
Name: "min",
421-
Func: Min,
383+
Func: func(args ...any) (any, error) {
384+
return minMax("min", runtime.More, args...)
385+
},
422386
Validate: func(args []reflect.Type) (reflect.Type, error) {
423-
switch len(args) {
424-
case 0:
425-
return anyType, fmt.Errorf("not enough arguments to call min")
426-
case 1:
427-
if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice {
428-
return anyType, nil
429-
}
430-
fallthrough
431-
default:
432-
for _, arg := range args {
433-
switch kind(arg) {
434-
case reflect.Interface, reflect.Array, reflect.Slice:
435-
return anyType, nil
436-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
437-
default:
438-
return anyType, fmt.Errorf("invalid argument for min (type %s)", arg)
439-
}
440-
}
441-
return args[0], nil
442-
443-
}
387+
return validateAggregateFunc("min", args)
444388
},
445389
},
446390
{
447391
Name: "sum",
448-
Func: func(args ...any) (any, error) {
449-
if len(args) != 1 {
450-
return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
451-
}
452-
v := reflect.ValueOf(args[0])
453-
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
454-
return nil, fmt.Errorf("cannot sum %s", v.Kind())
455-
}
456-
sum := int64(0)
457-
i := 0
458-
for ; i < v.Len(); i++ {
459-
it := deref.Value(v.Index(i))
460-
if it.CanInt() {
461-
sum += it.Int()
462-
} else if it.CanFloat() {
463-
goto float
464-
} else {
465-
return nil, fmt.Errorf("cannot sum %s", it.Kind())
466-
}
467-
}
468-
return int(sum), nil
469-
float:
470-
fSum := float64(sum)
471-
for ; i < v.Len(); i++ {
472-
it := deref.Value(v.Index(i))
473-
if it.CanInt() {
474-
fSum += float64(it.Int())
475-
} else if it.CanFloat() {
476-
fSum += it.Float()
477-
} else {
478-
return nil, fmt.Errorf("cannot sum %s", it.Kind())
479-
}
480-
}
481-
return fSum, nil
482-
},
392+
Func: sum,
483393
Validate: func(args []reflect.Type) (reflect.Type, error) {
484-
if len(args) != 1 {
485-
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
486-
}
487-
switch kind(args[0]) {
488-
case reflect.Interface, reflect.Slice, reflect.Array:
489-
default:
490-
return anyType, fmt.Errorf("cannot sum %s", args[0])
491-
}
492-
return anyType, nil
394+
return validateAggregateFunc("sum", args)
493395
},
494396
},
495397
{
496398
Name: "mean",
497399
Func: func(args ...any) (any, error) {
498-
if len(args) != 1 {
499-
return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
500-
}
501-
v := reflect.ValueOf(args[0])
502-
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
503-
return nil, fmt.Errorf("cannot mean %s", v.Kind())
400+
count, sum, err := mean(args...)
401+
if err != nil {
402+
return nil, err
504403
}
505-
if v.Len() == 0 {
404+
if count == 0 {
506405
return 0.0, nil
507406
}
508-
sum := float64(0)
509-
i := 0
510-
for ; i < v.Len(); i++ {
511-
it := deref.Value(v.Index(i))
512-
if it.CanInt() {
513-
sum += float64(it.Int())
514-
} else if it.CanFloat() {
515-
sum += it.Float()
516-
} else {
517-
return nil, fmt.Errorf("cannot mean %s", it.Kind())
518-
}
519-
}
520-
return sum / float64(i), nil
407+
return sum / float64(count), nil
521408
},
522409
Validate: func(args []reflect.Type) (reflect.Type, error) {
523-
if len(args) != 1 {
524-
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
525-
}
526-
switch kind(args[0]) {
527-
case reflect.Interface, reflect.Slice, reflect.Array:
528-
default:
529-
return anyType, fmt.Errorf("cannot avg %s", args[0])
530-
}
531-
return floatType, nil
410+
return validateAggregateFunc("mean", args)
532411
},
533412
},
534413
{
535414
Name: "median",
536415
Func: func(args ...any) (any, error) {
537-
if len(args) != 1 {
538-
return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
539-
}
540-
v := reflect.ValueOf(args[0])
541-
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
542-
return nil, fmt.Errorf("cannot median %s", v.Kind())
543-
}
544-
if v.Len() == 0 {
545-
return 0.0, nil
416+
values, err := median(args...)
417+
if err != nil {
418+
return nil, err
546419
}
547-
s := make([]float64, v.Len())
548-
for i := 0; i < v.Len(); i++ {
549-
it := deref.Value(v.Index(i))
550-
if it.CanInt() {
551-
s[i] = float64(it.Int())
552-
} else if it.CanFloat() {
553-
s[i] = it.Float()
554-
} else {
555-
return nil, fmt.Errorf("cannot median %s", it.Kind())
420+
if n := len(values); n > 0 {
421+
sort.Float64s(values)
422+
if n%2 == 1 {
423+
return values[n/2], nil
556424
}
425+
return (values[n/2-1] + values[n/2]) / 2, nil
557426
}
558-
sort.Float64s(s)
559-
if len(s)%2 == 0 {
560-
return (s[len(s)/2-1] + s[len(s)/2]) / 2, nil
561-
}
562-
return s[len(s)/2], nil
427+
return 0.0, nil
563428
},
564429
Validate: func(args []reflect.Type) (reflect.Type, error) {
565-
if len(args) != 1 {
566-
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
567-
}
568-
switch kind(args[0]) {
569-
case reflect.Interface, reflect.Slice, reflect.Array:
570-
default:
571-
return anyType, fmt.Errorf("cannot median %s", args[0])
572-
}
573-
return floatType, nil
430+
return validateAggregateFunc("median", args)
574431
},
575432
},
576433
{

builtin/builtin_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,29 @@ func TestBuiltin(t *testing.T) {
8585
{`min(1.5, 2.5, 3.5)`, 1.5},
8686
{`min([1, 2, 3])`, 1},
8787
{`min([1.5, 2.5, 3.5])`, 1.5},
88+
{`min(-1, [1.5, 2.5, 3.5])`, -1},
8889
{`sum(1..9)`, 45},
8990
{`sum([.5, 1.5, 2.5])`, 4.5},
9091
{`sum([])`, 0},
9192
{`sum([1, 2, 3.0, 4])`, 10.0},
93+
{`sum(10, [1, 2, 3], 1..9)`, 61},
94+
{`sum(-10, [1, 2, 3, 4])`, 0},
95+
{`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1},
9296
{`mean(1..9)`, 5.0},
9397
{`mean([.5, 1.5, 2.5])`, 1.5},
9498
{`mean([])`, 0.0},
9599
{`mean([1, 2, 3.0, 4])`, 2.5},
100+
{`mean(10, [1, 2, 3], 1..9)`, 4.6923076923076925},
101+
{`mean(-10, [1, 2, 3, 4])`, 0.0},
102+
{`mean(10.9, 1..9)`, 5.59},
96103
{`median(1..9)`, 5.0},
97104
{`median([.5, 1.5, 2.5])`, 1.5},
98105
{`median([])`, 0.0},
99106
{`median([1, 2, 3])`, 2.0},
100107
{`median([1, 2, 3, 4])`, 2.5},
108+
{`median(10, [1, 2, 3], 1..9)`, 4.0},
109+
{`median(-10, [1, 2, 3, 4])`, 2.0},
110+
{`median(1..5, 4.9)`, 3.5},
101111
{`toJSON({foo: 1, bar: 2})`, "{\n \"bar\": 2,\n \"foo\": 1\n}"},
102112
{`fromJSON("[1, 2, 3]")`, []any{1.0, 2.0, 3.0}},
103113
{`toBase64("hello")`, "aGVsbG8="},
@@ -207,6 +217,9 @@ func TestBuiltin_errors(t *testing.T) {
207217
{`min()`, `not enough arguments to call min`},
208218
{`min(1, "2")`, `invalid argument for min (type string)`},
209219
{`min([1, "2"])`, `invalid argument for min (type string)`},
220+
{`median(1..9, "t")`, "invalid argument for median (type string)"},
221+
{`mean("s", 1..9)`, "invalid argument for mean (type string)"},
222+
{`sum("s", "h")`, "invalid argument for sum (type string)"},
210223
{`duration("error")`, `invalid duration`},
211224
{`date("error")`, `invalid date`},
212225
{`get()`, `invalid number of arguments (expected 2, got 0)`},

0 commit comments

Comments
 (0)