diff --git a/lang/funcs/collection.go b/lang/funcs/collection.go index 4324e589d..33073d747 100644 --- a/lang/funcs/collection.go +++ b/lang/funcs/collection.go @@ -3,13 +3,13 @@ package funcs import ( "errors" "fmt" + "math/big" "sort" "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/convert" "github.com/zclconf/go-cty/cty/function" "github.com/zclconf/go-cty/cty/function/stdlib" - "github.com/zclconf/go-cty/cty/gocty" ) var LengthFunc = function.New(&function.Spec{ @@ -404,27 +404,45 @@ var SumFunc = function.New(&function.Spec{ arg := args[0].AsValueSlice() ty := args[0].Type() - var i float64 - var s float64 - if !ty.IsListType() && !ty.IsSetType() && !ty.IsTupleType() { return cty.NilVal, function.NewArgErrorf(0, fmt.Sprintf("argument must be list, set, or tuple. Received %s", ty.FriendlyName())) } - if !args[0].IsKnown() { + if !args[0].IsWhollyKnown() { return cty.UnknownVal(cty.Number), nil } - for _, v := range arg { - - if err := gocty.FromCtyValue(v, &i); err != nil { - return cty.UnknownVal(cty.Number), function.NewArgErrorf(0, "argument must be list, set, or tuple of number values") - } else { - s += i + // big.Float.Add can panic if the input values are opposing infinities, + // so we must catch that here in order to remain within + // the cty Function abstraction. + defer func() { + if r := recover(); r != nil { + if _, ok := r.(big.ErrNaN); ok { + ret = cty.NilVal + err = fmt.Errorf("can't compute sum of opposing infinities") + } else { + // not a panic we recognize + panic(r) + } } + }() + + s := arg[0] + if s.IsNull() { + return cty.NilVal, function.NewArgErrorf(0, "argument must be list, set, or tuple of number values") + } + for _, v := range arg[1:] { + if v.IsNull() { + return cty.NilVal, function.NewArgErrorf(0, "argument must be list, set, or tuple of number values") + } + v, err = convert.Convert(v, cty.Number) + if err != nil { + return cty.NilVal, function.NewArgErrorf(0, "argument must be list, set, or tuple of number values") + } + s = s.Add(v) } - return cty.NumberFloatVal(s), nil + return s, nil }, }) diff --git a/lang/funcs/collection_test.go b/lang/funcs/collection_test.go index 457ef7f73..0b61738ac 100644 --- a/lang/funcs/collection_test.go +++ b/lang/funcs/collection_test.go @@ -2,6 +2,7 @@ package funcs import ( "fmt" + "math" "testing" "github.com/zclconf/go-cty/cty" @@ -996,7 +997,7 @@ func TestSum(t *testing.T) { tests := []struct { List cty.Value Want cty.Value - Err bool + Err string }{ { cty.ListVal([]cty.Value{ @@ -1005,7 +1006,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(3), }), cty.NumberIntVal(6), - false, + "", }, { cty.ListVal([]cty.Value{ @@ -1016,7 +1017,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(234), }), cty.NumberIntVal(66685532), - false, + "", }, { cty.ListVal([]cty.Value{ @@ -1025,7 +1026,7 @@ func TestSum(t *testing.T) { cty.StringVal("c"), }), cty.UnknownVal(cty.String), - true, + "argument must be list, set, or tuple of number values", }, { cty.ListVal([]cty.Value{ @@ -1034,7 +1035,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(5), }), cty.NumberIntVal(-4), - false, + "", }, { cty.ListVal([]cty.Value{ @@ -1043,7 +1044,7 @@ func TestSum(t *testing.T) { cty.NumberFloatVal(5.7), }), cty.NumberFloatVal(35.3), - false, + "", }, { cty.ListVal([]cty.Value{ @@ -1052,12 +1053,20 @@ func TestSum(t *testing.T) { cty.NumberFloatVal(-5.7), }), cty.NumberFloatVal(-35.3), - false, + "", }, { cty.ListVal([]cty.Value{cty.NullVal(cty.Number)}), cty.NilVal, - true, + "argument must be list, set, or tuple of number values", + }, + { + cty.ListVal([]cty.Value{ + cty.NumberIntVal(5), + cty.NullVal(cty.Number), + }), + cty.NilVal, + "argument must be list, set, or tuple of number values", }, { cty.SetVal([]cty.Value{ @@ -1066,7 +1075,7 @@ func TestSum(t *testing.T) { cty.StringVal("c"), }), cty.UnknownVal(cty.String), - true, + "argument must be list, set, or tuple of number values", }, { cty.SetVal([]cty.Value{ @@ -1075,7 +1084,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(5), }), cty.NumberIntVal(-4), - false, + "", }, { cty.SetVal([]cty.Value{ @@ -1084,7 +1093,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(30), }), cty.NumberIntVal(65), - false, + "", }, { cty.SetVal([]cty.Value{ @@ -1093,14 +1102,14 @@ func TestSum(t *testing.T) { cty.NumberFloatVal(3), }), cty.NumberFloatVal(2354), - false, + "", }, { cty.SetVal([]cty.Value{ cty.NumberFloatVal(2), }), cty.NumberFloatVal(2), - false, + "", }, { cty.SetVal([]cty.Value{ @@ -1111,7 +1120,7 @@ func TestSum(t *testing.T) { cty.NumberFloatVal(-4), }), cty.NumberFloatVal(-199), - false, + "", }, { cty.TupleVal([]cty.Value{ @@ -1120,27 +1129,53 @@ func TestSum(t *testing.T) { cty.NumberIntVal(38), }), cty.UnknownVal(cty.String), - true, + "argument must be list, set, or tuple of number values", }, { cty.NumberIntVal(12), cty.NilVal, - true, + "cannot sum noniterable", }, { cty.ListValEmpty(cty.Number), cty.NilVal, - true, + "cannot sum an empty list", }, { cty.MapVal(map[string]cty.Value{"hello": cty.True}), cty.NilVal, - true, + "argument must be list, set, or tuple. Received map of bool", }, { cty.UnknownVal(cty.Number), cty.UnknownVal(cty.Number), - false, + "", + }, + { + cty.UnknownVal(cty.List(cty.Number)), + cty.UnknownVal(cty.Number), + "", + }, + { // known list containing unknown values + cty.ListVal([]cty.Value{cty.UnknownVal(cty.Number)}), + cty.UnknownVal(cty.Number), + "", + }, + { // numbers too large to represent as float64 + cty.ListVal([]cty.Value{ + cty.MustParseNumberVal("1e+500"), + cty.MustParseNumberVal("1e+500"), + }), + cty.MustParseNumberVal("2e+500"), + "", + }, + { // edge case we have a special error handler for + cty.ListVal([]cty.Value{ + cty.NumberFloatVal(math.Inf(1)), + cty.NumberFloatVal(math.Inf(-1)), + }), + cty.NilVal, + "can't compute sum of opposing infinities", }, } @@ -1148,9 +1183,11 @@ func TestSum(t *testing.T) { t.Run(fmt.Sprintf("sum(%#v)", test.List), func(t *testing.T) { got, err := Sum(test.List) - if test.Err { + if test.Err != "" { if err == nil { t.Fatal("succeeded; want error") + } else if got, want := err.Error(), test.Err; got != want { + t.Fatalf("wrong error\n got: %s\nwant: %s", got, want) } return } else if err != nil {