Merge pull request #24763 from hashicorp/alisdair/required-providers-refactor

configs: Simplify required_providers blocks
This commit is contained in:
Alisdair McDiarmid 2020-04-27 11:16:24 -04:00 committed by GitHub
commit c1137430c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 522 additions and 359 deletions

View File

@ -185,23 +185,21 @@ func (c *Config) addProviderRequirements(reqs getproviders.Requirements) hcl.Dia
var diags hcl.Diagnostics var diags hcl.Diagnostics
// First we'll deal with the requirements directly in _our_ module... // First we'll deal with the requirements directly in _our_ module...
for _, providerReqs := range c.Module.ProviderRequirements { for _, providerReqs := range c.Module.ProviderRequirements.RequiredProviders {
fqn := providerReqs.Type fqn := providerReqs.Type
if _, ok := reqs[fqn]; !ok { if _, ok := reqs[fqn]; !ok {
// We'll at least have an unconstrained dependency then, but might // We'll at least have an unconstrained dependency then, but might
// add to this in the loop below. // add to this in the loop below.
reqs[fqn] = nil reqs[fqn] = nil
} }
for _, constraintsSrc := range providerReqs.VersionConstraints { // The model of version constraints in this package is still the
// The model of version constraints in this package is still the // old one using a different upstream module to represent versions,
// old one using a different upstream module to represent versions, // so we'll need to shim that out here for now. We assume this
// so we'll need to shim that out here for now. We assume this // will always succeed because these constraints already succeeded
// will always succeed because these constraints already succeeded // parsing with the other constraint parser, which uses the same
// parsing with the other constraint parser, which uses the same // syntax.
// syntax. constraints := getproviders.MustParseVersionConstraints(providerReqs.Requirement.Required.String())
constraints := getproviders.MustParseVersionConstraints(constraintsSrc.Required.String()) reqs[fqn] = append(reqs[fqn], constraints...)
reqs[fqn] = append(reqs[fqn], constraints...)
}
} }
// Each resource in the configuration creates an *implicit* provider // Each resource in the configuration creates an *implicit* provider
// dependency, though we'll only record it if there isn't already // dependency, though we'll only record it if there isn't already
@ -321,7 +319,7 @@ func (c *Config) ResolveAbsProviderAddr(addr addrs.ProviderConfig, inModule addr
} }
var provider addrs.Provider var provider addrs.Provider
if providerReq, exists := c.Module.ProviderRequirements[addr.LocalName]; exists { if providerReq, exists := c.Module.ProviderRequirements.RequiredProviders[addr.LocalName]; exists {
provider = providerReq.Type provider = providerReq.Type
} else { } else {
provider = addrs.ImpliedProviderForUnqualifiedType(addr.LocalName) provider = addrs.ImpliedProviderForUnqualifiedType(addr.LocalName)
@ -343,7 +341,7 @@ func (c *Config) ResolveAbsProviderAddr(addr addrs.ProviderConfig, inModule addr
// by checking for the provider in module.ProviderRequirements and falling // by checking for the provider in module.ProviderRequirements and falling
// back to addrs.NewDefaultProvider if it is not found. // back to addrs.NewDefaultProvider if it is not found.
func (c *Config) ProviderForConfigAddr(addr addrs.LocalProviderConfig) addrs.Provider { func (c *Config) ProviderForConfigAddr(addr addrs.LocalProviderConfig) addrs.Provider {
if provider, exists := c.Module.ProviderRequirements[addr.LocalName]; exists { if provider, exists := c.Module.ProviderRequirements.RequiredProviders[addr.LocalName]; exists {
return provider.Type return provider.Type
} }
return c.ResolveAbsProviderAddr(addr, addrs.RootModule).Provider return c.ResolveAbsProviderAddr(addr, addrs.RootModule).Provider

View File

@ -7,7 +7,6 @@ import (
"github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/addrs"
"github.com/hashicorp/terraform/experiments" "github.com/hashicorp/terraform/experiments"
"github.com/hashicorp/terraform/tfdiags"
) )
// Module is a container for a set of configuration constructs that are // Module is a container for a set of configuration constructs that are
@ -31,7 +30,7 @@ type Module struct {
Backend *Backend Backend *Backend
ProviderConfigs map[string]*Provider ProviderConfigs map[string]*Provider
ProviderRequirements map[string]ProviderRequirements ProviderRequirements *RequiredProviders
ProviderLocalNames map[addrs.Provider]string ProviderLocalNames map[addrs.Provider]string
ProviderMetas map[addrs.Provider]*ProviderMeta ProviderMetas map[addrs.Provider]*ProviderMeta
@ -64,7 +63,7 @@ type File struct {
Backends []*Backend Backends []*Backend
ProviderConfigs []*Provider ProviderConfigs []*Provider
ProviderMetas []*ProviderMeta ProviderMetas []*ProviderMeta
RequiredProviders []*RequiredProvider RequiredProviders []*RequiredProviders
Variables []*Variable Variables []*Variable
Locals []*Local Locals []*Local
@ -87,16 +86,50 @@ type File struct {
func NewModule(primaryFiles, overrideFiles []*File) (*Module, hcl.Diagnostics) { func NewModule(primaryFiles, overrideFiles []*File) (*Module, hcl.Diagnostics) {
var diags hcl.Diagnostics var diags hcl.Diagnostics
mod := &Module{ mod := &Module{
ProviderConfigs: map[string]*Provider{}, ProviderConfigs: map[string]*Provider{},
ProviderRequirements: map[string]ProviderRequirements{}, ProviderLocalNames: map[addrs.Provider]string{},
ProviderLocalNames: map[addrs.Provider]string{}, Variables: map[string]*Variable{},
Variables: map[string]*Variable{}, Locals: map[string]*Local{},
Locals: map[string]*Local{}, Outputs: map[string]*Output{},
Outputs: map[string]*Output{}, ModuleCalls: map[string]*ModuleCall{},
ModuleCalls: map[string]*ModuleCall{}, ManagedResources: map[string]*Resource{},
ManagedResources: map[string]*Resource{}, DataResources: map[string]*Resource{},
DataResources: map[string]*Resource{}, ProviderMetas: map[addrs.Provider]*ProviderMeta{},
ProviderMetas: map[addrs.Provider]*ProviderMeta{}, }
// Process the required_providers blocks first, to ensure that all
// resources have access to the correct provider FQNs
for _, file := range primaryFiles {
for _, r := range file.RequiredProviders {
if mod.ProviderRequirements != nil {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Duplicate required providers configuration",
Detail: fmt.Sprintf("A module may have only one required providers configuration. The required providers were previously configured at %s.", mod.ProviderRequirements.DeclRange),
Subject: &r.DeclRange,
})
continue
}
mod.ProviderRequirements = r
}
}
// If no required_providers block is configured, create a useful empty
// state to reduce nil checks elsewhere
if mod.ProviderRequirements == nil {
mod.ProviderRequirements = &RequiredProviders{
RequiredProviders: make(map[string]*RequiredProvider),
}
}
// Any required_providers blocks in override files replace the entire
// block for each provider
for _, file := range overrideFiles {
for _, override := range file.RequiredProviders {
for name, rp := range override.RequiredProviders {
mod.ProviderRequirements.RequiredProviders[name] = rp
}
}
} }
for _, file := range primaryFiles { for _, file := range primaryFiles {
@ -178,35 +211,6 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics {
m.ProviderConfigs[key] = pc m.ProviderConfigs[key] = pc
} }
for _, reqd := range file.RequiredProviders {
var fqn addrs.Provider
if reqd.Source.SourceStr != "" {
var sourceDiags tfdiags.Diagnostics
fqn, sourceDiags = addrs.ParseProviderSourceString(reqd.Source.SourceStr)
hclDiags := sourceDiags.ToHCL()
// The diagnostics from ParseProviderSourceString don't contain
// source location information because it has no context to compute
// them from, and so we'll add those in quickly here before we
// return.
for _, diag := range hclDiags {
if diag.Subject == nil {
diag.Subject = reqd.Source.DeclRange.Ptr()
}
}
diags = append(diags, hclDiags...)
} else {
fqn = addrs.ImpliedProviderForUnqualifiedType(reqd.Name)
}
if existing, exists := m.ProviderRequirements[reqd.Name]; exists {
if existing.Type != fqn {
panic("provider fqn mismatch")
}
existing.VersionConstraints = append(existing.VersionConstraints, reqd.Requirement)
} else {
m.ProviderRequirements[reqd.Name] = ProviderRequirements{Type: fqn, VersionConstraints: []VersionConstraint{reqd.Requirement}}
}
}
for _, pm := range file.ProviderMetas { for _, pm := range file.ProviderMetas {
provider := m.ProviderForLocalConfig(addrs.LocalProviderConfig{LocalName: pm.Provider}) provider := m.ProviderForLocalConfig(addrs.LocalProviderConfig{LocalName: pm.Provider})
if existing, exists := m.ProviderMetas[provider]; exists { if existing, exists := m.ProviderMetas[provider]; exists {
@ -283,7 +287,7 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics {
// set the provider FQN for the resource // set the provider FQN for the resource
if r.ProviderConfigRef != nil { if r.ProviderConfigRef != nil {
if existing, exists := m.ProviderRequirements[r.ProviderConfigAddr().LocalName]; exists { if existing, exists := m.ProviderRequirements.RequiredProviders[r.ProviderConfigAddr().LocalName]; exists {
r.Provider = existing.Type r.Provider = existing.Type
} else { } else {
r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName) r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName)
@ -308,7 +312,7 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics {
// set the provider FQN for the resource // set the provider FQN for the resource
if r.ProviderConfigRef != nil { if r.ProviderConfigRef != nil {
if existing, exists := m.ProviderRequirements[r.ProviderConfigAddr().LocalName]; exists { if existing, exists := m.ProviderRequirements.RequiredProviders[r.ProviderConfigAddr().LocalName]; exists {
r.Provider = existing.Type r.Provider = existing.Type
} else { } else {
r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName) r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName)
@ -382,10 +386,6 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics {
} }
} }
if len(file.RequiredProviders) != 0 {
mergeProviderVersionConstraints(m.ProviderRequirements, file.RequiredProviders)
}
for _, v := range file.Variables { for _, v := range file.Variables {
existing, exists := m.Variables[v.Name] existing, exists := m.Variables[v.Name]
if !exists { if !exists {
@ -458,7 +458,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics {
}) })
continue continue
} }
mergeDiags := existing.merge(r, m.ProviderRequirements) mergeDiags := existing.merge(r, m.ProviderRequirements.RequiredProviders)
diags = append(diags, mergeDiags...) diags = append(diags, mergeDiags...)
} }
@ -474,7 +474,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics {
}) })
continue continue
} }
mergeDiags := existing.merge(r, m.ProviderRequirements) mergeDiags := existing.merge(r, m.ProviderRequirements.RequiredProviders)
diags = append(diags, mergeDiags...) diags = append(diags, mergeDiags...)
} }
@ -487,7 +487,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics {
// only be populated after the module has been parsed. // only be populated after the module has been parsed.
func (m *Module) gatherProviderLocalNames() { func (m *Module) gatherProviderLocalNames() {
providers := make(map[addrs.Provider]string) providers := make(map[addrs.Provider]string)
for k, v := range m.ProviderRequirements { for k, v := range m.ProviderRequirements.RequiredProviders {
providers[v.Type] = k providers[v.Type] = k
} }
m.ProviderLocalNames = providers m.ProviderLocalNames = providers
@ -507,7 +507,7 @@ func (m *Module) LocalNameForProvider(p addrs.Provider) string {
// ProviderForLocalConfig returns the provider FQN for a given LocalProviderConfig // ProviderForLocalConfig returns the provider FQN for a given LocalProviderConfig
func (m *Module) ProviderForLocalConfig(pc addrs.LocalProviderConfig) addrs.Provider { func (m *Module) ProviderForLocalConfig(pc addrs.LocalProviderConfig) addrs.Provider {
if provider, exists := m.ProviderRequirements[pc.LocalName]; exists { if provider, exists := m.ProviderRequirements.RequiredProviders[pc.LocalName]; exists {
return provider.Type return provider.Type
} }
return addrs.ImpliedProviderForUnqualifiedType(pc.LocalName) return addrs.ImpliedProviderForUnqualifiedType(pc.LocalName)

View File

@ -35,25 +35,6 @@ func (p *Provider) merge(op *Provider) hcl.Diagnostics {
return diags return diags
} }
func mergeProviderVersionConstraints(recv map[string]ProviderRequirements, ovrd []*RequiredProvider) {
// Any provider name that's mentioned in the override gets nilled out in
// our map so that we'll rebuild it below. Any provider not mentioned is
// left unchanged.
for _, reqd := range ovrd {
delete(recv, reqd.Name)
}
for _, reqd := range ovrd {
var fqn addrs.Provider
if reqd.Source.SourceStr != "" {
// any errors parsing the source string will have already been captured.
fqn, _ = addrs.ParseProviderSourceString(reqd.Source.SourceStr)
} else {
fqn = addrs.ImpliedProviderForUnqualifiedType(reqd.Name)
}
recv[reqd.Name] = ProviderRequirements{Type: fqn, VersionConstraints: []VersionConstraint{reqd.Requirement}}
}
}
func (v *Variable) merge(ov *Variable) hcl.Diagnostics { func (v *Variable) merge(ov *Variable) hcl.Diagnostics {
var diags hcl.Diagnostics var diags hcl.Diagnostics
@ -197,7 +178,7 @@ func (mc *ModuleCall) merge(omc *ModuleCall) hcl.Diagnostics {
return diags return diags
} }
func (r *Resource) merge(or *Resource, prs map[string]ProviderRequirements) hcl.Diagnostics { func (r *Resource) merge(or *Resource, rps map[string]*RequiredProvider) hcl.Diagnostics {
var diags hcl.Diagnostics var diags hcl.Diagnostics
if r.Mode != or.Mode { if r.Mode != or.Mode {
@ -215,7 +196,7 @@ func (r *Resource) merge(or *Resource, prs map[string]ProviderRequirements) hcl.
if or.ProviderConfigRef != nil { if or.ProviderConfigRef != nil {
r.ProviderConfigRef = or.ProviderConfigRef r.ProviderConfigRef = or.ProviderConfigRef
if existing, exists := prs[or.ProviderConfigRef.Name]; exists { if existing, exists := rps[or.ProviderConfigRef.Name]; exists {
r.Provider = existing.Type r.Provider = existing.Type
} else { } else {
r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigRef.Name) r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigRef.Name)

View File

@ -3,7 +3,6 @@ package configs
import ( import (
"testing" "testing"
version "github.com/hashicorp/go-version"
"github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/gohcl" "github.com/hashicorp/hcl/v2/gohcl"
"github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/addrs"
@ -232,98 +231,3 @@ func TestModuleOverrideResourceFQNs(t *testing.T) {
t.Fatalf("wrong result: found provider config ref %s, expected nil", got.ProviderConfigRef) t.Fatalf("wrong result: found provider config ref %s, expected nil", got.ProviderConfigRef)
} }
} }
func TestMergeProviderVersionConstraints(t *testing.T) {
v1, _ := version.NewConstraint("1.0.0")
vc1 := VersionConstraint{
Required: v1,
}
v2, _ := version.NewConstraint("2.0.0")
vc2 := VersionConstraint{
Required: v2,
}
tests := map[string]struct {
Input map[string]ProviderRequirements
Override []*RequiredProvider
Want map[string]ProviderRequirements
}{
"basic merge": {
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.Provider{Type: "random"},
VersionConstraints: []VersionConstraint{},
},
},
[]*RequiredProvider{
&RequiredProvider{
Name: "null",
Requirement: VersionConstraint{},
},
},
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.Provider{Type: "random"},
VersionConstraints: []VersionConstraint{},
},
"null": ProviderRequirements{
Type: addrs.NewDefaultProvider("null"),
VersionConstraints: []VersionConstraint{
VersionConstraint{
Required: version.Constraints(nil),
DeclRange: hcl.Range{},
},
},
},
},
},
"override version constraint": {
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.Provider{Type: "random"},
VersionConstraints: []VersionConstraint{vc1},
},
},
[]*RequiredProvider{
&RequiredProvider{
Name: "random",
Requirement: vc2,
},
},
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.NewDefaultProvider("random"),
VersionConstraints: []VersionConstraint{vc2},
},
},
},
"merge with source constraint": {
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.Provider{Type: "random"},
VersionConstraints: []VersionConstraint{vc1},
},
},
[]*RequiredProvider{
&RequiredProvider{
Name: "random",
Source: Source{SourceStr: "hashicorp/random"},
Requirement: vc2,
},
},
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.NewDefaultProvider("random"),
VersionConstraints: []VersionConstraint{vc2},
},
},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
mergeProviderVersionConstraints(test.Input, test.Override)
assertResultDeepEqual(t, test.Input, test.Want)
})
}
}

View File

@ -1,6 +1,7 @@
package configs package configs
import ( import (
"strings"
"testing" "testing"
"github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/addrs"
@ -111,3 +112,94 @@ func TestProviderForLocalConfig(t *testing.T) {
t.Fatalf("wrong result! got %#v, want %#v\n", got, want) t.Fatalf("wrong result! got %#v, want %#v\n", got, want)
} }
} }
// At most one required_providers block per module is permitted.
func TestModule_required_providers_multiple(t *testing.T) {
_, diags := testModuleFromDir("testdata/invalid-modules/multiple-required-providers")
if !diags.HasErrors() {
t.Fatal("module should have error diags, but does not")
}
want := `Duplicate required providers configuration`
if got := diags.Error(); !strings.Contains(got, want) {
t.Fatalf("expected error to contain %q\nerror was:\n%s", want, got)
}
}
// A module may have required_providers configured in files loaded later than
// resources. These provider settings should still be reflected in the
// resources' configuration.
func TestModule_required_providers_after_resource(t *testing.T) {
mod, diags := testModuleFromDir("testdata/valid-modules/required-providers-after-resource")
if diags.HasErrors() {
t.Fatal(diags.Error())
}
want := addrs.NewProvider(addrs.DefaultRegistryHost, "foo", "test")
req, exists := mod.ProviderRequirements.RequiredProviders["test"]
if !exists {
t.Fatal("no provider requirements found for \"test\"")
}
if req.Type != want {
t.Errorf("wrong provider addr for \"test\"\ngot: %s\nwant: %s",
req.Type, want,
)
}
if got := mod.ManagedResources["test_instance.my-instance"].Provider; !got.Equals(want) {
t.Errorf("wrong provider addr for \"test_instance.my-instance\"\ngot: %s\nwant: %s",
got, want,
)
}
}
// We support overrides for required_providers blocks, which should replace the
// entire block for each provider localname, leaving other blocks unaffected.
// This should also be reflected in any resources in the module using this
// provider.
func TestModule_required_provider_overrides(t *testing.T) {
mod, diags := testModuleFromDir("testdata/valid-modules/required-providers-overrides")
if diags.HasErrors() {
t.Fatal(diags.Error())
}
// The foo provider and resource should be unaffected
want := addrs.NewProvider(addrs.DefaultRegistryHost, "acme", "foo")
req, exists := mod.ProviderRequirements.RequiredProviders["foo"]
if !exists {
t.Fatal("no provider requirements found for \"foo\"")
}
if req.Type != want {
t.Errorf("wrong provider addr for \"foo\"\ngot: %s\nwant: %s",
req.Type, want,
)
}
if got := mod.ManagedResources["foo_thing.ft"].Provider; !got.Equals(want) {
t.Errorf("wrong provider addr for \"foo_thing.ft\"\ngot: %s\nwant: %s",
got, want,
)
}
// The bar provider and resource should be using the override config
want = addrs.NewProvider(addrs.DefaultRegistryHost, "blorp", "bar")
req, exists = mod.ProviderRequirements.RequiredProviders["bar"]
if !exists {
t.Fatal("no provider requirements found for \"bar\"")
}
if req.Type != want {
t.Errorf("wrong provider addr for \"bar\"\ngot: %s\nwant: %s",
req.Type, want,
)
}
if gotVer, wantVer := req.Requirement.Required.String(), "~>2.0.0"; gotVer != wantVer {
t.Errorf("wrong provider version constraint for \"bar\"\ngot: %s\nwant: %s",
gotVer, wantVer,
)
}
if got := mod.ManagedResources["bar_thing.bt"].Provider; !got.Equals(want) {
t.Errorf("wrong provider addr for \"bar_thing.bt\"\ngot: %s\nwant: %s",
got, want,
)
}
}

View File

@ -75,7 +75,7 @@ func (p *Parser) loadConfigFile(path string, override bool) (*File, hcl.Diagnost
case "required_providers": case "required_providers":
reqs, reqsDiags := decodeRequiredProvidersBlock(innerBlock) reqs, reqsDiags := decodeRequiredProvidersBlock(innerBlock)
diags = append(diags, reqsDiags...) diags = append(diags, reqsDiags...)
file.RequiredProviders = append(file.RequiredProviders, reqs...) file.RequiredProviders = append(file.RequiredProviders, reqs)
case "provider_meta": case "provider_meta":
providerCfg, cfgDiags := decodeProviderMetaBlock(innerBlock) providerCfg, cfgDiags := decodeProviderMetaBlock(innerBlock)

View File

@ -12,41 +12,40 @@ import (
// parent. // parent.
type RequiredProvider struct { type RequiredProvider struct {
Name string Name string
Source Source Type addrs.Provider
Requirement VersionConstraint Requirement VersionConstraint
DeclRange hcl.Range
} }
type Source struct { type RequiredProviders struct {
SourceStr string RequiredProviders map[string]*RequiredProvider
DeclRange hcl.Range DeclRange hcl.Range
} }
// ProviderRequirements represents provider version constraints from func decodeRequiredProvidersBlock(block *hcl.Block) (*RequiredProviders, hcl.Diagnostics) {
// required_providers blocks.
type ProviderRequirements struct {
Type addrs.Provider
VersionConstraints []VersionConstraint
}
func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Diagnostics) {
attrs, diags := block.Body.JustAttributes() attrs, diags := block.Body.JustAttributes()
var reqs []*RequiredProvider ret := &RequiredProviders{
RequiredProviders: make(map[string]*RequiredProvider),
DeclRange: block.DefRange,
}
for name, attr := range attrs { for name, attr := range attrs {
expr, err := attr.Expr.Value(nil) expr, err := attr.Expr.Value(nil)
if err != nil { if err != nil {
diags = append(diags, err...) diags = append(diags, err...)
} }
rp := &RequiredProvider{
Name: name,
DeclRange: attr.Expr.Range(),
}
switch { switch {
case expr.Type().IsPrimitiveType(): case expr.Type().IsPrimitiveType():
vc, reqDiags := decodeVersionConstraint(attr) vc, reqDiags := decodeVersionConstraint(attr)
diags = append(diags, reqDiags...) diags = append(diags, reqDiags...)
reqs = append(reqs, &RequiredProvider{ rp.Requirement = vc
Name: name,
Requirement: vc,
})
case expr.Type().IsObjectType(): case expr.Type().IsObjectType():
ret := &RequiredProvider{Name: name}
if expr.Type().HasAttribute("version") { if expr.Type().HasAttribute("version") {
vc := VersionConstraint{ vc := VersionConstraint{
DeclRange: attr.Range, DeclRange: attr.Range,
@ -64,25 +63,55 @@ func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Di
}) })
} else { } else {
vc.Required = constraints vc.Required = constraints
ret.Requirement = vc rp.Requirement = vc
} }
} }
if expr.Type().HasAttribute("source") { if expr.Type().HasAttribute("source") {
ret.Source.SourceStr = expr.GetAttr("source").AsString() fqn, sourceDiags := addrs.ParseProviderSourceString(expr.GetAttr("source").AsString())
ret.Source.DeclRange = attr.Range
if sourceDiags.HasErrors() {
hclDiags := sourceDiags.ToHCL()
// The diagnostics from ParseProviderSourceString don't contain
// source location information because it has no context to compute
// them from, and so we'll add those in quickly here before we
// return.
for _, diag := range hclDiags {
if diag.Subject == nil {
diag.Subject = attr.Expr.Range().Ptr()
}
}
diags = append(diags, hclDiags...)
} else {
rp.Type = fqn
}
} }
reqs = append(reqs, ret)
default: default:
// should not happen // should not happen
diags = append(diags, &hcl.Diagnostic{ diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError, Severity: hcl.DiagError,
Summary: "Invalid provider_requirements syntax", Summary: "Invalid required_providers syntax",
Detail: "provider_requirements entries must be strings or objects.", Detail: "required_providers entries must be strings or objects.",
Subject: attr.Expr.Range().Ptr(), Subject: attr.Expr.Range().Ptr(),
}) })
reqs = append(reqs, &RequiredProvider{Name: name})
return reqs, diags
} }
if rp.Type.IsZero() {
pType, err := addrs.ParseProviderPart(rp.Name)
if err != nil {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Invalid provider name",
Detail: err.Error(),
Subject: attr.Expr.Range().Ptr(),
})
} else {
rp.Type = addrs.ImpliedProviderForUnqualifiedType(pType)
}
}
ret.RequiredProviders[rp.Name] = rp
} }
return reqs, diags
return ret, diags
} }

