diff --git a/config/module/test-fixtures/child/foo/bar/main.tf b/config/module/test-fixtures/child/foo/bar/main.tf new file mode 100644 index 000000000..df5927501 --- /dev/null +++ b/config/module/test-fixtures/child/foo/bar/main.tf @@ -0,0 +1,2 @@ +# Hello + diff --git a/config/module/test-fixtures/child/foo/main.tf b/config/module/test-fixtures/child/foo/main.tf new file mode 100644 index 000000000..548d21b99 --- /dev/null +++ b/config/module/test-fixtures/child/foo/main.tf @@ -0,0 +1,5 @@ +# Hello + +module "bar" { + source = "./bar" +} diff --git a/config/module/test-fixtures/child/main.tf b/config/module/test-fixtures/child/main.tf new file mode 100644 index 000000000..383063715 --- /dev/null +++ b/config/module/test-fixtures/child/main.tf @@ -0,0 +1,5 @@ +# Hello + +module "foo" { + source = "./foo" +} diff --git a/config/module/tree.go b/config/module/tree.go index bb2afc16e..33afbd262 100644 --- a/config/module/tree.go +++ b/config/module/tree.go @@ -67,6 +67,24 @@ func (t *Tree) Config() *config.Config { return t.config } +// Child returns the child with the given path (by name). +func (t *Tree) Child(path []string) *Tree { + if len(path) == 0 { + return nil + } + + c := t.Children()[path[0]] + if c == nil { + return nil + } + + if len(path) == 1 { + return c + } + + return c.Child(path[1:]) +} + // Children returns the children of this tree (the modules that are // imported by this root). // diff --git a/config/module/tree_test.go b/config/module/tree_test.go index 837519d27..ae718a261 100644 --- a/config/module/tree_test.go +++ b/config/module/tree_test.go @@ -6,6 +6,28 @@ import ( "testing" ) +func TestTreeChild(t *testing.T) { + storage := testStorage(t) + tree := NewTree("", testConfig(t, "child")) + if err := tree.Load(storage, GetModeGet); err != nil { + t.Fatalf("err: %s", err) + } + + // Should be able to get the foo child + if c := tree.Child([]string{"foo"}); c == nil { + t.Fatal("should not be nil") + } else if c.Name() != "foo" { + t.Fatalf("bad: %#v", c.Name()) + } + + // Should be able to get the nested child + if c := tree.Child([]string{"foo", "bar"}); c == nil { + t.Fatal("should not be nil") + } else if c.Name() != "bar" { + t.Fatalf("bad: %#v", c.Name()) + } +} + func TestTreeLoad(t *testing.T) { storage := testStorage(t) tree := NewTree("", testConfig(t, "basic"))