Merge pull request #27249 from hashicorp/alisdair/sum-func-robustness

lang: Improved robustness of sum function
This commit is contained in:
Alisdair McDiarmid 2020-12-11 09:29:14 -05:00 committed by GitHub
commit 9b0af78f24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 87 additions and 32 deletions

View File

@ -3,13 +3,13 @@ package funcs
import ( import (
"errors" "errors"
"fmt" "fmt"
"math/big"
"sort" "sort"
"github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/convert" "github.com/zclconf/go-cty/cty/convert"
"github.com/zclconf/go-cty/cty/function" "github.com/zclconf/go-cty/cty/function"
"github.com/zclconf/go-cty/cty/function/stdlib" "github.com/zclconf/go-cty/cty/function/stdlib"
"github.com/zclconf/go-cty/cty/gocty"
) )
var LengthFunc = function.New(&function.Spec{ var LengthFunc = function.New(&function.Spec{
@ -404,27 +404,45 @@ var SumFunc = function.New(&function.Spec{
arg := args[0].AsValueSlice() arg := args[0].AsValueSlice()
ty := args[0].Type() ty := args[0].Type()
var i float64
var s float64
if !ty.IsListType() && !ty.IsSetType() && !ty.IsTupleType() { if !ty.IsListType() && !ty.IsSetType() && !ty.IsTupleType() {
return cty.NilVal, function.NewArgErrorf(0, fmt.Sprintf("argument must be list, set, or tuple. Received %s", ty.FriendlyName())) return cty.NilVal, function.NewArgErrorf(0, fmt.Sprintf("argument must be list, set, or tuple. Received %s", ty.FriendlyName()))
} }
if !args[0].IsKnown() { if !args[0].IsWhollyKnown() {
return cty.UnknownVal(cty.Number), nil return cty.UnknownVal(cty.Number), nil
} }
for _, v := range arg { // big.Float.Add can panic if the input values are opposing infinities,
// so we must catch that here in order to remain within
if err := gocty.FromCtyValue(v, &i); err != nil { // the cty Function abstraction.
return cty.UnknownVal(cty.Number), function.NewArgErrorf(0, "argument must be list, set, or tuple of number values") defer func() {
if r := recover(); r != nil {
if _, ok := r.(big.ErrNaN); ok {
ret = cty.NilVal
err = fmt.Errorf("can't compute sum of opposing infinities")
} else { } else {
s += i // not a panic we recognize
panic(r)
} }
} }
}()
s := arg[0]
if s.IsNull() {
return cty.NilVal, function.NewArgErrorf(0, "argument must be list, set, or tuple of number values")
}
for _, v := range arg[1:] {
if v.IsNull() {
return cty.NilVal, function.NewArgErrorf(0, "argument must be list, set, or tuple of number values")
}
v, err = convert.Convert(v, cty.Number)
if err != nil {
return cty.NilVal, function.NewArgErrorf(0, "argument must be list, set, or tuple of number values")
}
s = s.Add(v)
}
return cty.NumberFloatVal(s), nil return s, nil
}, },
}) })

View File

