config: DetectVariables to detect interpolated variables in an AST

This commit is contained in:
Mitchell Hashimoto 2015-01-12 12:09:30 -08:00
parent c05d7a6acd
commit aa2c7b2764
2 changed files with 86 additions and 0 deletions

View File

@ -5,6 +5,8 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"github.com/hashicorp/terraform/config/lang/ast"
) )
// We really need to replace this with a real parser. // We really need to replace this with a real parser.
@ -317,3 +319,39 @@ func (v *UserVariable) FullKey() string {
func (v *UserVariable) GoString() string { func (v *UserVariable) GoString() string {
return fmt.Sprintf("*%#v", *v) return fmt.Sprintf("*%#v", *v)
} }
// DetectVariables takes an AST root and returns all the interpolated
// variables that are detected in the AST tree.
func DetectVariables(root ast.Node) ([]InterpolatedVariable, error) {
var result []InterpolatedVariable
var resultErr error
// Visitor callback
fn := func(n ast.Node) {
if resultErr != nil {
return
}
vn, ok := n.(*ast.VariableAccess)
if !ok {
return
}
v, err := NewInterpolatedVariable(vn.Name)
if err != nil {
resultErr = err
return
}
result = append(result, v)
}
// Visitor pattern
root.Accept(fn)
if resultErr != nil {
return nil, resultErr
}
return result, nil
}

View File

@ -4,6 +4,8 @@ import (
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"github.com/hashicorp/terraform/config/lang"
) )
func TestNewInterpolatedVariable(t *testing.T) { func TestNewInterpolatedVariable(t *testing.T) {
@ -291,3 +293,49 @@ func TestVariableInterpolation_missing(t *testing.T) {
t.Fatal("should error") t.Fatal("should error")
} }
} }
func TestDetectVariables(t *testing.T) {
cases := []struct {
Input string
Result []InterpolatedVariable
}{
{
"foo ${var.foo}",
[]InterpolatedVariable{
&UserVariable{
Name: "foo",
key: "var.foo",
},
},
},
{
"foo ${var.foo} ${var.bar}",
[]InterpolatedVariable{
&UserVariable{
Name: "foo",
key: "var.foo",
},
&UserVariable{
Name: "bar",
key: "var.bar",
},
},
},
}
for _, tc := range cases {
ast, err := lang.Parse(tc.Input)
if err != nil {
t.Fatalf("%s\n\nInput: %s", err, tc.Input)
}
actual, err := DetectVariables(ast)
if err != nil {
t.Fatalf("err: %s", err)
}
if !reflect.DeepEqual(actual, tc.Result) {
t.Fatalf("bad: %#v\n\nInput: %s", actual, tc.Input)
}
}
}