View File

@ -1,8 +1,6 @@
package configs package configs
import ( import (
"fmt"
"sort"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -10,6 +8,7 @@ import (
version "github.com/hashicorp/go-version" version "github.com/hashicorp/go-version"
"github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hcltest" "github.com/hashicorp/hcl/v2/hcltest"
"github.com/hashicorp/terraform/addrs"
"github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty"
) )
@ -19,173 +18,281 @@ var (
if x.Name != y.Name { if x.Name != y.Name {
return false return false
} }
if x.Source != y.Source { if x.Type != y.Type {
return false return false
} }
if x.Requirement.Required.String() != y.Requirement.Required.String() { if x.Requirement.Required.String() != y.Requirement.Required.String() {
return false return false
} }
if x.DeclRange != y.DeclRange {
return false
}
return true return true
}) })
) blockRange = hcl.Range{
func TestDecodeRequiredProvidersBlock_legacy(t *testing.T) {
block := &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"default": {
Name: "default",
Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")),
},
},
}),
}
want := &RequiredProvider{
Name: "default",
Requirement: testVC("1.0.0"),
}
got, diags := decodeRequiredProvidersBlock(block)
if diags.HasErrors() {
t.Fatalf("unexpected error")
}
if len(got) != 1 {
t.Fatalf("wrong number of results, got %d, wanted 1", len(got))
}
if !cmp.Equal(got[0], want, ignoreUnexported, comparer) {
t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer))
}
}
func TestDecodeRequiredProvidersBlock_provider_source(t *testing.T) {
mockRange := hcl.Range{
Filename: "mock.tf", Filename: "mock.tf",
Start: hcl.Pos{Line: 3, Column: 12, Byte: 27}, Start: hcl.Pos{Line: 3, Column: 12, Byte: 27},
End: hcl.Pos{Line: 3, Column: 19, Byte: 34}, End: hcl.Pos{Line: 3, Column: 19, Byte: 34},
} }
mockRange = hcl.Range{
Filename: "MockExprLiteral",
}
)
block := &hcl.Block{ func TestDecodeRequiredProvidersBlock(t *testing.T) {
Type: "required_providers", tests := map[string]struct {
Body: hcltest.MockBody(&hcl.BodyContent{ Block *hcl.Block
Attributes: hcl.Attributes{ Want *RequiredProviders
"my_test": { Error string
Name: "my_test", }{
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{ "legacy": {
"source": cty.StringVal("mycloud/test"), Block: &hcl.Block{
"version": cty.StringVal("2.0.0"), Type: "required_providers",
})), Body: hcltest.MockBody(&hcl.BodyContent{
Range: mockRange, Attributes: hcl.Attributes{
}, "default": {
Name: "default",
Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")),
},
},
}),
DefRange: blockRange,
}, },
}), Want: &RequiredProviders{
} RequiredProviders: map[string]*RequiredProvider{
"default": {
want := &RequiredProvider{ Name: "default",
Name: "my_test", Type: addrs.NewDefaultProvider("default"),
Source: Source{SourceStr: "mycloud/test", DeclRange: mockRange}, Requirement: testVC("1.0.0"),
Requirement: testVC("2.0.0"), DeclRange: mockRange,
} },
got, diags := decodeRequiredProvidersBlock(block)
if diags.HasErrors() {
t.Fatalf("unexpected error")
}
if len(got) != 1 {
t.Fatalf("wrong number of results, got %d, wanted 1", len(got))
}
if !cmp.Equal(got[0], want, ignoreUnexported, comparer) {
t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer))
}
}
func TestDecodeRequiredProvidersBlock_mixed(t *testing.T) {
block := &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"legacy": {
Name: "legacy",
Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")),
},
"my_test": {
Name: "my_test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"source": cty.StringVal("mycloud/test"),
"version": cty.StringVal("2.0.0"),
})),
}, },
DeclRange: blockRange,
}, },
}),
}
want := []*RequiredProvider{
{
Name: "legacy",
Requirement: testVC("1.0.0"),
}, },
{ "provider source": {
Name: "my_test", Block: &hcl.Block{
Source: Source{SourceStr: "mycloud/test", DeclRange: hcl.Range{}}, Type: "required_providers",
Requirement: testVC("2.0.0"), Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"my_test": {
Name: "my_test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"source": cty.StringVal("mycloud/test"),
"version": cty.StringVal("2.0.0"),
})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"my_test": {
Name: "my_test",
Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"),
Requirement: testVC("2.0.0"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
},
"mixed": {
Block: &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"legacy": {
Name: "legacy",
Expr: hcltest.MockExprLiteral(cty.StringVal("1.0.0")),
},
"my_test": {
Name: "my_test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"source": cty.StringVal("mycloud/test"),
"version": cty.StringVal("2.0.0"),
})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"legacy": {
Name: "legacy",
Type: addrs.NewDefaultProvider("legacy"),
Requirement: testVC("1.0.0"),
DeclRange: mockRange,
},
"my_test": {
Name: "my_test",
Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"),
Requirement: testVC("2.0.0"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
},
"version-only block": {
Block: &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"test": {
Name: "test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"version": cty.StringVal("~>2.0.0"),
})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"test": {
Name: "test",
Type: addrs.NewDefaultProvider("test"),
Requirement: testVC("~>2.0.0"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
},
"invalid source": {
Block: &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"my_test": {
Name: "my_test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"source": cty.StringVal("some/invalid/provider/source/test"),
"version": cty.StringVal("~>2.0.0"),
})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"my_test": {
Name: "my_test",
Type: addrs.Provider{},
Requirement: testVC("~>2.0.0"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
Error: "Invalid provider source string",
},
"localname is invalid provider name": {
Block: &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"my_test": {
Name: "my_test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"version": cty.StringVal("~>2.0.0"),
})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"my_test": {
Name: "my_test",
Type: addrs.Provider{},
Requirement: testVC("~>2.0.0"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
Error: "Invalid provider name",
},
"version constraint error": {
Block: &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"my_test": {
Name: "my_test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"source": cty.StringVal("mycloud/test"),
"version": cty.StringVal("invalid"),
})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"my_test": {
Name: "my_test",
Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
Error: "Invalid version constraint",
},
"invalid required_providers attribute value": {
Block: &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"test": {
Name: "test",
Expr: hcltest.MockExprLiteral(cty.ListVal([]cty.Value{cty.StringVal("2.0.0")})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"test": {
Name: "test",
Type: addrs.NewDefaultProvider("test"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
Error: "Invalid required_providers syntax",
}, },
} }
got, diags := decodeRequiredProvidersBlock(block) for name, test := range tests {
t.Run(name, func(t *testing.T) {
got, diags := decodeRequiredProvidersBlock(test.Block)
if diags.HasErrors() {
if test.Error == "" {
t.Fatalf("unexpected error")
}
if gotErr := diags[0].Summary; gotErr != test.Error {
t.Errorf("wrong error, got %q, want %q", gotErr, test.Error)
}
} else if test.Error != "" {
t.Fatalf("expected error")
}
sort.SliceStable(got, func(i, j int) bool { if !cmp.Equal(got, test.Want, ignoreUnexported, comparer) {
return got[i].Name < got[j].Name t.Fatalf("wrong result:\n %s", cmp.Diff(got, test.Want, ignoreUnexported, comparer))
}) }
})
if diags.HasErrors() {
t.Fatalf("unexpected error")
}
if len(got) != 2 {
t.Fatalf("wrong number of results, got %d, wanted 2", len(got))
}
for i, rp := range want {
if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) {
t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer))
}
}
}
func TestDecodeRequiredProvidersBlock_version_error(t *testing.T) {
block := &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"my_test": {
Name: "my_test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"source": cty.StringVal("mycloud/test"),
"version": cty.StringVal("invalid"),
})),
},
},
}),
}
want := []*RequiredProvider{
{
Name: "my_test",
Source: Source{SourceStr: "mycloud/test", DeclRange: hcl.Range{}},
},
}
got, diags := decodeRequiredProvidersBlock(block)
if !diags.HasErrors() {
t.Fatalf("expected error, got success")
} else {
fmt.Printf(diags[0].Summary)
}
if len(got) != 1 {
t.Fatalf("wrong number of results, got %d, wanted 1", len(got))
}
for i, rp := range want {
if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) {
t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer))
}
} }
} }

View File

@ -0,0 +1,7 @@
terraform {
required_providers {
bar = {
version = "~>1.0.0"
}
}
}

View File

@ -0,0 +1,7 @@
terraform {
required_providers {
foo = {
version = "~>2.0.0"
}
}
}

View File

@ -0,0 +1,3 @@
resource test_instance "my-instance" {
provider = test
}

View File

@ -0,0 +1,8 @@
terraform {
required_providers {
test = {
source = "foo/test"
version = "~>1.0.0"
}
}
}

View File

@ -0,0 +1,9 @@
terraform {
required_providers {
bar = {
source = "blorp/bar"
version = "~>2.0.0"
}
}
}

View File

@ -0,0 +1,7 @@
resource bar_thing "bt" {
provider = bar
}
resource foo_thing "ft" {
provider = foo
}

View File

@ -0,0 +1,11 @@
terraform {
required_providers {
bar = {
source = "acme/bar"
}
foo = {
source = "acme/foo"
}
}
}