From 851e6dcdbb0f80c62da2ad9977d1cbe63975fff7 Mon Sep 17 00:00:00 2001 From: Kristin Laemmert Date: Tue, 11 Feb 2020 13:17:37 -0500 Subject: [PATCH] configs: added map to configs.Module for provider local name lookup (#24039) * configs: added map of ProviderLocalNames to configs.Module We will need to lookup any user-supplied local names for a given FQN. This PR adds a map of ProviderLocalNames to the Module, along with adding tests for this and for decodeRequiredProvidersBlock. This also introduces the appearance of support for a required_provider "source" attribute, but ignores any user-supplied source and instead continues to assume that addrs.NewLegacyProvider is the way to go. --- configs/module.go | 40 +++++- configs/module_test.go | 34 +++++ configs/provider_requirements.go | 16 +-- configs/provider_requirements_test.go | 185 ++++++++++++++++++++++++++ 4 files changed, 263 insertions(+), 12 deletions(-) create mode 100644 configs/module_test.go create mode 100644 configs/provider_requirements_test.go diff --git a/configs/module.go b/configs/module.go index 010dc9aee..8bc45f375 100644 --- a/configs/module.go +++ b/configs/module.go @@ -31,6 +31,7 @@ type Module struct { Backend *Backend ProviderConfigs map[string]*Provider ProviderRequirements map[string]ProviderRequirements + ProviderLocalNames map[addrs.Provider]string Variables map[string]*Variable Locals map[string]*Local @@ -85,6 +86,7 @@ func NewModule(primaryFiles, overrideFiles []*File) (*Module, hcl.Diagnostics) { mod := &Module{ ProviderConfigs: map[string]*Provider{}, ProviderRequirements: map[string]ProviderRequirements{}, + ProviderLocalNames: map[addrs.Provider]string{}, Variables: map[string]*Variable{}, Locals: map[string]*Local{}, Outputs: map[string]*Output{}, @@ -105,6 +107,9 @@ func NewModule(primaryFiles, overrideFiles []*File) (*Module, hcl.Diagnostics) { diags = append(diags, checkModuleExperiments(mod)...) + // Generate the FQN -> LocalProviderName map + mod.gatherProviderLocalNames() + return mod, diags } @@ -170,11 +175,14 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics { } for _, reqd := range file.RequiredProviders { - // TODO: once the remaining provider source functionality is - // implemented, get addrs.Provider from source if set, or - // addrs.NewDefaultProvider(name) if not + // As an interim *testing* step, we will accept a source argument + // but assume that the source is a legacy provider. This allows us to + // exercise the provider local names -> fqn logic without changing + // terraform's behavior. if reqd.Source != "" { - panic("source is not yet supported") + // Fixme: once the rest of the provider source logic is implemented, + // update this to get the addrs.Provider by using + // addrs.ParseProviderSourceString() } fqn := addrs.NewLegacyProvider(reqd.Name) if existing, exists := m.ProviderRequirements[reqd.Name]; exists { @@ -425,3 +433,27 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics { return diags } + +// gatherProviderLocalNames is a helper function that populatesA a map of +// provider FQNs -> provider local names. This information is useful for +// user-facing output, which should include both the FQN and LocalName. It must +// only be populated after the module has been parsed. +func (m *Module) gatherProviderLocalNames() { + providers := make(map[addrs.Provider]string) + for k, v := range m.ProviderRequirements { + providers[v.Type] = k + } + m.ProviderLocalNames = providers +} + +// LocalNameForProvider returns the module-specific user-supplied local name for +// a given provider FQN, or the default local name if none was supplied. +func (m *Module) LocalNameForProvider(p addrs.Provider) string { + if existing, exists := m.ProviderLocalNames[p]; exists { + return existing + } else { + // If there isn't a map entry, fall back to the default: + // Type = LocalName + return p.Type + } +} diff --git a/configs/module_test.go b/configs/module_test.go new file mode 100644 index 000000000..4e58db4f5 --- /dev/null +++ b/configs/module_test.go @@ -0,0 +1,34 @@ +package configs + +import ( + "testing" + + "github.com/hashicorp/terraform/addrs" +) + +// TestNewModule_provider_fqns exercises module.gatherProviderLocalNames() +func TestNewModule_provider_local_name(t *testing.T) { + mod, diags := testModuleFromDir("testdata/providers-explicit-fqn") + if diags.HasErrors() { + t.Fatal(diags.Error()) + } + + // FIXME: while the provider source is set to "foo/test", terraform + // currently assumes everything is a legacy provider and the localname and + // type match. This test will be updated when provider source is fully + // implemented. + p := addrs.NewLegacyProvider("foo_test") + if name, exists := mod.ProviderLocalNames[p]; !exists { + t.Fatal("provider FQN foo/test not found") + } else { + if name != "foo_test" { + t.Fatalf("provider localname mismatch: got %s, want foo_test", name) + } + } + + // ensure the reverse lookup (fqn to local name) works as well + localName := mod.LocalNameForProvider(p) + if localName != "foo_test" { + t.Fatal("provider local name not found") + } +} diff --git a/configs/provider_requirements.go b/configs/provider_requirements.go index f2a78e501..7dbcf6ba3 100644 --- a/configs/provider_requirements.go +++ b/configs/provider_requirements.go @@ -9,8 +9,6 @@ import ( // RequiredProvider represents a declaration of a dependency on a particular // provider version without actually configuring that provider. This is used in // child modules that expect a provider to be passed in from their parent. -// -// TODO: "Source" is a placeholder for an attribute that is not yet supported. type RequiredProvider struct { Name string Source string // TODO @@ -43,6 +41,7 @@ func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Di Requirement: vc, }) case expr.Type().IsObjectType(): + ret := &RequiredProvider{Name: name} if expr.Type().HasAttribute("version") { vc := VersionConstraint{ DeclRange: attr.Range, @@ -58,14 +57,15 @@ func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Di Detail: "This string does not use correct version constraint syntax.", Subject: attr.Expr.Range().Ptr(), }) - reqs = append(reqs, &RequiredProvider{Name: name}) - return reqs, diags + } else { + vc.Required = constraints + ret.Requirement = vc } - vc.Required = constraints - reqs = append(reqs, &RequiredProvider{Name: name, Requirement: vc}) } - // No version - reqs = append(reqs, &RequiredProvider{Name: name}) + if expr.Type().HasAttribute("source") { + ret.Source = expr.GetAttr("source").AsString() + } + reqs = append(reqs, ret) default: // should not happen diags = append(diags, &hcl.Diagnostic{ diff --git a/configs/provider_requirements_test.go b/configs/provider_requirements_test.go new file mode 100644 index 000000000..a51149e53 --- /dev/null +++ b/configs/provider_requirements_test.go @@ -0,0 +1,185 @@ +package configs + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + version "github.com/hashicorp/go-version" + "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/hcltest" + "github.com/zclconf/go-cty/cty" +) + +var ( + ignoreUnexported = cmpopts.IgnoreUnexported(version.Constraint{}) + comparer = cmp.Comparer(func(x, y RequiredProvider) bool { + if x.Name != y.Name { + return false + } + if x.Source != y.Source { + return false + } + if x.Requirement.Required.String() != y.Requirement.Required.String() { + return false + } + return true + }) +) + +func TestDecodeRequiredProvidersBlock_legacy(t *testing.T) { + block := &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "default": { + Name: "default", + Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), + }, + }, + }), + } + + want := &RequiredProvider{ + Name: "default", + Requirement: testVC("1.0.0"), + } + + got, diags := decodeRequiredProvidersBlock(block) + if diags.HasErrors() { + t.Fatalf("unexpected error") + } + if len(got) != 1 { + t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) + } + if !cmp.Equal(got[0], want, ignoreUnexported, comparer) { + t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer)) + } +} + +func TestDecodeRequiredProvidersBlock_provider_source(t *testing.T) { + block := &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("mycloud/test"), + "version": cty.StringVal("2.0.0"), + })), + }, + }, + }), + } + + want := &RequiredProvider{ + Name: "my_test", + Source: "mycloud/test", + Requirement: testVC("2.0.0"), + } + got, diags := decodeRequiredProvidersBlock(block) + if diags.HasErrors() { + t.Fatalf("unexpected error") + } + if len(got) != 1 { + t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) + } + if !cmp.Equal(got[0], want, ignoreUnexported, comparer) { + t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer)) + } +} + +func TestDecodeRequiredProvidersBlock_mixed(t *testing.T) { + block := &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "legacy": { + Name: "legacy", + Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), + }, + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("mycloud/test"), + "version": cty.StringVal("2.0.0"), + })), + }, + }, + }), + } + + want := []*RequiredProvider{ + { + Name: "legacy", + Requirement: testVC("1.0.0"), + }, + { + Name: "my_test", + Source: "mycloud/test", + Requirement: testVC("2.0.0"), + }, + } + + got, diags := decodeRequiredProvidersBlock(block) + if diags.HasErrors() { + t.Fatalf("unexpected error") + } + if len(got) != 2 { + t.Fatalf("wrong number of results, got %d, wanted 2", len(got)) + } + for i, rp := range want { + if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) { + t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer)) + } + } +} + +func TestDecodeRequiredProvidersBlock_version_error(t *testing.T) { + block := &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("mycloud/test"), + "version": cty.StringVal("invalid"), + })), + }, + }, + }), + } + + want := []*RequiredProvider{ + { + Name: "my_test", + Source: "mycloud/test", + }, + } + + got, diags := decodeRequiredProvidersBlock(block) + if !diags.HasErrors() { + t.Fatalf("expected error, got success") + } else { + fmt.Printf(diags[0].Summary) + } + if len(got) != 1 { + t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) + } + for i, rp := range want { + if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) { + t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer)) + } + } +} + +func testVC(ver string) VersionConstraint { + constraint, _ := version.NewConstraint(ver) + return VersionConstraint{ + Required: constraint, + DeclRange: hcl.Range{}, + } +}