diff --git a/configs/config.go b/configs/config.go index dd562910d..39a0b717d 100644 --- a/configs/config.go +++ b/configs/config.go @@ -185,23 +185,21 @@ func (c *Config) addProviderRequirements(reqs getproviders.Requirements) hcl.Dia var diags hcl.Diagnostics // First we'll deal with the requirements directly in _our_ module... - for _, providerReqs := range c.Module.ProviderRequirements { + for _, providerReqs := range c.Module.ProviderRequirements.RequiredProviders { fqn := providerReqs.Type if _, ok := reqs[fqn]; !ok { // We'll at least have an unconstrained dependency then, but might // add to this in the loop below. reqs[fqn] = nil } - for _, constraintsSrc := range providerReqs.VersionConstraints { - // The model of version constraints in this package is still the - // old one using a different upstream module to represent versions, - // so we'll need to shim that out here for now. We assume this - // will always succeed because these constraints already succeeded - // parsing with the other constraint parser, which uses the same - // syntax. - constraints := getproviders.MustParseVersionConstraints(constraintsSrc.Required.String()) - reqs[fqn] = append(reqs[fqn], constraints...) - } + // The model of version constraints in this package is still the + // old one using a different upstream module to represent versions, + // so we'll need to shim that out here for now. We assume this + // will always succeed because these constraints already succeeded + // parsing with the other constraint parser, which uses the same + // syntax. + constraints := getproviders.MustParseVersionConstraints(providerReqs.Requirement.Required.String()) + reqs[fqn] = append(reqs[fqn], constraints...) } // Each resource in the configuration creates an *implicit* provider // dependency, though we'll only record it if there isn't already @@ -321,7 +319,7 @@ func (c *Config) ResolveAbsProviderAddr(addr addrs.ProviderConfig, inModule addr } var provider addrs.Provider - if providerReq, exists := c.Module.ProviderRequirements[addr.LocalName]; exists { + if providerReq, exists := c.Module.ProviderRequirements.RequiredProviders[addr.LocalName]; exists { provider = providerReq.Type } else { provider = addrs.ImpliedProviderForUnqualifiedType(addr.LocalName) @@ -343,7 +341,7 @@ func (c *Config) ResolveAbsProviderAddr(addr addrs.ProviderConfig, inModule addr // by checking for the provider in module.ProviderRequirements and falling // back to addrs.NewDefaultProvider if it is not found. func (c *Config) ProviderForConfigAddr(addr addrs.LocalProviderConfig) addrs.Provider { - if provider, exists := c.Module.ProviderRequirements[addr.LocalName]; exists { + if provider, exists := c.Module.ProviderRequirements.RequiredProviders[addr.LocalName]; exists { return provider.Type } return c.ResolveAbsProviderAddr(addr, addrs.RootModule).Provider diff --git a/configs/module.go b/configs/module.go index 97c9082c7..8f7dc10d9 100644 --- a/configs/module.go +++ b/configs/module.go @@ -7,7 +7,6 @@ import ( "github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/experiments" - "github.com/hashicorp/terraform/tfdiags" ) // Module is a container for a set of configuration constructs that are @@ -31,7 +30,7 @@ type Module struct { Backend *Backend ProviderConfigs map[string]*Provider - ProviderRequirements map[string]ProviderRequirements + ProviderRequirements *RequiredProviders ProviderLocalNames map[addrs.Provider]string ProviderMetas map[addrs.Provider]*ProviderMeta @@ -64,7 +63,7 @@ type File struct { Backends []*Backend ProviderConfigs []*Provider ProviderMetas []*ProviderMeta - RequiredProviders []*RequiredProvider + RequiredProviders []*RequiredProviders Variables []*Variable Locals []*Local @@ -87,16 +86,50 @@ type File struct { func NewModule(primaryFiles, overrideFiles []*File) (*Module, hcl.Diagnostics) { var diags hcl.Diagnostics mod := &Module{ - ProviderConfigs: map[string]*Provider{}, - ProviderRequirements: map[string]ProviderRequirements{}, - ProviderLocalNames: map[addrs.Provider]string{}, - Variables: map[string]*Variable{}, - Locals: map[string]*Local{}, - Outputs: map[string]*Output{}, - ModuleCalls: map[string]*ModuleCall{}, - ManagedResources: map[string]*Resource{}, - DataResources: map[string]*Resource{}, - ProviderMetas: map[addrs.Provider]*ProviderMeta{}, + ProviderConfigs: map[string]*Provider{}, + ProviderLocalNames: map[addrs.Provider]string{}, + Variables: map[string]*Variable{}, + Locals: map[string]*Local{}, + Outputs: map[string]*Output{}, + ModuleCalls: map[string]*ModuleCall{}, + ManagedResources: map[string]*Resource{}, + DataResources: map[string]*Resource{}, + ProviderMetas: map[addrs.Provider]*ProviderMeta{}, + } + + // Process the required_providers blocks first, to ensure that all + // resources have access to the correct provider FQNs + for _, file := range primaryFiles { + for _, r := range file.RequiredProviders { + if mod.ProviderRequirements != nil { + diags = append(diags, &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Duplicate required providers configuration", + Detail: fmt.Sprintf("A module may have only one required providers configuration. The required providers were previously configured at %s.", mod.ProviderRequirements.DeclRange), + Subject: &r.DeclRange, + }) + continue + } + mod.ProviderRequirements = r + } + } + + // If no required_providers block is configured, create a useful empty + // state to reduce nil checks elsewhere + if mod.ProviderRequirements == nil { + mod.ProviderRequirements = &RequiredProviders{ + RequiredProviders: make(map[string]*RequiredProvider), + } + } + + // Any required_providers blocks in override files replace the entire + // block for each provider + for _, file := range overrideFiles { + for _, override := range file.RequiredProviders { + for name, rp := range override.RequiredProviders { + mod.ProviderRequirements.RequiredProviders[name] = rp + } + } } for _, file := range primaryFiles { @@ -178,35 +211,6 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics { m.ProviderConfigs[key] = pc } - for _, reqd := range file.RequiredProviders { - var fqn addrs.Provider - if reqd.Source.SourceStr != "" { - var sourceDiags tfdiags.Diagnostics - fqn, sourceDiags = addrs.ParseProviderSourceString(reqd.Source.SourceStr) - hclDiags := sourceDiags.ToHCL() - // The diagnostics from ParseProviderSourceString don't contain - // source location information because it has no context to compute - // them from, and so we'll add those in quickly here before we - // return. - for _, diag := range hclDiags { - if diag.Subject == nil { - diag.Subject = reqd.Source.DeclRange.Ptr() - } - } - diags = append(diags, hclDiags...) - } else { - fqn = addrs.ImpliedProviderForUnqualifiedType(reqd.Name) - } - if existing, exists := m.ProviderRequirements[reqd.Name]; exists { - if existing.Type != fqn { - panic("provider fqn mismatch") - } - existing.VersionConstraints = append(existing.VersionConstraints, reqd.Requirement) - } else { - m.ProviderRequirements[reqd.Name] = ProviderRequirements{Type: fqn, VersionConstraints: []VersionConstraint{reqd.Requirement}} - } - } - for _, pm := range file.ProviderMetas { provider := m.ProviderForLocalConfig(addrs.LocalProviderConfig{LocalName: pm.Provider}) if existing, exists := m.ProviderMetas[provider]; exists { @@ -283,7 +287,7 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics { // set the provider FQN for the resource if r.ProviderConfigRef != nil { - if existing, exists := m.ProviderRequirements[r.ProviderConfigAddr().LocalName]; exists { + if existing, exists := m.ProviderRequirements.RequiredProviders[r.ProviderConfigAddr().LocalName]; exists { r.Provider = existing.Type } else { r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName) @@ -308,7 +312,7 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics { // set the provider FQN for the resource if r.ProviderConfigRef != nil { - if existing, exists := m.ProviderRequirements[r.ProviderConfigAddr().LocalName]; exists { + if existing, exists := m.ProviderRequirements.RequiredProviders[r.ProviderConfigAddr().LocalName]; exists { r.Provider = existing.Type } else { r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName) @@ -382,10 +386,6 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics { } } - if len(file.RequiredProviders) != 0 { - mergeProviderVersionConstraints(m.ProviderRequirements, file.RequiredProviders) - } - for _, v := range file.Variables { existing, exists := m.Variables[v.Name] if !exists { @@ -458,7 +458,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics { }) continue } - mergeDiags := existing.merge(r, m.ProviderRequirements) + mergeDiags := existing.merge(r, m.ProviderRequirements.RequiredProviders) diags = append(diags, mergeDiags...) } @@ -474,7 +474,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics { }) continue } - mergeDiags := existing.merge(r, m.ProviderRequirements) + mergeDiags := existing.merge(r, m.ProviderRequirements.RequiredProviders) diags = append(diags, mergeDiags...) } @@ -487,7 +487,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics { // only be populated after the module has been parsed. func (m *Module) gatherProviderLocalNames() { providers := make(map[addrs.Provider]string) - for k, v := range m.ProviderRequirements { + for k, v := range m.ProviderRequirements.RequiredProviders { providers[v.Type] = k } m.ProviderLocalNames = providers @@ -507,7 +507,7 @@ func (m *Module) LocalNameForProvider(p addrs.Provider) string { // ProviderForLocalConfig returns the provider FQN for a given LocalProviderConfig func (m *Module) ProviderForLocalConfig(pc addrs.LocalProviderConfig) addrs.Provider { - if provider, exists := m.ProviderRequirements[pc.LocalName]; exists { + if provider, exists := m.ProviderRequirements.RequiredProviders[pc.LocalName]; exists { return provider.Type } return addrs.ImpliedProviderForUnqualifiedType(pc.LocalName) diff --git a/configs/module_merge.go b/configs/module_merge.go index 045e02344..bf3fd8e49 100644 --- a/configs/module_merge.go +++ b/configs/module_merge.go @@ -35,25 +35,6 @@ func (p *Provider) merge(op *Provider) hcl.Diagnostics { return diags } -func mergeProviderVersionConstraints(recv map[string]ProviderRequirements, ovrd []*RequiredProvider) { - // Any provider name that's mentioned in the override gets nilled out in - // our map so that we'll rebuild it below. Any provider not mentioned is - // left unchanged. - for _, reqd := range ovrd { - delete(recv, reqd.Name) - } - for _, reqd := range ovrd { - var fqn addrs.Provider - if reqd.Source.SourceStr != "" { - // any errors parsing the source string will have already been captured. - fqn, _ = addrs.ParseProviderSourceString(reqd.Source.SourceStr) - } else { - fqn = addrs.ImpliedProviderForUnqualifiedType(reqd.Name) - } - recv[reqd.Name] = ProviderRequirements{Type: fqn, VersionConstraints: []VersionConstraint{reqd.Requirement}} - } -} - func (v *Variable) merge(ov *Variable) hcl.Diagnostics { var diags hcl.Diagnostics @@ -197,7 +178,7 @@ func (mc *ModuleCall) merge(omc *ModuleCall) hcl.Diagnostics { return diags } -func (r *Resource) merge(or *Resource, prs map[string]ProviderRequirements) hcl.Diagnostics { +func (r *Resource) merge(or *Resource, rps map[string]*RequiredProvider) hcl.Diagnostics { var diags hcl.Diagnostics if r.Mode != or.Mode { @@ -215,7 +196,7 @@ func (r *Resource) merge(or *Resource, prs map[string]ProviderRequirements) hcl. if or.ProviderConfigRef != nil { r.ProviderConfigRef = or.ProviderConfigRef - if existing, exists := prs[or.ProviderConfigRef.Name]; exists { + if existing, exists := rps[or.ProviderConfigRef.Name]; exists { r.Provider = existing.Type } else { r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigRef.Name) diff --git a/configs/module_merge_test.go b/configs/module_merge_test.go index 082c24447..d581ac836 100644 --- a/configs/module_merge_test.go +++ b/configs/module_merge_test.go @@ -3,7 +3,6 @@ package configs import ( "testing" - version "github.com/hashicorp/go-version" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/gohcl" "github.com/hashicorp/terraform/addrs" @@ -232,98 +231,3 @@ func TestModuleOverrideResourceFQNs(t *testing.T) { t.Fatalf("wrong result: found provider config ref %s, expected nil", got.ProviderConfigRef) } } - -func TestMergeProviderVersionConstraints(t *testing.T) { - v1, _ := version.NewConstraint("1.0.0") - vc1 := VersionConstraint{ - Required: v1, - } - v2, _ := version.NewConstraint("2.0.0") - vc2 := VersionConstraint{ - Required: v2, - } - - tests := map[string]struct { - Input map[string]ProviderRequirements - Override []*RequiredProvider - Want map[string]ProviderRequirements - }{ - "basic merge": { - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.Provider{Type: "random"}, - VersionConstraints: []VersionConstraint{}, - }, - }, - []*RequiredProvider{ - &RequiredProvider{ - Name: "null", - Requirement: VersionConstraint{}, - }, - }, - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.Provider{Type: "random"}, - VersionConstraints: []VersionConstraint{}, - }, - "null": ProviderRequirements{ - Type: addrs.NewDefaultProvider("null"), - VersionConstraints: []VersionConstraint{ - VersionConstraint{ - Required: version.Constraints(nil), - DeclRange: hcl.Range{}, - }, - }, - }, - }, - }, - "override version constraint": { - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.Provider{Type: "random"}, - VersionConstraints: []VersionConstraint{vc1}, - }, - }, - []*RequiredProvider{ - &RequiredProvider{ - Name: "random", - Requirement: vc2, - }, - }, - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.NewDefaultProvider("random"), - VersionConstraints: []VersionConstraint{vc2}, - }, - }, - }, - "merge with source constraint": { - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.Provider{Type: "random"}, - VersionConstraints: []VersionConstraint{vc1}, - }, - }, - []*RequiredProvider{ - &RequiredProvider{ - Name: "random", - Source: Source{SourceStr: "hashicorp/random"}, - Requirement: vc2, - }, - }, - map[string]ProviderRequirements{ - "random": ProviderRequirements{ - Type: addrs.NewDefaultProvider("random"), - VersionConstraints: []VersionConstraint{vc2}, - }, - }, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - mergeProviderVersionConstraints(test.Input, test.Override) - assertResultDeepEqual(t, test.Input, test.Want) - }) - } -} diff --git a/configs/module_test.go b/configs/module_test.go index 3d9c47bdc..d5e3659ad 100644 --- a/configs/module_test.go +++ b/configs/module_test.go @@ -1,6 +1,7 @@ package configs import ( + "strings" "testing" "github.com/hashicorp/terraform/addrs" @@ -111,3 +112,94 @@ func TestProviderForLocalConfig(t *testing.T) { t.Fatalf("wrong result! got %#v, want %#v\n", got, want) } } + +// At most one required_providers block per module is permitted. +func TestModule_required_providers_multiple(t *testing.T) { + _, diags := testModuleFromDir("testdata/invalid-modules/multiple-required-providers") + if !diags.HasErrors() { + t.Fatal("module should have error diags, but does not") + } + + want := `Duplicate required providers configuration` + if got := diags.Error(); !strings.Contains(got, want) { + t.Fatalf("expected error to contain %q\nerror was:\n%s", want, got) + } +} + +// A module may have required_providers configured in files loaded later than +// resources. These provider settings should still be reflected in the +// resources' configuration. +func TestModule_required_providers_after_resource(t *testing.T) { + mod, diags := testModuleFromDir("testdata/valid-modules/required-providers-after-resource") + if diags.HasErrors() { + t.Fatal(diags.Error()) + } + + want := addrs.NewProvider(addrs.DefaultRegistryHost, "foo", "test") + + req, exists := mod.ProviderRequirements.RequiredProviders["test"] + if !exists { + t.Fatal("no provider requirements found for \"test\"") + } + if req.Type != want { + t.Errorf("wrong provider addr for \"test\"\ngot: %s\nwant: %s", + req.Type, want, + ) + } + + if got := mod.ManagedResources["test_instance.my-instance"].Provider; !got.Equals(want) { + t.Errorf("wrong provider addr for \"test_instance.my-instance\"\ngot: %s\nwant: %s", + got, want, + ) + } +} + +// We support overrides for required_providers blocks, which should replace the +// entire block for each provider localname, leaving other blocks unaffected. +// This should also be reflected in any resources in the module using this +// provider. +func TestModule_required_provider_overrides(t *testing.T) { + mod, diags := testModuleFromDir("testdata/valid-modules/required-providers-overrides") + if diags.HasErrors() { + t.Fatal(diags.Error()) + } + + // The foo provider and resource should be unaffected + want := addrs.NewProvider(addrs.DefaultRegistryHost, "acme", "foo") + req, exists := mod.ProviderRequirements.RequiredProviders["foo"] + if !exists { + t.Fatal("no provider requirements found for \"foo\"") + } + if req.Type != want { + t.Errorf("wrong provider addr for \"foo\"\ngot: %s\nwant: %s", + req.Type, want, + ) + } + if got := mod.ManagedResources["foo_thing.ft"].Provider; !got.Equals(want) { + t.Errorf("wrong provider addr for \"foo_thing.ft\"\ngot: %s\nwant: %s", + got, want, + ) + } + + // The bar provider and resource should be using the override config + want = addrs.NewProvider(addrs.DefaultRegistryHost, "blorp", "bar") + req, exists = mod.ProviderRequirements.RequiredProviders["bar"] + if !exists { + t.Fatal("no provider requirements found for \"bar\"") + } + if req.Type != want { + t.Errorf("wrong provider addr for \"bar\"\ngot: %s\nwant: %s", + req.Type, want, + ) + } + if gotVer, wantVer := req.Requirement.Required.String(), "~>2.0.0"; gotVer != wantVer { + t.Errorf("wrong provider version constraint for \"bar\"\ngot: %s\nwant: %s", + gotVer, wantVer, + ) + } + if got := mod.ManagedResources["bar_thing.bt"].Provider; !got.Equals(want) { + t.Errorf("wrong provider addr for \"bar_thing.bt\"\ngot: %s\nwant: %s", + got, want, + ) + } +} diff --git a/configs/parser_config.go b/configs/parser_config.go index e930a093c..354b96a72 100644 --- a/configs/parser_config.go +++ b/configs/parser_config.go @@ -75,7 +75,7 @@ func (p *Parser) loadConfigFile(path string, override bool) (*File, hcl.Diagnost case "required_providers": reqs, reqsDiags := decodeRequiredProvidersBlock(innerBlock) diags = append(diags, reqsDiags...) - file.RequiredProviders = append(file.RequiredProviders, reqs...) + file.RequiredProviders = append(file.RequiredProviders, reqs) case "provider_meta": providerCfg, cfgDiags := decodeProviderMetaBlock(innerBlock) diff --git a/configs/provider_requirements.go b/configs/provider_requirements.go index b535afe0c..d4218a6fc 100644 --- a/configs/provider_requirements.go +++ b/configs/provider_requirements.go @@ -12,41 +12,40 @@ import ( // parent. type RequiredProvider struct { Name string - Source Source + Type addrs.Provider Requirement VersionConstraint + DeclRange hcl.Range } -type Source struct { - SourceStr string - DeclRange hcl.Range +type RequiredProviders struct { + RequiredProviders map[string]*RequiredProvider + DeclRange hcl.Range } -// ProviderRequirements represents provider version constraints from -// required_providers blocks. -type ProviderRequirements struct { - Type addrs.Provider - VersionConstraints []VersionConstraint -} - -func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Diagnostics) { +func decodeRequiredProvidersBlock(block *hcl.Block) (*RequiredProviders, hcl.Diagnostics) { attrs, diags := block.Body.JustAttributes() - var reqs []*RequiredProvider + ret := &RequiredProviders{ + RequiredProviders: make(map[string]*RequiredProvider), + DeclRange: block.DefRange, + } for name, attr := range attrs { expr, err := attr.Expr.Value(nil) if err != nil { diags = append(diags, err...) } + rp := &RequiredProvider{ + Name: name, + DeclRange: attr.Expr.Range(), + } + switch { case expr.Type().IsPrimitiveType(): vc, reqDiags := decodeVersionConstraint(attr) diags = append(diags, reqDiags...) - reqs = append(reqs, &RequiredProvider{ - Name: name, - Requirement: vc, - }) + rp.Requirement = vc + case expr.Type().IsObjectType(): - ret := &RequiredProvider{Name: name} if expr.Type().HasAttribute("version") { vc := VersionConstraint{ DeclRange: attr.Range, @@ -64,25 +63,55 @@ func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Di }) } else { vc.Required = constraints - ret.Requirement = vc + rp.Requirement = vc } } if expr.Type().HasAttribute("source") { - ret.Source.SourceStr = expr.GetAttr("source").AsString() - ret.Source.DeclRange = attr.Range + fqn, sourceDiags := addrs.ParseProviderSourceString(expr.GetAttr("source").AsString()) + + if sourceDiags.HasErrors() { + hclDiags := sourceDiags.ToHCL() + // The diagnostics from ParseProviderSourceString don't contain + // source location information because it has no context to compute + // them from, and so we'll add those in quickly here before we + // return. + for _, diag := range hclDiags { + if diag.Subject == nil { + diag.Subject = attr.Expr.Range().Ptr() + } + } + diags = append(diags, hclDiags...) + } else { + rp.Type = fqn + } } - reqs = append(reqs, ret) + default: // should not happen diags = append(diags, &hcl.Diagnostic{ Severity: hcl.DiagError, - Summary: "Invalid provider_requirements syntax", - Detail: "provider_requirements entries must be strings or objects.", + Summary: "Invalid required_providers syntax", + Detail: "required_providers entries must be strings or objects.", Subject: attr.Expr.Range().Ptr(), }) - reqs = append(reqs, &RequiredProvider{Name: name}) - return reqs, diags } + + if rp.Type.IsZero() { + pType, err := addrs.ParseProviderPart(rp.Name) + if err != nil { + diags = append(diags, &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Invalid provider name", + Detail: err.Error(), + Subject: attr.Expr.Range().Ptr(), + }) + } else { + rp.Type = addrs.ImpliedProviderForUnqualifiedType(pType) + } + } + + ret.RequiredProviders[rp.Name] = rp } - return reqs, diags + + return ret, diags } diff --git a/configs/provider_requirements_test.go b/configs/provider_requirements_test.go index 4abcc7ac3..a04714f55 100644 --- a/configs/provider_requirements_test.go +++ b/configs/provider_requirements_test.go @@ -1,8 +1,6 @@ package configs import ( - "fmt" - "sort" "testing" "github.com/google/go-cmp/cmp" @@ -10,6 +8,7 @@ import ( version "github.com/hashicorp/go-version" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/hcltest" + "github.com/hashicorp/terraform/addrs" "github.com/zclconf/go-cty/cty" ) @@ -19,173 +18,281 @@ var ( if x.Name != y.Name { return false } - if x.Source != y.Source { + if x.Type != y.Type { return false } if x.Requirement.Required.String() != y.Requirement.Required.String() { return false } + if x.DeclRange != y.DeclRange { + return false + } return true }) -) - -func TestDecodeRequiredProvidersBlock_legacy(t *testing.T) { - block := &hcl.Block{ - Type: "required_providers", - Body: hcltest.MockBody(&hcl.BodyContent{ - Attributes: hcl.Attributes{ - "default": { - Name: "default", - Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), - }, - }, - }), - } - - want := &RequiredProvider{ - Name: "default", - Requirement: testVC("1.0.0"), - } - - got, diags := decodeRequiredProvidersBlock(block) - if diags.HasErrors() { - t.Fatalf("unexpected error") - } - if len(got) != 1 { - t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) - } - if !cmp.Equal(got[0], want, ignoreUnexported, comparer) { - t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer)) - } -} - -func TestDecodeRequiredProvidersBlock_provider_source(t *testing.T) { - mockRange := hcl.Range{ + blockRange = hcl.Range{ Filename: "mock.tf", Start: hcl.Pos{Line: 3, Column: 12, Byte: 27}, End: hcl.Pos{Line: 3, Column: 19, Byte: 34}, } + mockRange = hcl.Range{ + Filename: "MockExprLiteral", + } +) - block := &hcl.Block{ - Type: "required_providers", - Body: hcltest.MockBody(&hcl.BodyContent{ - Attributes: hcl.Attributes{ - "my_test": { - Name: "my_test", - Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ - "source": cty.StringVal("mycloud/test"), - "version": cty.StringVal("2.0.0"), - })), - Range: mockRange, - }, +func TestDecodeRequiredProvidersBlock(t *testing.T) { + tests := map[string]struct { + Block *hcl.Block + Want *RequiredProviders + Error string + }{ + "legacy": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "default": { + Name: "default", + Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), + }, + }, + }), + DefRange: blockRange, }, - }), - } - - want := &RequiredProvider{ - Name: "my_test", - Source: Source{SourceStr: "mycloud/test", DeclRange: mockRange}, - Requirement: testVC("2.0.0"), - } - got, diags := decodeRequiredProvidersBlock(block) - if diags.HasErrors() { - t.Fatalf("unexpected error") - } - if len(got) != 1 { - t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) - } - if !cmp.Equal(got[0], want, ignoreUnexported, comparer) { - t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer)) - } -} - -func TestDecodeRequiredProvidersBlock_mixed(t *testing.T) { - block := &hcl.Block{ - Type: "required_providers", - Body: hcltest.MockBody(&hcl.BodyContent{ - Attributes: hcl.Attributes{ - "legacy": { - Name: "legacy", - Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), - }, - "my_test": { - Name: "my_test", - Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ - "source": cty.StringVal("mycloud/test"), - "version": cty.StringVal("2.0.0"), - })), + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "default": { + Name: "default", + Type: addrs.NewDefaultProvider("default"), + Requirement: testVC("1.0.0"), + DeclRange: mockRange, + }, }, + DeclRange: blockRange, }, - }), - } - - want := []*RequiredProvider{ - { - Name: "legacy", - Requirement: testVC("1.0.0"), }, - { - Name: "my_test", - Source: Source{SourceStr: "mycloud/test", DeclRange: hcl.Range{}}, - Requirement: testVC("2.0.0"), + "provider source": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("mycloud/test"), + "version": cty.StringVal("2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "my_test": { + Name: "my_test", + Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"), + Requirement: testVC("2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + }, + "mixed": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "legacy": { + Name: "legacy", + Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")), + }, + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("mycloud/test"), + "version": cty.StringVal("2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "legacy": { + Name: "legacy", + Type: addrs.NewDefaultProvider("legacy"), + Requirement: testVC("1.0.0"), + DeclRange: mockRange, + }, + "my_test": { + Name: "my_test", + Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"), + Requirement: testVC("2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + }, + "version-only block": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "test": { + Name: "test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "version": cty.StringVal("~>2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "test": { + Name: "test", + Type: addrs.NewDefaultProvider("test"), + Requirement: testVC("~>2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + }, + "invalid source": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("some/invalid/provider/source/test"), + "version": cty.StringVal("~>2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "my_test": { + Name: "my_test", + Type: addrs.Provider{}, + Requirement: testVC("~>2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + Error: "Invalid provider source string", + }, + "localname is invalid provider name": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "version": cty.StringVal("~>2.0.0"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "my_test": { + Name: "my_test", + Type: addrs.Provider{}, + Requirement: testVC("~>2.0.0"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + Error: "Invalid provider name", + }, + "version constraint error": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "my_test": { + Name: "my_test", + Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ + "source": cty.StringVal("mycloud/test"), + "version": cty.StringVal("invalid"), + })), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "my_test": { + Name: "my_test", + Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + Error: "Invalid version constraint", + }, + "invalid required_providers attribute value": { + Block: &hcl.Block{ + Type: "required_providers", + Body: hcltest.MockBody(&hcl.BodyContent{ + Attributes: hcl.Attributes{ + "test": { + Name: "test", + Expr: hcltest.MockExprLiteral(cty.ListVal([]cty.Value{cty.StringVal("2.0.0")})), + }, + }, + }), + DefRange: blockRange, + }, + Want: &RequiredProviders{ + RequiredProviders: map[string]*RequiredProvider{ + "test": { + Name: "test", + Type: addrs.NewDefaultProvider("test"), + DeclRange: mockRange, + }, + }, + DeclRange: blockRange, + }, + Error: "Invalid required_providers syntax", }, } - got, diags := decodeRequiredProvidersBlock(block) + for name, test := range tests { + t.Run(name, func(t *testing.T) { + got, diags := decodeRequiredProvidersBlock(test.Block) + if diags.HasErrors() { + if test.Error == "" { + t.Fatalf("unexpected error") + } + if gotErr := diags[0].Summary; gotErr != test.Error { + t.Errorf("wrong error, got %q, want %q", gotErr, test.Error) + } + } else if test.Error != "" { + t.Fatalf("expected error") + } - sort.SliceStable(got, func(i, j int) bool { - return got[i].Name < got[j].Name - }) - - if diags.HasErrors() { - t.Fatalf("unexpected error") - } - if len(got) != 2 { - t.Fatalf("wrong number of results, got %d, wanted 2", len(got)) - } - for i, rp := range want { - if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) { - t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer)) - } - } -} - -func TestDecodeRequiredProvidersBlock_version_error(t *testing.T) { - block := &hcl.Block{ - Type: "required_providers", - Body: hcltest.MockBody(&hcl.BodyContent{ - Attributes: hcl.Attributes{ - "my_test": { - Name: "my_test", - Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ - "source": cty.StringVal("mycloud/test"), - "version": cty.StringVal("invalid"), - })), - }, - }, - }), - } - - want := []*RequiredProvider{ - { - Name: "my_test", - Source: Source{SourceStr: "mycloud/test", DeclRange: hcl.Range{}}, - }, - } - - got, diags := decodeRequiredProvidersBlock(block) - if !diags.HasErrors() { - t.Fatalf("expected error, got success") - } else { - fmt.Printf(diags[0].Summary) - } - if len(got) != 1 { - t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) - } - for i, rp := range want { - if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) { - t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer)) - } + if !cmp.Equal(got, test.Want, ignoreUnexported, comparer) { + t.Fatalf("wrong result:\n %s", cmp.Diff(got, test.Want, ignoreUnexported, comparer)) + } + }) } } diff --git a/configs/testdata/invalid-modules/multiple-required-providers/a.tf b/configs/testdata/invalid-modules/multiple-required-providers/a.tf new file mode 100644 index 000000000..a607ddd1c --- /dev/null +++ b/configs/testdata/invalid-modules/multiple-required-providers/a.tf @@ -0,0 +1,7 @@ +terraform { + required_providers { + bar = { + version = "~>1.0.0" + } + } +} diff --git a/configs/testdata/invalid-modules/multiple-required-providers/b.tf b/configs/testdata/invalid-modules/multiple-required-providers/b.tf new file mode 100644 index 000000000..7672e9249 --- /dev/null +++ b/configs/testdata/invalid-modules/multiple-required-providers/b.tf @@ -0,0 +1,7 @@ +terraform { + required_providers { + foo = { + version = "~>2.0.0" + } + } +} diff --git a/configs/testdata/valid-modules/required-providers-after-resource/main.tf b/configs/testdata/valid-modules/required-providers-after-resource/main.tf new file mode 100644 index 000000000..8d40ec4b9 --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-after-resource/main.tf @@ -0,0 +1,3 @@ +resource test_instance "my-instance" { + provider = test +} \ No newline at end of file diff --git a/configs/testdata/valid-modules/required-providers-after-resource/providers.tf b/configs/testdata/valid-modules/required-providers-after-resource/providers.tf new file mode 100644 index 000000000..687ef1bdd --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-after-resource/providers.tf @@ -0,0 +1,8 @@ +terraform { + required_providers { + test = { + source = "foo/test" + version = "~>1.0.0" + } + } +} \ No newline at end of file diff --git a/configs/testdata/valid-modules/required-providers-overrides/bar_provider_override.tf b/configs/testdata/valid-modules/required-providers-overrides/bar_provider_override.tf new file mode 100644 index 000000000..b83ef7494 --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-overrides/bar_provider_override.tf @@ -0,0 +1,9 @@ +terraform { + required_providers { + bar = { + source = "blorp/bar" + version = "~>2.0.0" + } + } +} + diff --git a/configs/testdata/valid-modules/required-providers-overrides/main.tf b/configs/testdata/valid-modules/required-providers-overrides/main.tf new file mode 100644 index 000000000..f7e58f2e2 --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-overrides/main.tf @@ -0,0 +1,7 @@ +resource bar_thing "bt" { + provider = bar +} + +resource foo_thing "ft" { + provider = foo +} diff --git a/configs/testdata/valid-modules/required-providers-overrides/providers.tf b/configs/testdata/valid-modules/required-providers-overrides/providers.tf new file mode 100644 index 000000000..9b1d897f4 --- /dev/null +++ b/configs/testdata/valid-modules/required-providers-overrides/providers.tf @@ -0,0 +1,11 @@ +terraform { + required_providers { + bar = { + source = "acme/bar" + } + + foo = { + source = "acme/foo" + } + } +}