From f27dae2ab76a9081ac72b4d52f2f0303fed8e6de Mon Sep 17 00:00:00 2001 From: Alisdair McDiarmid Date: Thu, 10 Dec 2020 16:56:06 -0500 Subject: [PATCH] lang: Improved robustness of sum function Fixes error when calling sum with values not known until apply time. Also allows sum to cope with numbers too large to represent in float64, along with correctly handling errors when trying to sum opposing infinities. --- lang/funcs/collection.go | 42 +++++++++++++------ lang/funcs/collection_test.go | 77 ++++++++++++++++++++++++++--------- 2 files changed, 87 insertions(+), 32 deletions(-) diff --git a/lang/funcs/collection.go b/lang/funcs/collection.go index 4324e589d..33073d747 100644 --- a/lang/funcs/collection.go +++ b/lang/funcs/collection.go @@ -3,13 +3,13 @@ package funcs import ( "errors" "fmt" + "math/big" "sort" "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/convert" "github.com/zclconf/go-cty/cty/function" "github.com/zclconf/go-cty/cty/function/stdlib" - "github.com/zclconf/go-cty/cty/gocty" ) var LengthFunc = function.New(&function.Spec{ @@ -404,27 +404,45 @@ var SumFunc = function.New(&function.Spec{ arg := args[0].AsValueSlice() ty := args[0].Type() - var i float64 - var s float64 - if !ty.IsListType() && !ty.IsSetType() && !ty.IsTupleType() { return cty.NilVal, function.NewArgErrorf(0, fmt.Sprintf("argument must be list, set, or tuple. Received %s", ty.FriendlyName())) } - if !args[0].IsKnown() { + if !args[0].IsWhollyKnown() { return cty.UnknownVal(cty.Number), nil } - for _, v := range arg { - - if err := gocty.FromCtyValue(v, &i); err != nil { - return cty.UnknownVal(cty.Number), function.NewArgErrorf(0, "argument must be list, set, or tuple of number values") - } else { - s += i + // big.Float.Add can panic if the input values are opposing infinities, + // so we must catch that here in order to remain within + // the cty Function abstraction. + defer func() { + if r := recover(); r != nil { + if _, ok := r.(big.ErrNaN); ok { + ret = cty.NilVal + err = fmt.Errorf("can't compute sum of opposing infinities") + } else { + // not a panic we recognize + panic(r) + } } + }() + + s := arg[0] + if s.IsNull() { + return cty.NilVal, function.NewArgErrorf(0, "argument must be list, set, or tuple of number values") + } + for _, v := range arg[1:] { + if v.IsNull() { + return cty.NilVal, function.NewArgErrorf(0, "argument must be list, set, or tuple of number values") + } + v, err = convert.Convert(v, cty.Number) + if err != nil { + return cty.NilVal, function.NewArgErrorf(0, "argument must be list, set, or tuple of number values") + } + s = s.Add(v) } - return cty.NumberFloatVal(s), nil + return s, nil }, }) diff --git a/lang/funcs/collection_test.go b/lang/funcs/collection_test.go index 457ef7f73..0b61738ac 100644 --- a/lang/funcs/collection_test.go +++ b/lang/funcs/collection_test.go @@ -2,6 +2,7 @@ package funcs import ( "fmt" + "math" "testing" "github.com/zclconf/go-cty/cty" @@ -996,7 +997,7 @@ func TestSum(t *testing.T) { tests := []struct { List cty.Value Want cty.Value - Err bool + Err string }{ { cty.ListVal([]cty.Value{ @@ -1005,7 +1006,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(3), }), cty.NumberIntVal(6), - false, + "", }, { cty.ListVal([]cty.Value{ @@ -1016,7 +1017,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(234), }), cty.NumberIntVal(66685532), - false, + "", }, { cty.ListVal([]cty.Value{ @@ -1025,7 +1026,7 @@ func TestSum(t *testing.T) { cty.StringVal("c"), }), cty.UnknownVal(cty.String), - true, + "argument must be list, set, or tuple of number values", }, { cty.ListVal([]cty.Value{ @@ -1034,7 +1035,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(5), }), cty.NumberIntVal(-4), - false, + "", }, { cty.ListVal([]cty.Value{ @@ -1043,7 +1044,7 @@ func TestSum(t *testing.T) { cty.NumberFloatVal(5.7), }), cty.NumberFloatVal(35.3), - false, + "", }, { cty.ListVal([]cty.Value{ @@ -1052,12 +1053,20 @@ func TestSum(t *testing.T) { cty.NumberFloatVal(-5.7), }), cty.NumberFloatVal(-35.3), - false, + "", }, { cty.ListVal([]cty.Value{cty.NullVal(cty.Number)}), cty.NilVal, - true, + "argument must be list, set, or tuple of number values", + }, + { + cty.ListVal([]cty.Value{ + cty.NumberIntVal(5), + cty.NullVal(cty.Number), + }), + cty.NilVal, + "argument must be list, set, or tuple of number values", }, { cty.SetVal([]cty.Value{ @@ -1066,7 +1075,7 @@ func TestSum(t *testing.T) { cty.StringVal("c"), }), cty.UnknownVal(cty.String), - true, + "argument must be list, set, or tuple of number values", }, { cty.SetVal([]cty.Value{ @@ -1075,7 +1084,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(5), }), cty.NumberIntVal(-4), - false, + "", }, { cty.SetVal([]cty.Value{ @@ -1084,7 +1093,7 @@ func TestSum(t *testing.T) { cty.NumberIntVal(30), }), cty.NumberIntVal(65), - false, + "", }, { cty.SetVal([]cty.Value{ @@ -1093,14 +1102,14 @@ func TestSum(t *testing.T) { cty.NumberFloatVal(3), }), cty.NumberFloatVal(2354), - false, + "", }, { cty.SetVal([]cty.Value{ cty.NumberFloatVal(2), }), cty.NumberFloatVal(2), - false, + "", }, { cty.SetVal([]cty.Value{ @@ -1111,7 +1120,7 @@ func TestSum(t *testing.T) { cty.NumberFloatVal(-4), }), cty.NumberFloatVal(-199), - false, + "", }, { cty.TupleVal([]cty.Value{ @@ -1120,27 +1129,53 @@ func TestSum(t *testing.T) { cty.NumberIntVal(38), }), cty.UnknownVal(cty.String), - true, + "argument must be list, set, or tuple of number values", }, { cty.NumberIntVal(12), cty.NilVal, - true, + "cannot sum noniterable", }, { cty.ListValEmpty(cty.Number), cty.NilVal, - true, + "cannot sum an empty list", }, { cty.MapVal(map[string]cty.Value{"hello": cty.True}), cty.NilVal, - true, + "argument must be list, set, or tuple. Received map of bool", }, { cty.UnknownVal(cty.Number), cty.UnknownVal(cty.Number), - false, + "", + }, + { + cty.UnknownVal(cty.List(cty.Number)), + cty.UnknownVal(cty.Number), + "", + }, + { // known list containing unknown values + cty.ListVal([]cty.Value{cty.UnknownVal(cty.Number)}), + cty.UnknownVal(cty.Number), + "", + }, + { // numbers too large to represent as float64 + cty.ListVal([]cty.Value{ + cty.MustParseNumberVal("1e+500"), + cty.MustParseNumberVal("1e+500"), + }), + cty.MustParseNumberVal("2e+500"), + "", + }, + { // edge case we have a special error handler for + cty.ListVal([]cty.Value{ + cty.NumberFloatVal(math.Inf(1)), + cty.NumberFloatVal(math.Inf(-1)), + }), + cty.NilVal, + "can't compute sum of opposing infinities", }, } @@ -1148,9 +1183,11 @@ func TestSum(t *testing.T) { t.Run(fmt.Sprintf("sum(%#v)", test.List), func(t *testing.T) { got, err := Sum(test.List) - if test.Err { + if test.Err != "" { if err == nil { t.Fatal("succeeded; want error") + } else if got, want := err.Error(), test.Err; got != want { + t.Fatalf("wrong error\n got: %s\nwant: %s", got, want) } return } else if err != nil {