@ -2,6 +2,7 @@ package funcs
import ( import (
"fmt" "fmt"
"math"
"testing" "testing"
"github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty"
@ -996,7 +997,7 @@ func TestSum(t *testing.T) {
tests := []struct { tests := []struct {
List cty.Value List cty.Value
Want cty.Value Want cty.Value
Err bool Err string
}{ }{
{ {
cty.ListVal([]cty.Value{ cty.ListVal([]cty.Value{
@ -1005,7 +1006,7 @@ func TestSum(t *testing.T) {
cty.NumberIntVal(3), cty.NumberIntVal(3),
}), }),
cty.NumberIntVal(6), cty.NumberIntVal(6),
false, "",
}, },
{ {
cty.ListVal([]cty.Value{ cty.ListVal([]cty.Value{
@ -1016,7 +1017,7 @@ func TestSum(t *testing.T) {
cty.NumberIntVal(234), cty.NumberIntVal(234),
}), }),
cty.NumberIntVal(66685532), cty.NumberIntVal(66685532),
false, "",
}, },
{ {
cty.ListVal([]cty.Value{ cty.ListVal([]cty.Value{
@ -1025,7 +1026,7 @@ func TestSum(t *testing.T) {
cty.StringVal("c"), cty.StringVal("c"),
}), }),
cty.UnknownVal(cty.String), cty.UnknownVal(cty.String),
true, "argument must be list, set, or tuple of number values",
}, },
{ {
cty.ListVal([]cty.Value{ cty.ListVal([]cty.Value{
@ -1034,7 +1035,7 @@ func TestSum(t *testing.T) {
cty.NumberIntVal(5), cty.NumberIntVal(5),
}), }),
cty.NumberIntVal(-4), cty.NumberIntVal(-4),
false, "",
}, },
{ {
cty.ListVal([]cty.Value{ cty.ListVal([]cty.Value{
@ -1043,7 +1044,7 @@ func TestSum(t *testing.T) {
cty.NumberFloatVal(5.7), cty.NumberFloatVal(5.7),
}), }),
cty.NumberFloatVal(35.3), cty.NumberFloatVal(35.3),
false, "",
}, },
{ {
cty.ListVal([]cty.Value{ cty.ListVal([]cty.Value{
@ -1052,12 +1053,20 @@ func TestSum(t *testing.T) {
cty.NumberFloatVal(-5.7), cty.NumberFloatVal(-5.7),
}), }),
cty.NumberFloatVal(-35.3), cty.NumberFloatVal(-35.3),
false, "",
}, },
{ {
cty.ListVal([]cty.Value{cty.NullVal(cty.Number)}), cty.ListVal([]cty.Value{cty.NullVal(cty.Number)}),
cty.NilVal, cty.NilVal,
true, "argument must be list, set, or tuple of number values",
},
{
cty.ListVal([]cty.Value{
cty.NumberIntVal(5),
cty.NullVal(cty.Number),
}),
cty.NilVal,
"argument must be list, set, or tuple of number values",
}, },
{ {
cty.SetVal([]cty.Value{ cty.SetVal([]cty.Value{
@ -1066,7 +1075,7 @@ func TestSum(t *testing.T) {
cty.StringVal("c"), cty.StringVal("c"),
}), }),
cty.UnknownVal(cty.String), cty.UnknownVal(cty.String),
true, "argument must be list, set, or tuple of number values",
}, },
{ {
cty.SetVal([]cty.Value{ cty.SetVal([]cty.Value{
@ -1075,7 +1084,7 @@ func TestSum(t *testing.T) {
cty.NumberIntVal(5), cty.NumberIntVal(5),
}), }),
cty.NumberIntVal(-4), cty.NumberIntVal(-4),
false, "",
}, },
{ {
cty.SetVal([]cty.Value{ cty.SetVal([]cty.Value{
@ -1084,7 +1093,7 @@ func TestSum(t *testing.T) {
cty.NumberIntVal(30), cty.NumberIntVal(30),
}), }),
cty.NumberIntVal(65), cty.NumberIntVal(65),
false, "",
}, },
{ {
cty.SetVal([]cty.Value{ cty.SetVal([]cty.Value{
@ -1093,14 +1102,14 @@ func TestSum(t *testing.T) {
cty.NumberFloatVal(3), cty.NumberFloatVal(3),
}), }),
cty.NumberFloatVal(2354), cty.NumberFloatVal(2354),
false, "",
}, },
{ {
cty.SetVal([]cty.Value{ cty.SetVal([]cty.Value{
cty.NumberFloatVal(2), cty.NumberFloatVal(2),
}), }),
cty.NumberFloatVal(2), cty.NumberFloatVal(2),
false, "",
}, },
{ {
cty.SetVal([]cty.Value{ cty.SetVal([]cty.Value{
@ -1111,7 +1120,7 @@ func TestSum(t *testing.T) {
cty.NumberFloatVal(-4), cty.NumberFloatVal(-4),
}), }),
cty.NumberFloatVal(-199), cty.NumberFloatVal(-199),
false, "",
}, },
{ {
cty.TupleVal([]cty.Value{ cty.TupleVal([]cty.Value{
@ -1120,27 +1129,53 @@ func TestSum(t *testing.T) {
cty.NumberIntVal(38), cty.NumberIntVal(38),
}), }),
cty.UnknownVal(cty.String), cty.UnknownVal(cty.String),
true, "argument must be list, set, or tuple of number values",
}, },
{ {
cty.NumberIntVal(12), cty.NumberIntVal(12),
cty.NilVal, cty.NilVal,
true, "cannot sum noniterable",
}, },
{ {
cty.ListValEmpty(cty.Number), cty.ListValEmpty(cty.Number),
cty.NilVal, cty.NilVal,
true, "cannot sum an empty list",
}, },
{ {
cty.MapVal(map[string]cty.Value{"hello": cty.True}), cty.MapVal(map[string]cty.Value{"hello": cty.True}),
cty.NilVal, cty.NilVal,
true, "argument must be list, set, or tuple. Received map of bool",
}, },
{ {
cty.UnknownVal(cty.Number), cty.UnknownVal(cty.Number),
cty.UnknownVal(cty.Number), cty.UnknownVal(cty.Number),
false, "",
},
{
cty.UnknownVal(cty.List(cty.Number)),
cty.UnknownVal(cty.Number),
"",
},
{ // known list containing unknown values
cty.ListVal([]cty.Value{cty.UnknownVal(cty.Number)}),
cty.UnknownVal(cty.Number),
"",
},
{ // numbers too large to represent as float64
cty.ListVal([]cty.Value{
cty.MustParseNumberVal("1e+500"),
cty.MustParseNumberVal("1e+500"),
}),
cty.MustParseNumberVal("2e+500"),
"",
},
{ // edge case we have a special error handler for
cty.ListVal([]cty.Value{
cty.NumberFloatVal(math.Inf(1)),
cty.NumberFloatVal(math.Inf(-1)),
}),
cty.NilVal,
"can't compute sum of opposing infinities",
}, },
} }
@ -1148,9 +1183,11 @@ func TestSum(t *testing.T) {
t.Run(fmt.Sprintf("sum(%#v)", test.List), func(t *testing.T) { t.Run(fmt.Sprintf("sum(%#v)", test.List), func(t *testing.T) {
got, err := Sum(test.List) got, err := Sum(test.List)
if test.Err { if test.Err != "" {
if err == nil { if err == nil {
t.Fatal("succeeded; want error") t.Fatal("succeeded; want error")
} else if got, want := err.Error(), test.Err; got != want {
t.Fatalf("wrong error\n got: %s\nwant: %s", got, want)
} }
return return
} else if err != nil { } else if err != nil {