diff --git a/config/lang/ast/ast.go b/config/lang/ast/ast.go index 31951621f..dfe863f99 100644 --- a/config/lang/ast/ast.go +++ b/config/lang/ast/ast.go @@ -6,8 +6,9 @@ import ( // Node is the interface that all AST nodes must implement. type Node interface { - // Accept is called to dispatch to the visitors. - Accept(Visitor) + // Accept is called to dispatch to the visitors. It must return the + // resulting Node (which might be different in an AST transform). + Accept(Visitor) Node // Pos returns the position of this node in some source. Pos() Pos @@ -24,11 +25,18 @@ func (p Pos) String() string { // Visitors are just implementations of this function. // +// The function must return the Node to replace this node with. "nil" is +// _not_ a valid return value. If there is no replacement, the original node +// should be returned. We build this replacement directly into the visitor +// pattern since AST transformations are a common and useful tool and +// building it into the AST itself makes it required for future Node +// implementations and very easy to do. +// // Note that this isn't a true implementation of the visitor pattern, which // generally requires proper type dispatch on the function. However, // implementing this basic visitor pattern style is still very useful even // if you have to type switch. -type Visitor func(Node) +type Visitor func(Node) Node //go:generate stringer -type=Type diff --git a/config/lang/ast/call.go b/config/lang/ast/call.go index 40b0773e1..bbb632b7b 100644 --- a/config/lang/ast/call.go +++ b/config/lang/ast/call.go @@ -12,12 +12,12 @@ type Call struct { Posx Pos } -func (n *Call) Accept(v Visitor) { - for _, a := range n.Args { - a.Accept(v) +func (n *Call) Accept(v Visitor) Node { + for i, a := range n.Args { + n.Args[i] = a.Accept(v) } - v(n) + return v(n) } func (n *Call) Pos() Pos { diff --git a/config/lang/ast/concat.go b/config/lang/ast/concat.go index 238912697..871b0f44a 100644 --- a/config/lang/ast/concat.go +++ b/config/lang/ast/concat.go @@ -12,12 +12,12 @@ type Concat struct { Posx Pos } -func (n *Concat) Accept(v Visitor) { - for _, n := range n.Exprs { - n.Accept(v) +func (n *Concat) Accept(v Visitor) Node { + for i, expr := range n.Exprs { + n.Exprs[i] = expr.Accept(v) } - v(n) + return v(n) } func (n *Concat) Pos() Pos { diff --git a/config/lang/ast/literal.go b/config/lang/ast/literal.go index 1fd7669ff..b314fcc21 100644 --- a/config/lang/ast/literal.go +++ b/config/lang/ast/literal.go @@ -12,8 +12,8 @@ type LiteralNode struct { Posx Pos } -func (n *LiteralNode) Accept(v Visitor) { - v(n) +func (n *LiteralNode) Accept(v Visitor) Node { + return v(n) } func (n *LiteralNode) Pos() Pos { diff --git a/config/lang/ast/variable_access.go b/config/lang/ast/variable_access.go index 1f86a260d..148094a6a 100644 --- a/config/lang/ast/variable_access.go +++ b/config/lang/ast/variable_access.go @@ -10,8 +10,8 @@ type VariableAccess struct { Posx Pos } -func (n *VariableAccess) Accept(v Visitor) { - v(n) +func (n *VariableAccess) Accept(v Visitor) Node { + return v(n) } func (n *VariableAccess) Pos() Pos { diff --git a/config/lang/check_identifier.go b/config/lang/check_identifier.go index 2e467c098..10ee2267d 100644 --- a/config/lang/check_identifier.go +++ b/config/lang/check_identifier.go @@ -25,9 +25,9 @@ func (c *IdentifierCheck) Visit(root ast.Node) error { return c.err } -func (c *IdentifierCheck) visit(raw ast.Node) { +func (c *IdentifierCheck) visit(raw ast.Node) ast.Node { if c.err != nil { - return + return raw } switch n := raw.(type) { @@ -42,6 +42,9 @@ func (c *IdentifierCheck) visit(raw ast.Node) { default: c.createErr(n, fmt.Sprintf("unknown node: %#v", raw)) } + + // We never do replacement with this visitor + return raw } func (c *IdentifierCheck) visitCall(n *ast.Call) { diff --git a/config/lang/check_types.go b/config/lang/check_types.go index 9aec9af19..4491ea496 100644 --- a/config/lang/check_types.go +++ b/config/lang/check_types.go @@ -37,9 +37,9 @@ func (v *TypeCheck) Visit(root ast.Node) error { return v.err } -func (v *TypeCheck) visit(raw ast.Node) { +func (v *TypeCheck) visit(raw ast.Node) ast.Node { if v.err != nil { - return + return raw } switch n := raw.(type) { @@ -54,6 +54,8 @@ func (v *TypeCheck) visit(raw ast.Node) { default: v.createErr(n, fmt.Sprintf("unknown node: %#v", raw)) } + + return raw } func (v *TypeCheck) visitCall(n *ast.Call) { diff --git a/config/lang/engine.go b/config/lang/engine.go index 23d4ca6f7..b18db0f39 100644 --- a/config/lang/engine.go +++ b/config/lang/engine.go @@ -105,9 +105,9 @@ func (v *executeVisitor) Visit(root ast.Node) (interface{}, ast.Type, error) { return result.Value, result.Type, resultErr } -func (v *executeVisitor) visit(raw ast.Node) { +func (v *executeVisitor) visit(raw ast.Node) ast.Node { if v.err != nil { - return + return raw } switch n := raw.(type) { @@ -122,6 +122,8 @@ func (v *executeVisitor) visit(raw ast.Node) { default: v.err = fmt.Errorf("unknown node: %#v", raw) } + + return raw } func (v *executeVisitor) visitCall(n *ast.Call) { diff --git a/config/lang/transform_implicit_types_test.go b/config/lang/types_test.go similarity index 85% rename from config/lang/transform_implicit_types_test.go rename to config/lang/types_test.go index 9eb0fd92f..74513e24a 100644 --- a/config/lang/transform_implicit_types_test.go +++ b/config/lang/types_test.go @@ -64,11 +64,11 @@ func TestLookupType(t *testing.T) { type customUntyped struct{} -func (n customUntyped) Accept(ast.Visitor) {} -func (n customUntyped) Pos() (v ast.Pos) { return } +func (n customUntyped) Accept(ast.Visitor) ast.Node { return n } +func (n customUntyped) Pos() (v ast.Pos) { return } type customTyped struct{} -func (n customTyped) Accept(ast.Visitor) {} +func (n customTyped) Accept(ast.Visitor) ast.Node { return n } func (n customTyped) Pos() (v ast.Pos) { return } func (n customTyped) Type(*Scope) (ast.Type, error) { return ast.TypeString, nil }