From 7ca7b1f0fe69adf12d30cebd7063345618b5aa42 Mon Sep 17 00:00:00 2001 From: Alisdair McDiarmid Date: Fri, 24 Apr 2020 10:54:24 -0400 Subject: [PATCH] configs: Simplify required_providers blocks We now permit at most one `required_providers` block per module (except for overrides). This prevents users (and Terraform) from struggling to understand how to merge multiple `required_providers` configurations, with `version` and `source` attributes split across multiple blocks. Because only one `required_providers` block is permitted, there is no need to concatenate version constraints and resolve them. This allows us to simplify the structs used to represent provider requirements, aligning more closely with other structs in this package. This commit also fixes a semantic use-before-initialize bug, where resources defined before a `required_providers` block would be unable to use its source attribute. We achieve this by processing the module's `required_providers` configuration (and overrides) before resources. Overrides for `required_providers` work as before, replacing the entire block per provider. --- configs/config.go | 24 +- configs/module.go | 104 ++--- configs/module_merge.go | 23 +- configs/module_merge_test.go | 96 ----- configs/module_test.go | 92 ++++ configs/parser_config.go | 2 +- configs/provider_requirements.go | 83 ++-- configs/provider_requirements_test.go | 405 +++++++++++------- .../multiple-required-providers/a.tf | 7 + .../multiple-required-providers/b.tf | 7 + .../required-providers-after-resource/main.tf | 3 + .../providers.tf | 8 + .../bar_provider_override.tf | 9 + .../required-providers-overrides/main.tf | 7 + .../required-providers-overrides/providers.tf | 11 + 15 files changed, 522 insertions(+), 359 deletions(-) create mode 100644 configs/testdata/invalid-modules/multiple-required-providers/a.tf create mode 100644 configs/testdata/invalid-modules/multiple-required-providers/b.tf create mode 100644 configs/testdata/valid-modules/required-providers-after-resource/main.tf create mode 100644 configs/testdata/valid-modules/required-providers-after-resource/providers.tf create mode 100644 configs/testdata/valid-modules/required-providers-overrides/bar_provider_override.tf create mode 100644 configs/testdata/valid-modules/required-providers-overrides/main.tf create mode 100644 configs/testdata/valid-modules/required-providers-overrides/providers.tf diff --git a/configs/config.go b/configs/config.go index dd562910d..39a0b717d 100644 --- a/configs/config.go +++ b/configs/config.go @@ -185,23 +185,21 @@ func (c *Config) addProviderRequirements(reqs getproviders.Requirements) hcl.Dia var diags hcl.Diagnostics // First we'll deal with the requirements directly in _our_ module... - for _, providerReqs := range c.Module.ProviderRequirements { + for _, providerReqs := range c.Module.ProviderRequirements.RequiredProviders { fqn := providerReqs.Type if _, ok := reqs[fqn]; !ok { // We'll at least have an unconstrained dependency then, but might // add to this in the loop below. reqs[fqn] = nil } - for _, constraintsSrc := range providerReqs.VersionConstraints { - // The model of version constraints in this package is still the - // old one using a different upstream module to represent versions, - // so we'll need to shim that out here for now. We assume this - // will always succeed because these constraints already succeeded - // parsing with the other constraint parser, which uses the same - // syntax. - constraints := getproviders.MustParseVersionConstraints(constraintsSrc.Required.String()) - reqs[fqn] = append(reqs[fqn], constraints...) - } + // The model of version constraints in this package is still the + // old one using a different upstream module to represent versions, + // so we'll need to shim that out here for now. We assume this + // will always succeed because these constraints already succeeded + // parsing with the other constraint parser, which uses the same + // syntax. + constraints := getproviders.MustParseVersionConstraints(providerReqs.Requirement.Required.String()) + reqs[fqn] = append(reqs[fqn], constraints...) } // Each resource in the configuration creates an *implicit* provider // dependency, though we'll only record it if there isn't already @@ -321,7 +319,7 @@ func (c *Config) ResolveAbsProviderAddr(addr addrs.ProviderConfig, inModule addr } var provider addrs.Provider - if providerReq, exists := c.Module.ProviderRequirements[addr.LocalName]; exists { + if providerReq, exists := c.Module.ProviderRequirements.RequiredProviders[addr.LocalName]; exists { provider = providerReq.Type } else { provider = addrs.ImpliedProviderForUnqualifiedType(addr.LocalName) @@ -343,7 +341,7 @@ func (c *Config) ResolveAbsProviderAddr(addr addrs.ProviderConfig, inModule addr // by checking for the provider in module.ProviderRequirements and falling // back to addrs.NewDefaultProvider if it is not found. func (c *Config) ProviderForConfigAddr(addr addrs.LocalProviderConfig) addrs.Provider { - if provider, exists := c.Module.ProviderRequirements[addr.LocalName]; exists { + if provider, exists := c.Module.ProviderRequirements.RequiredProviders[addr.LocalName]; exists { return provider.Type } return c.ResolveAbsProviderAddr(addr, addrs.RootModule).Provider diff --git a/configs/module.go b/configs/module.go index 97c9082c7..8f7dc10d9 100644 --- a/configs/module.go +++ b/configs/module.go @@ -7,7 +7,6 @@ import ( "github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/experiments" - "github.com/hashicorp/terraform/tfdiags" ) // Module is a container for a set of configuration constructs that are @@ -31,7 +30,7 @@ type Module struct { Backend *Backend ProviderConfigs map[string]*Provider - ProviderRequirements map[string]ProviderRequirements + ProviderRequirements *RequiredProviders ProviderLocalNames map[addrs.Provider]string ProviderMetas map[addrs.Provider]*ProviderMeta @@ -64,7 +63,7 @@ type File struct { Backends []*Backend ProviderConfigs []*Provider ProviderMetas []*ProviderMeta - RequiredProviders []*RequiredProvider + RequiredProviders []*RequiredProviders Variables []*Variable Locals []*Local @@ -87,16 +86,50 @@ type File struct { func NewModule(primaryFiles, overrideFiles []*File) (*Module, hcl.Diagnostics) { var diags hcl.Diagnostics mod := &Module{ - ProviderConfigs: map[string]*Provider{}, - ProviderRequirements: map[string]ProviderRequirements{}, - ProviderLocalNames: map[addrs.Provider]string{}, - Variables: map[string]*Variable{}, - Locals: map[string]*Local{}, - Outputs: map[string]*Output{}, - ModuleCalls: map[string]*ModuleCall{}, - ManagedResources: map[string]*Resource{}, - DataResources: map[string]*Resource{}, - ProviderMetas: map[addrs.Provider]*ProviderMeta{}, + ProviderConfigs: map[string]*Provider{}, + ProviderLocalNames: map[addrs.Provider]string{}, + Variables: map[string]*Variable{}, + Locals: map[string]*Local{}, + Outputs: map[string]*Output{}, + ModuleCalls: map[string]*ModuleCall{}, + ManagedResources: map[string]*Resource{}, + DataResources: map[string]*Resource{}, + ProviderMetas: map[addrs.Provider]*ProviderMeta{}, + } + + // Process the required_providers blocks first, to ensure that all + // resources have access to the correct provider FQNs + for _, file := range primaryFiles { + for _, r := range file.RequiredProviders { + if mod.ProviderRequirements != nil { + diags = append(diags, &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Duplicate required providers configuration", + Detail: fmt.Sprintf("A module may have only one required providers configuration. The required providers were previously configured at %s.", mod.ProviderRequirements.DeclRange), + Subject: &r.DeclRange, + }) + continue + } + mod.ProviderRequirements = r + } + } + + // If no required_providers block is configured, create a useful empty + // state to reduce nil checks elsewhere + if mod.ProviderRequirements == nil { + mod.ProviderRequirements = &RequiredProviders{ + RequiredProviders: make(map[string]*RequiredProvider), + } + } + + // Any required_providers blocks in override files replace the entire + // block for each provider + for _, file := range overrideFiles { + for _, override := range file.RequiredProviders { + for name, rp := range override.RequiredProviders { + mod.ProviderRequirements.RequiredProviders[name] = rp + } + } } for _, file := range primaryFiles { @@ -178,35 +211,6 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics { m.ProviderConfigs[key] = pc } - for _, reqd := range file.RequiredProviders { - var fqn addrs.Provider - if reqd.Source.SourceStr != "" { - var sourceDiags tfdiags.Diagnostics - fqn, sourceDiags = addrs.ParseProviderSourceString(reqd.Source.SourceStr) - hclDiags := sourceDiags.ToHCL() - // The diagnostics from ParseProviderSourceString don't contain - // source location information because it has no context to compute - // them from, and so we'll add those in quickly here before we - // return. - for _, diag := range hclDiags { - if diag.Subject == nil { - diag.Subject = reqd.Source.DeclRange.Ptr() - } - } - diags = append(diags, hclDiags...) - } else { - fqn = addrs.ImpliedProviderForUnqualifiedType(reqd.Name) - } - if existing, exists := m.ProviderRequirements[reqd.Name]; exists { - if existing.Type != fqn { - panic("provider fqn mismatch") - } - existing.VersionConstraints = append(existing.VersionConstraints, reqd.Requirement) - } else { - m.ProviderRequirements[reqd.Name] = ProviderRequirements{Type: fqn, VersionConstraints: []VersionConstraint{reqd.Requirement}} - } - } - for _, pm := range file.ProviderMetas { provider := m.ProviderForLocalConfig(addrs.LocalProviderConfig{LocalName: pm.Provider}) if existing, exists := m.ProviderMetas[provider]; exists { @@ -283,7 +287,7 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics { // set the provider FQN for the resource if r.ProviderConfigRef != nil { - if existing, exists := m.ProviderRequirements[r.ProviderConfigAddr().LocalName]; exists { + if existing, exists := m.ProviderRequirements.RequiredProviders[r.ProviderConfigAddr().LocalName]; exists { r.Provider = existing.Type } else { r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName) @@ -308,7 +312,7 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics { // set the provider FQN for the resource if r.ProviderConfigRef != nil { - if existing, exists := m.ProviderRequirements[r.ProviderConfigAddr().LocalName]; exists { + if existing, exists := m.ProviderRequirements.RequiredProviders[r.ProviderConfigAddr().LocalName]; exists { r.Provider = existing.Type } else { r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName) @@ -382,10 +386,6 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics { } } - if len(file.RequiredProviders) != 0 { - mergeProviderVersionConstraints(m.ProviderRequirements, file.RequiredProviders) - } - for _, v := range file.Variables { existing, exists := m.Variables[v.Name] if !exists { @@ -458,7 +458,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics { }) continue } - mergeDiags := existing.merge(r, m.ProviderRequirements) + mergeDiags := existing.merge(r, m.ProviderRequirements.RequiredProviders) diags = append(diags, mergeDiags...) } @@ -474,7 +474,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics { }) continue } - mergeDiags := existing.merge(r, m.ProviderRequirements) + mergeDiags := existing.merge(r, m.ProviderRequirements.RequiredProviders) diags = append(diags, mergeDiags...) } @@ -487,7 +487,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics { // only be populated after the module has been parsed. func (m *Module) gatherProviderLocalNames() { providers := make(map[addrs.Provider]string) - for k, v := range m.ProviderRequirements { + for k, v := range m.ProviderRequirements.RequiredProviders { providers[v.Type] = k } m.ProviderLocalNames = providers @@ -507,7 +507,7 @@ func (m *Module) LocalNameForProvider(p addrs.Provider) string { // ProviderForLocalConfig returns the provider FQN for a given LocalProviderConfig func (m *Module) ProviderForLocalConfig(pc addrs.LocalProviderConfig) addrs.Provider { - if provider, exists := m.ProviderRequirements[pc.LocalName]; exists { + if provider, exists := m.ProviderRequirements.RequiredProviders[pc.LocalName]; exists { return provider.Type } return addrs.ImpliedProviderForUnqualifiedType(pc.LocalName) diff --git a/configs/module_merge.go b/configs/module_merge.go index 045e02344..bf3fd8e49 100644 --- a/configs/module_merge.go +++ b/configs/module_merge.go @@ -35,25 +35,6 @@ func (p *Provider) merge(op *Provider) hcl.Diagnostics { return diags } -func mergeProviderVersionConstraints(recv map[string]ProviderRequirements, ovrd []*RequiredProvider) { - // Any provider name that's mentioned in the override gets nilled out in - // our map so that we'll rebuild it below. Any provider not mentioned is - // left unchanged. - for _, reqd := range ovrd { - delete(recv, reqd.Name) - } - for _, reqd := range ovrd { - var fqn addrs.Provider - if reqd.Source.SourceStr != "" { - // any errors parsing the source string will have already been captured. - fqn, _ = addrs.ParseProviderSourceString(reqd.Source.SourceStr) - } else { - fqn = addrs.ImpliedProviderForUnqualifiedType(reqd.Name) - } - recv[reqd.Name] = ProviderRequirements{Type: fqn, VersionConstraints: []VersionConstraint{reqd.Requirement}} - } -} - func (v *Variable) merge(ov *Variable) hcl.Diagnostics { var diags hcl.Diagnostics @@ -197,7 +178,7 @@ func (mc *ModuleCall) merge(omc *ModuleCall) hcl.Diagnostics { return diags } -func (r *Resource) merge(or *Resource, prs map[string]ProviderRequirements) hcl.Diagnostics { +func (r *Resource) merge(or *Resource, rps map[string]*RequiredProvider) hcl.Diagnostics { var diags hcl.Diagnostics if r.Mode != or.Mode { @@ -215,7 +196,7 @@ func (r *Resource) merge(or *Resource, prs map[string]ProviderRequirements) hcl. if or.ProviderConfigRef != nil { r.ProviderConfigRef = or.ProviderConfigRef - if existing, exists := prs[or.ProviderConfigRef.Name]; exists { + if existing, exists := rps[or.ProviderConfigRef.Name]; exists { r.Provider = existing.Type } else { r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigRef.Name) diff --git a/configs/module_merge_test.go b/configs/module_merge_test.go index 082c24447..d581ac836 100644 --- a/configs/module_merge_test.go +++ b/configs/module_merge_test.go @@ -3,7 +3,6 @@ package configs import ( "testing" - version "github.com/hashicorp/go-version" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/gohcl" "github.com/hashicorp/terraform/addrs" @@ -232,98 +231,3 @@ func TestModuleOverrideResourceFQNs(t *testing.T) { t.Fatalf("wrong result: found provider config ref %s, expected nil", got.ProviderConfigRef) } } - -func TestMergeProviderVersionConstraints(t *testing.T) { - v1, _ := version.NewConstraint("1.0.0") - vc1 := VersionConstraint{ - Required: v1, - } - v2, _ := version.NewConstraint("2.0.0") - vc2 := VersionConstraint{ - Required: v2, - } - - tests := map[string]struct { - Input map[string]ProviderRequirements - Override []*RequiredProvider - Want map[string]ProviderRequirements - }{ - "basic merge": { - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.Provider{Type: "random"}, - VersionConstraints: []VersionConstraint{}, - }, - }, - []*RequiredProvider{ - &RequiredProvider{ - Name: "null", - Requirement: VersionConstraint{}, - }, - }, - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.Provider{Type: "random"}, - VersionConstraints: []VersionConstraint{}, - }, - "null": ProviderRequirements{ - Type: addrs.NewDefaultProvider("null"), - VersionConstraints: []VersionConstraint{ - VersionConstraint{ - Required: version.Constraints(nil), - DeclRange: hcl.Range{}, - }, - }, - }, - }, - }, - "override version constraint": { - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.Provider{Type: "random"}, - VersionConstraints: []VersionConstraint{vc1}, - }, - }, - []*RequiredProvider{ - &RequiredProvider{ - Name: "random", - Requirement: vc2, - }, - }, - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.NewDefaultProvider("random"), - VersionConstraints: []VersionConstraint{vc2}, - }, - }, - }, - "merge with source constraint": { - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.Provider{Type: "random"}, - VersionConstraints: []VersionConstraint{vc1}, - }, - }, - []*RequiredProvider{ - &RequiredProvider{ - Name: "random", - Source: Source{SourceStr: "hashicorp/random"}, - Requirement: vc2, - }, - }, - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.NewDefaultProvider("random"), - VersionConstraints: []VersionConstraint{vc2}, - }, - }, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - mergeProviderVersionConstraints(test.Input, test.Override) - assertResultDeepEqual(t, test.Input, test.Want) - }) - } -} diff --git a/configs/module_test.go b/configs/module_test.go index 3d9c47bdc..d5e3659ad 100644 --- a/configs/module_test.go +++ b/configs/module_test.go @@ -1,6 +1,7 @@ package configs import ( + "strings" "testing" "github.com/hashicorp/terraform/addrs" @@ -111,3 +112,94 @@ func TestProviderForLocalConfig(t *testing.T) { t.Fatalf("wrong result! got %#v, want %#v\n", got, want) } } + +// At most one required_providers block per module is permitted. +func TestModule_required_providers_multiple(t *testing.T) { + _, diags := testModuleFromDir("testdata/invalid-modules/multiple-required-providers") + if !diags.HasErrors() { + t.Fatal("module should have error diags, but does not") + } + + want := `Duplicate required providers configuration` + if got := diags.Error(); !strings.Contains(got, want) { + t.Fatalf("expected error to contain %q\nerror was:\n%s", want, got) + } +} + +// A module may have required_providers configured in files loaded later than +// resources. These provider settings should still be reflected in the +// resources' configuration. +func TestModule_required_providers_after_resource(t *testing.T) { + mod, diags := testModuleFromDir("testdata/valid-modules/required-providers-after-resource") + if diags.HasErrors() { + t.Fatal(diags.Error()) + } + + want := addrs.NewProvider(addrs.DefaultRegistryHost, "foo", "test") + + req, exists := mod.ProviderRequirements.RequiredProviders["test"] + if !exists { + t.Fatal("no provider requirements found for \"test\"") + } + if req.Type != want { + t.Errorf("wrong provider addr for \"test\"\ngot: %s\nwant: %s", + req.Type, want, + ) + } + + if got := mod.ManagedResources["test_instance.my-instance"].Provider; !got.Equals(want) { + t.Errorf("wrong provider addr for \"test_instance.my-instance\"\ngot: %s\nwant: %s", + got, want, + ) + } +} + +// We support overrides for required_providers blocks, which should replace the +// entire block for each provider localname, leaving other blocks unaffected. +// This should also be reflected in any resources in the module using this +// provider. +func TestModule_required_provider_overrides(t *testing.T) { + mod, diags := testModuleFromDir("testdata/valid-modules/required-providers-overrides") + if diags.HasErrors() { + t.Fatal(diags.Error()) + } + + // The foo provider and resource should be unaffected + want := addrs.NewProvider(addrs.DefaultRegistryHost, "acme", "foo") + req, exists := mod.ProviderRequirements.RequiredProviders["foo"] + if !exists { + t.Fatal("no provider requirements found for \"foo\"") + } + if req.Type != want { + t.Errorf("wrong provider addr for \"foo\"\ngot: %s\nwant: %s", + req.Type, want, + ) + } + if got := mod.ManagedResources["foo_thing.ft"].Provider; !got.Equals(want) { + t.Errorf("wrong provider addr for \"foo_thing.ft\"\ngot: %s\nwant: %s", + got, want, + ) + } + + // The bar provider and resource should be using the override config + want = addrs.NewProvider(addrs.DefaultRegistryHost, "blorp", "bar") + req, exists = mod.ProviderRequirements.RequiredProviders["bar"] + if !exists { + t.Fatal("no provider requirements found for \"bar\"") + } + if req.Type != want { + t.Errorf("wrong provider addr for \"bar\"\ngot: %s\nwant: %s", + req.Type, want, + ) + } + if gotVer, wantVer := req.Requirement.Required.String(), "~>2.0.0"; gotVer != wantVer { + t.Errorf("wrong provider version constraint for \"bar\"\ngot: %s\nwant: %s", + gotVer, wantVer, + ) + } + if got := mod.ManagedResources["bar_thing.bt"].Provider; !got.Equals(want) { + t.Errorf("wrong provider addr for \"bar_thing.bt\"\ngot: %s\nwant: %s", + got, want, + ) + } +} diff --git a/configs/parser_config.go b/configs/parser_config.go index e930a093c..354b96a72 100644 --- a/configs/parser_config.go +++ b/configs/parser_config.go @@ -75,7 +75,7 @@ func (p *Parser) loadConfigFile(path string, override bool) (*File, hcl.Diagnost case "required_providers": reqs, reqsDiags := decodeRequiredProvidersBlock(innerBlock) diags = append(diags, reqsDiags...) - file.RequiredProviders = append(file.RequiredProviders, reqs...) + file.RequiredProviders = append(file.RequiredProviders, reqs) case "provider_meta": providerCfg, cfgDiags := decodeProviderMetaBlock(innerBlock) diff --git a/configs/provider_requirements.go b/configs/provider_requirements.go index b535afe0c..d4218a6fc 100644 --- a/configs/provider_requirements.go +++ b/configs/provider_requirements.go @@ -12,41 +12,40 @@ import ( // parent. type RequiredProvider struct { Name string - Source Source + Type addrs.Provider Requirement VersionConstraint + DeclRange hcl.Range } -type Source struct { - SourceStr string - DeclRange hcl.Range +type RequiredProviders struct { + RequiredProviders map[string]*RequiredProvider + DeclRange hcl.Range } -// ProviderRequirements represents provider version constraints from -// required_providers blocks. -type ProviderRequirements struct { - Type addrs.Provider - VersionConstraints []VersionConstraint -} - -func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Diagnostics) { +func decodeRequiredProvidersBlock(block *hcl.Block) (*RequiredProviders, hcl.Diagnostics) { attrs, diags := block.Body.JustAttributes() - var reqs []*RequiredProvider + ret := &RequiredProviders{ + RequiredProviders: make(map[string]*RequiredProvider), + DeclRange: block.DefRange, + } for name, attr := range attrs { expr, err := attr.Expr.Value(nil) if err != nil { diags = append(diags, err...) } + rp := &RequiredProvider{ + Name: name, + DeclRange: attr.Expr.Range(), + } + switch { case expr.Type().IsPrimitiveType(): vc, reqDiags := decodeVersionConstraint(attr) diags = append(diags, reqDiags...) - reqs = append(reqs, &RequiredProvider{ - Name: name, - Requirement: vc, - }) + rp.Requirement = vc + case expr.Type().IsObjectType(): - ret := &RequiredProvider{Name: name} if expr.Type().HasAttribute("version") { vc := VersionConstraint{ DeclRange: attr.Range, @@ -64,25 +63,55 @@ func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Di }) } else { vc.Required = constraints - ret.Requirement = vc + rp.Requirement = vc } } if expr.Type().HasAttribute("source") { - ret.Source.SourceStr = expr.GetAttr("source").AsString() - ret.Source.DeclRange = attr.Range + fqn, sourceDiags := addrs.ParseProviderSourceString(expr.GetAttr("source").AsString()) + + if sourceDiags.HasErrors() { + hclDiags := sourceDiags.ToHCL() + // The diagnostics from ParseProviderSourceString don't contain + // source location information because it has no context to compute + // them from, and so we'll add those in quickly here before we + // return. + for _, diag := range hclDiags { + if diag.Subject == nil { + diag.Subject = attr.Expr.Range().Ptr() + } + } + diags = append(diags, hclDiags...) + } else { + rp.Type = fqn + } } - reqs = append(reqs, ret) + default: // should not happen diags = append(diags, &hcl.Diagnostic{ Severity: hcl.DiagError, - Summary: "Invalid provider_requirements syntax", - Detail: "provider_requirements entries must be strings or objects.", + Summary: "Invalid required_providers syntax", + Detail: "required_providers entries must be strings or objects.", Subject: attr.Expr.Range().Ptr(), }) - reqs = append(reqs, &RequiredProvider{Name: name}) - return reqs, diags } + + if rp.Type.IsZero() { + pType, err := addrs.ParseProviderPart(rp.Name) + if err != nil { + diags = append(diags, &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Invalid provider name", + Detail: err.Error(), + Subject: attr.Expr.Range().Ptr(), + }) + } else { + rp.Type = addrs.ImpliedProviderForUnqualifiedType(pType) + } + } + + ret.RequiredProviders[rp.Name] = rp } - return reqs, diags + + return ret, diags } diff --git a/configs/provider_requirements_test.go b/configs/provider_requirements_test.go index 4abcc7ac3..a04714f55 100644 --- a/configs/provider_requirements_test.go +++ b/configs/provider_requirements_test.go @@ -1,8 +1,6 @@ package configs import ( - "fmt" - "sort" "testing" "github.com/google/go-cmp/cmp" @@ -10,6 +8,7 @@ import ( version "github.com/hashicorp/go-version" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/hcltest" + "github.com/hashicorp/terraform/addrs" "github.com/zclconf/go-cty/cty" ) @@ -19,173 +18,281 @@ var ( if x.Name != y.Name { return false } - if x.Source != y.Source { + if x.Type != y.Type { return false } if x.Requirement.Required.String() != y.Requirement.Required.String() { return false } + if x.DeclRange != y.DeclRange { + return false + } return true }) -) - -func TestDecodeRequiredProvidersBlock_legacy(t *testing.T) { - block := &hcl.Block{ - Type: "required_providers", - Body: hcltest.MockBody(&hcl.BodyContent{ - Attributes: hcl.Attributes{ - "default": { - Name: "default", - Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), - }, - }, - }), - } - - want := &RequiredProvider{ - Name: "default", - Requirement: testVC("1.0.0"), - } - - got, diags := decodeRequiredProvidersBlock(block) - if diags.HasErrors() { - t.Fatalf("unexpected error") - } - if len(got) != 1 { - t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) - } - if !cmp.Equal(got[0], want, ignoreUnexported, comparer) { - t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer)) - } -} - -func TestDecodeRequiredProvidersBlock_provider_source(t *testing.T) { - mockRange := hcl.Range{ + blockRange = hcl.Range{ Filename: "mock.tf", Start: hcl.Pos{Line: 3, Column: 12, Byte: 27}, End: hcl.Pos{Line: 3, Column: 19, Byte: 34}, } + mockRange = hcl.Range{ + Filename: "MockExprLiteral", + } +) - block := &hcl.Block{ - Type: "required_providers", - Body: hcltest.MockBody(&hcl.BodyContent{ - Attributes: hcl.Attributes{ - "my_test": { - Name: "my_test", - Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ - "source": cty.StringVal("mycloud/test"), - "version": cty.StringVal("2.0.0"), - })), - Range: mockRange, - }, +func TestDecodeRequiredProvidersBlock(t *testing.T) { + tests := map[string]struct { + Block *hcl.Block + Want *RequiredProviders + Error string + }{ + "legacy": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "default": { + Name: "default", + Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), + }, + }, + }), + DefRange: blockRange, }, - }), - } - - want := &RequiredProvider{ - Name: "my_test", - Source: Source{SourceStr: "mycloud/test", DeclRange: mockRange}, - Requirement: testVC("2.0.0"), - } - got, diags := decodeRequiredProvidersBlock(block) - if diags.HasErrors() { - t.Fatalf("unexpected error") - } - if len(got) != 1 { - t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) - } - if !cmp.Equal(got[0], want, ignoreUnexported, comparer) { - t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer)) - } -} - -func TestDecodeRequiredProvidersBlock_mixed(t *testing.T) { - block := &hcl.Block{ - Type: "required_providers", - Body: hcltest.MockBody(&hcl.BodyContent{ - Attributes: hcl.Attributes{ - "legacy": { - Name: "legacy", - Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), - }, - "my_test": { - Name: "my_test", - Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ - "source": cty.StringVal("mycloud/test"), - "version": cty.StringVal("2.0.0"), - })), + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "default": { + Name: "default", + Type: addrs.NewDefaultProvider("default"), + Requirement: testVC("1.0.0"), + DeclRange: mockRange, + }, }, + DeclRange: blockRange, }, - }), - } - - want := []*RequiredProvider{ - { - Name: "legacy", - Requirement: testVC("1.0.0"), }, - { - Name: "my_test", - Source: Source{SourceStr: "mycloud/test", DeclRange: hcl.Range{}}, - Requirement: testVC("2.0.0"), + "provider source": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("mycloud/test"), + "version": cty.StringVal("2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "my_test": { + Name: "my_test", + Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"), + Requirement: testVC("2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + }, + "mixed": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "legacy": { + Name: "legacy", + Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), + }, + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("mycloud/test"), + "version": cty.StringVal("2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "legacy": { + Name: "legacy", + Type: addrs.NewDefaultProvider("legacy"), + Requirement: testVC("1.0.0"), + DeclRange: mockRange, + }, + "my_test": { + Name: "my_test", + Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"), + Requirement: testVC("2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + }, + "version-only block": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "test": { + Name: "test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "version": cty.StringVal("~>2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "test": { + Name: "test", + Type: addrs.NewDefaultProvider("test"), + Requirement: testVC("~>2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + }, + "invalid source": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("some/invalid/provider/source/test"), + "version": cty.StringVal("~>2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "my_test": { + Name: "my_test", + Type: addrs.Provider{}, + Requirement: testVC("~>2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + Error: "Invalid provider source string", + }, + "localname is invalid provider name": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "version": cty.StringVal("~>2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "my_test": { + Name: "my_test", + Type: addrs.Provider{}, + Requirement: testVC("~>2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + Error: "Invalid provider name", + }, + "version constraint error": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("mycloud/test"), + "version": cty.StringVal("invalid"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "my_test": { + Name: "my_test", + Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + Error: "Invalid version constraint", + }, + "invalid required_providers attribute value": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "test": { + Name: "test", + Expr: hcltest.MockExprLiteral(cty.ListVal([]cty.Value{cty.StringVal("2.0.0")})), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "test": { + Name: "test", + Type: addrs.NewDefaultProvider("test"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + Error: "Invalid required_providers syntax", }, } - got, diags := decodeRequiredProvidersBlock(block) + for name, test := range tests { + t.Run(name, func(t *testing.T) { + got, diags := decodeRequiredProvidersBlock(test.Block) + if diags.HasErrors() { + if test.Error == "" { + t.Fatalf("unexpected error") + } + if gotErr := diags[0].Summary; gotErr != test.Error { + t.Errorf("wrong error, got %q, want %q", gotErr, test.Error) + } + } else if test.Error != "" { + t.Fatalf("expected error") + } - sort.SliceStable(got, func(i, j int) bool { - return got[i].Name < got[j].Name - }) - - if diags.HasErrors() { - t.Fatalf("unexpected error") - } - if len(got) != 2 { - t.Fatalf("wrong number of results, got %d, wanted 2", len(got)) - } - for i, rp := range want { - if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) { - t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer)) - } - } -} - -func TestDecodeRequiredProvidersBlock_version_error(t *testing.T) { - block := &hcl.Block{ - Type: "required_providers", - Body: hcltest.MockBody(&hcl.BodyContent{ - Attributes: hcl.Attributes{ - "my_test": { - Name: "my_test", - Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ - "source": cty.StringVal("mycloud/test"), - "version": cty.StringVal("invalid"), - })), - }, - }, - }), - } - - want := []*RequiredProvider{ - { - Name: "my_test", - Source: Source{SourceStr: "mycloud/test", DeclRange: hcl.Range{}}, - }, - } - - got, diags := decodeRequiredProvidersBlock(block) - if !diags.HasErrors() { - t.Fatalf("expected error, got success") - } else { - fmt.Printf(diags[0].Summary) - } - if len(got) != 1 { - t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) - } - for i, rp := range want { - if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) { - t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer)) - } + if !cmp.Equal(got, test.Want, ignoreUnexported, comparer) { + t.Fatalf("wrong result:\n %s", cmp.Diff(got, test.Want, ignoreUnexported, comparer)) + } + }) } } diff --git a/configs/testdata/invalid-modules/multiple-required-providers/a.tf b/configs/testdata/invalid-modules/multiple-required-providers/a.tf new file mode 100644 index 000000000..a607ddd1c --- /dev/null +++ b/configs/testdata/invalid-modules/multiple-required-providers/a.tf @@ -0,0 +1,7 @@ +terraform { + required_providers { + bar = { + version = "~>1.0.0" + } + } +} diff --git a/configs/testdata/invalid-modules/multiple-required-providers/b.tf b/configs/testdata/invalid-modules/multiple-required-providers/b.tf new file mode 100644 index 000000000..7672e9249 --- /dev/null +++ b/configs/testdata/invalid-modules/multiple-required-providers/b.tf @@ -0,0 +1,7 @@ +terraform { + required_providers { + foo = { + version = "~>2.0.0" + } + } +} diff --git a/configs/testdata/valid-modules/required-providers-after-resource/main.tf b/configs/testdata/valid-modules/required-providers-after-resource/main.tf new file mode 100644 index 000000000..8d40ec4b9 --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-after-resource/main.tf @@ -0,0 +1,3 @@ +resource test_instance "my-instance" { + provider = test +} \ No newline at end of file diff --git a/configs/testdata/valid-modules/required-providers-after-resource/providers.tf b/configs/testdata/valid-modules/required-providers-after-resource/providers.tf new file mode 100644 index 000000000..687ef1bdd --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-after-resource/providers.tf @@ -0,0 +1,8 @@ +terraform { + required_providers { + test = { + source = "foo/test" + version = "~>1.0.0" + } + } +} \ No newline at end of file diff --git a/configs/testdata/valid-modules/required-providers-overrides/bar_provider_override.tf b/configs/testdata/valid-modules/required-providers-overrides/bar_provider_override.tf new file mode 100644 index 000000000..b83ef7494 --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-overrides/bar_provider_override.tf @@ -0,0 +1,9 @@ +terraform { + required_providers { + bar = { + source = "blorp/bar" + version = "~>2.0.0" + } + } +} + diff --git a/configs/testdata/valid-modules/required-providers-overrides/main.tf b/configs/testdata/valid-modules/required-providers-overrides/main.tf new file mode 100644 index 000000000..f7e58f2e2 --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-overrides/main.tf @@ -0,0 +1,7 @@ +resource bar_thing "bt" { + provider = bar +} + +resource foo_thing "ft" { + provider = foo +} diff --git a/configs/testdata/valid-modules/required-providers-overrides/providers.tf b/configs/testdata/valid-modules/required-providers-overrides/providers.tf new file mode 100644 index 000000000..9b1d897f4 --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-overrides/providers.tf @@ -0,0 +1,11 @@ +terraform { + required_providers { + bar = { + source = "acme/bar" + } + + foo = { + source = "acme/foo" + } + } +}