diff --git a/lang/funcs/collection.go b/lang/funcs/collection.go index 1257895e3..ab68a6411 100644 --- a/lang/funcs/collection.go +++ b/lang/funcs/collection.go @@ -800,10 +800,12 @@ var MatchkeysFunc = function.New(&function.Spec{ }, }, Type: func(args []cty.Value) (cty.Type, error) { - if !args[1].Type().Equals(args[2].Type()) { - return cty.NilType, errors.New("lists must be of the same type") + ty, _ := convert.UnifyUnsafe([]cty.Type{args[1].Type(), args[2].Type()}) + if ty == cty.NilType { + return cty.NilType, errors.New("keys and searchset must be of the same type") } + // the return type is based on args[0] (values) return args[0].Type(), nil }, Impl: func(args []cty.Value, retType cty.Type) (ret cty.Value, err error) { @@ -816,10 +818,14 @@ var MatchkeysFunc = function.New(&function.Spec{ } output := make([]cty.Value, 0) - values := args[0] - keys := args[1] - searchset := args[2] + + // Keys and searchset must be the same type. + // We can skip error checking here because we've already verified that + // they can be unified in the Type function + ty, _ := convert.UnifyUnsafe([]cty.Type{args[1].Type(), args[2].Type()}) + keys, _ := convert.Convert(args[1], ty) + searchset, _ := convert.Convert(args[2], ty) // if searchset is empty, return an empty list. if searchset.LengthInt() == 0 { diff --git a/lang/funcs/collection_test.go b/lang/funcs/collection_test.go index e408d385b..581e212ff 100644 --- a/lang/funcs/collection_test.go +++ b/lang/funcs/collection_test.go @@ -1883,8 +1883,7 @@ func TestMatchkeys(t *testing.T) { cty.UnknownVal(cty.List(cty.String)), false, }, - // errors - { // different types + { // different types that can be unified cty.ListVal([]cty.Value{ cty.StringVal("a"), }), @@ -1894,9 +1893,41 @@ func TestMatchkeys(t *testing.T) { cty.ListVal([]cty.Value{ cty.StringVal("a"), }), - cty.NilVal, - true, + cty.ListValEmpty(cty.String), + false, }, + { // complex values: values is a different type from keys and searchset + cty.ListVal([]cty.Value{ + cty.MapVal(map[string]cty.Value{ + "foo": cty.StringVal("bar"), + }), + cty.MapVal(map[string]cty.Value{ + "foo": cty.StringVal("baz"), + }), + cty.MapVal(map[string]cty.Value{ + "foo": cty.StringVal("beep"), + }), + }), + cty.ListVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + cty.StringVal("c"), + }), + cty.ListVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("c"), + }), + cty.ListVal([]cty.Value{ + cty.MapVal(map[string]cty.Value{ + "foo": cty.StringVal("bar"), + }), + cty.MapVal(map[string]cty.Value{ + "foo": cty.StringVal("beep"), + }), + }), + false, + }, + // errors { // different types cty.ListVal([]cty.Value{ cty.StringVal("a"), diff --git a/lang/functions_test.go b/lang/functions_test.go index 5fdca59d3..8717ec7cd 100644 --- a/lang/functions_test.go +++ b/lang/functions_test.go @@ -459,6 +459,13 @@ func TestFunctions(t *testing.T) { cty.StringVal("a"), }), }, + { // mixing types in searchset + `matchkeys(["a", "b", "c"], [1, 2, 3], [1, "3"])`, + cty.ListVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("c"), + }), + }, }, "max": {