diff --git a/terraform/test-fixtures/transform-provider-prune/main.tf b/terraform/test-fixtures/transform-provider-prune/main.tf new file mode 100644 index 000000000..986f8840b --- /dev/null +++ b/terraform/test-fixtures/transform-provider-prune/main.tf @@ -0,0 +1,2 @@ +provider "aws" {} +resource "foo_instance" "web" {} diff --git a/terraform/transform_provider.go b/terraform/transform_provider.go index 7a6220df0..a7bcca478 100644 --- a/terraform/transform_provider.go +++ b/terraform/transform_provider.go @@ -71,6 +71,27 @@ func (t *MissingProviderTransformer) Transform(g *Graph) error { return nil } +// PruneProviderTransformer is a GraphTransformer that prunes all the +// providers that aren't needed from the graph. A provider is unneeded if +// no resource or module is using that provider. +type PruneProviderTransformer struct{} + +func (t *PruneProviderTransformer) Transform(g *Graph) error { + for _, v := range g.Vertices() { + // We only care about the providers + if _, ok := v.(GraphNodeProvider); !ok { + continue + } + + // Does anything depend on this? If not, then prune it. + if s := g.UpEdges(v); s.Len() == 0 { + g.Remove(v) + } + } + + return nil +} + type graphNodeMissingProvider struct { ProviderNameValue string } diff --git a/terraform/transform_provider_test.go b/terraform/transform_provider_test.go index 2389a9e58..fdf25b7a4 100644 --- a/terraform/transform_provider_test.go +++ b/terraform/transform_provider_test.go @@ -53,6 +53,45 @@ func TestMissingProviderTransformer(t *testing.T) { } } +func TestPruneProviderTransformer(t *testing.T) { + mod := testModule(t, "transform-provider-prune") + + g := Graph{Path: RootModulePath} + { + tf := &ConfigTransformer{Module: mod} + if err := tf.Transform(&g); err != nil { + t.Fatalf("err: %s", err) + } + } + + { + transform := &MissingProviderTransformer{Providers: []string{"foo"}} + if err := transform.Transform(&g); err != nil { + t.Fatalf("err: %s", err) + } + } + + { + transform := &ProviderTransformer{} + if err := transform.Transform(&g); err != nil { + t.Fatalf("err: %s", err) + } + } + + { + transform := &PruneProviderTransformer{} + if err := transform.Transform(&g); err != nil { + t.Fatalf("err: %s", err) + } + } + + actual := strings.TrimSpace(g.String()) + expected := strings.TrimSpace(testTransformPruneProviderBasicStr) + if actual != expected { + t.Fatalf("bad:\n\n%s", actual) + } +} + func TestGraphNodeMissingProvider_impl(t *testing.T) { var _ dag.Vertex = new(graphNodeMissingProvider) var _ dag.NamedVertex = new(graphNodeMissingProvider) @@ -77,3 +116,9 @@ aws_instance.web provider.aws provider.foo ` + +const testTransformPruneProviderBasicStr = ` +foo_instance.web + provider.foo +provider.foo +`