configs: Simplify required_providers blocks

We now permit at most one `required_providers` block per module (except
for overrides). This prevents users (and Terraform) from struggling to
understand how to merge multiple `required_providers` configurations,
with `version` and `source` attributes split across multiple blocks.

Because only one `required_providers` block is permitted, there is no
need to concatenate version constraints and resolve them. This allows us
to simplify the structs used to represent provider requirements,
aligning more closely with other structs in this package.

This commit also fixes a semantic use-before-initialize bug, where
resources defined before a `required_providers` block would be unable to
use its source attribute. We achieve this by processing the module's
`required_providers` configuration (and overrides) before resources.

Overrides for `required_providers` work as before, replacing the entire
block per provider.
This commit is contained in:
Alisdair McDiarmid 2020-04-24 10:54:24 -04:00
parent dad1262fb8
commit 7ca7b1f0fe
15 changed files with 522 additions and 359 deletions

View File

@ -185,24 +185,22 @@ func (c *Config) addProviderRequirements(reqs getproviders.Requirements) hcl.Dia
var diags hcl.Diagnostics var diags hcl.Diagnostics
// First we'll deal with the requirements directly in _our_ module... // First we'll deal with the requirements directly in _our_ module...
for _, providerReqs := range c.Module.ProviderRequirements { for _, providerReqs := range c.Module.ProviderRequirements.RequiredProviders {
fqn := providerReqs.Type fqn := providerReqs.Type
if _, ok := reqs[fqn]; !ok { if _, ok := reqs[fqn]; !ok {
// We'll at least have an unconstrained dependency then, but might // We'll at least have an unconstrained dependency then, but might
// add to this in the loop below. // add to this in the loop below.
reqs[fqn] = nil reqs[fqn] = nil
} }
for _, constraintsSrc := range providerReqs.VersionConstraints {
// The model of version constraints in this package is still the // The model of version constraints in this package is still the
// old one using a different upstream module to represent versions, // old one using a different upstream module to represent versions,
// so we'll need to shim that out here for now. We assume this // so we'll need to shim that out here for now. We assume this
// will always succeed because these constraints already succeeded // will always succeed because these constraints already succeeded
// parsing with the other constraint parser, which uses the same // parsing with the other constraint parser, which uses the same
// syntax. // syntax.
constraints := getproviders.MustParseVersionConstraints(constraintsSrc.Required.String()) constraints := getproviders.MustParseVersionConstraints(providerReqs.Requirement.Required.String())
reqs[fqn] = append(reqs[fqn], constraints...) reqs[fqn] = append(reqs[fqn], constraints...)
} }
}
// Each resource in the configuration creates an *implicit* provider // Each resource in the configuration creates an *implicit* provider
// dependency, though we'll only record it if there isn't already // dependency, though we'll only record it if there isn't already
// an explicit dependency on the same provider. // an explicit dependency on the same provider.
@ -321,7 +319,7 @@ func (c *Config) ResolveAbsProviderAddr(addr addrs.ProviderConfig, inModule addr
} }
var provider addrs.Provider var provider addrs.Provider
if providerReq, exists := c.Module.ProviderRequirements[addr.LocalName]; exists { if providerReq, exists := c.Module.ProviderRequirements.RequiredProviders[addr.LocalName]; exists {
provider = providerReq.Type provider = providerReq.Type
} else { } else {
provider = addrs.ImpliedProviderForUnqualifiedType(addr.LocalName) provider = addrs.ImpliedProviderForUnqualifiedType(addr.LocalName)
@ -343,7 +341,7 @@ func (c *Config) ResolveAbsProviderAddr(addr addrs.ProviderConfig, inModule addr
// by checking for the provider in module.ProviderRequirements and falling // by checking for the provider in module.ProviderRequirements and falling
// back to addrs.NewDefaultProvider if it is not found. // back to addrs.NewDefaultProvider if it is not found.
func (c *Config) ProviderForConfigAddr(addr addrs.LocalProviderConfig) addrs.Provider { func (c *Config) ProviderForConfigAddr(addr addrs.LocalProviderConfig) addrs.Provider {
if provider, exists := c.Module.ProviderRequirements[addr.LocalName]; exists { if provider, exists := c.Module.ProviderRequirements.RequiredProviders[addr.LocalName]; exists {
return provider.Type return provider.Type
} }
return c.ResolveAbsProviderAddr(addr, addrs.RootModule).Provider return c.ResolveAbsProviderAddr(addr, addrs.RootModule).Provider

View File

@ -7,7 +7,6 @@ import (
"github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/addrs"
"github.com/hashicorp/terraform/experiments" "github.com/hashicorp/terraform/experiments"
"github.com/hashicorp/terraform/tfdiags"
) )
// Module is a container for a set of configuration constructs that are // Module is a container for a set of configuration constructs that are
@ -31,7 +30,7 @@ type Module struct {
Backend *Backend Backend *Backend
ProviderConfigs map[string]*Provider ProviderConfigs map[string]*Provider
ProviderRequirements map[string]ProviderRequirements ProviderRequirements *RequiredProviders
ProviderLocalNames map[addrs.Provider]string ProviderLocalNames map[addrs.Provider]string
ProviderMetas map[addrs.Provider]*ProviderMeta ProviderMetas map[addrs.Provider]*ProviderMeta
@ -64,7 +63,7 @@ type File struct {
Backends []*Backend Backends []*Backend
ProviderConfigs []*Provider ProviderConfigs []*Provider
ProviderMetas []*ProviderMeta ProviderMetas []*ProviderMeta
RequiredProviders []*RequiredProvider RequiredProviders []*RequiredProviders
Variables []*Variable Variables []*Variable
Locals []*Local Locals []*Local
@ -88,7 +87,6 @@ func NewModule(primaryFiles, overrideFiles []*File) (*Module, hcl.Diagnostics) {
var diags hcl.Diagnostics var diags hcl.Diagnostics
mod := &Module{ mod := &Module{
ProviderConfigs: map[string]*Provider{}, ProviderConfigs: map[string]*Provider{},
ProviderRequirements: map[string]ProviderRequirements{},
ProviderLocalNames: map[addrs.Provider]string{}, ProviderLocalNames: map[addrs.Provider]string{},
Variables: map[string]*Variable{}, Variables: map[string]*Variable{},
Locals: map[string]*Local{}, Locals: map[string]*Local{},
@ -99,6 +97,41 @@ func NewModule(primaryFiles, overrideFiles []*File) (*Module, hcl.Diagnostics) {
ProviderMetas: map[addrs.Provider]*ProviderMeta{}, ProviderMetas: map[addrs.Provider]*ProviderMeta{},
} }
// Process the required_providers blocks first, to ensure that all
// resources have access to the correct provider FQNs
for _, file := range primaryFiles {
for _, r := range file.RequiredProviders {
if mod.ProviderRequirements != nil {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Duplicate required providers configuration",
Detail: fmt.Sprintf("A module may have only one required providers configuration. The required providers were previously configured at %s.", mod.ProviderRequirements.DeclRange),
Subject: &r.DeclRange,
})
continue
}
mod.ProviderRequirements = r
}
}
// If no required_providers block is configured, create a useful empty
// state to reduce nil checks elsewhere
if mod.ProviderRequirements == nil {
mod.ProviderRequirements = &RequiredProviders{
RequiredProviders: make(map[string]*RequiredProvider),
}
}
// Any required_providers blocks in override files replace the entire
// block for each provider
for _, file := range overrideFiles {
for _, override := range file.RequiredProviders {
for name, rp := range override.RequiredProviders {
mod.ProviderRequirements.RequiredProviders[name] = rp
}
}
}
for _, file := range primaryFiles { for _, file := range primaryFiles {
fileDiags := mod.appendFile(file) fileDiags := mod.appendFile(file)
diags = append(diags, fileDiags...) diags = append(diags, fileDiags...)
@ -178,35 +211,6 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics {
m.ProviderConfigs[key] = pc m.ProviderConfigs[key] = pc
} }
for _, reqd := range file.RequiredProviders {
var fqn addrs.Provider
if reqd.Source.SourceStr != "" {
var sourceDiags tfdiags.Diagnostics
fqn, sourceDiags = addrs.ParseProviderSourceString(reqd.Source.SourceStr)
hclDiags := sourceDiags.ToHCL()
// The diagnostics from ParseProviderSourceString don't contain
// source location information because it has no context to compute
// them from, and so we'll add those in quickly here before we
// return.
for _, diag := range hclDiags {
if diag.Subject == nil {
diag.Subject = reqd.Source.DeclRange.Ptr()
}
}
diags = append(diags, hclDiags...)
} else {
fqn = addrs.ImpliedProviderForUnqualifiedType(reqd.Name)
}
if existing, exists := m.ProviderRequirements[reqd.Name]; exists {
if existing.Type != fqn {
panic("provider fqn mismatch")
}
existing.VersionConstraints = append(existing.VersionConstraints, reqd.Requirement)
} else {
m.ProviderRequirements[reqd.Name] = ProviderRequirements{Type: fqn, VersionConstraints: []VersionConstraint{reqd.Requirement}}
}
}
for _, pm := range file.ProviderMetas { for _, pm := range file.ProviderMetas {
provider := m.ProviderForLocalConfig(addrs.LocalProviderConfig{LocalName: pm.Provider}) provider := m.ProviderForLocalConfig(addrs.LocalProviderConfig{LocalName: pm.Provider})
if existing, exists := m.ProviderMetas[provider]; exists { if existing, exists := m.ProviderMetas[provider]; exists {
@ -283,7 +287,7 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics {
// set the provider FQN for the resource // set the provider FQN for the resource
if r.ProviderConfigRef != nil { if r.ProviderConfigRef != nil {
if existing, exists := m.ProviderRequirements[r.ProviderConfigAddr().LocalName]; exists { if existing, exists := m.ProviderRequirements.RequiredProviders[r.ProviderConfigAddr().LocalName]; exists {
r.Provider = existing.Type r.Provider = existing.Type
} else { } else {
r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName) r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName)
@ -308,7 +312,7 @@ func (m *Module) appendFile(file *File) hcl.Diagnostics {
// set the provider FQN for the resource // set the provider FQN for the resource
if r.ProviderConfigRef != nil { if r.ProviderConfigRef != nil {
if existing, exists := m.ProviderRequirements[r.ProviderConfigAddr().LocalName]; exists { if existing, exists := m.ProviderRequirements.RequiredProviders[r.ProviderConfigAddr().LocalName]; exists {
r.Provider = existing.Type r.Provider = existing.Type
} else { } else {
r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName) r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigAddr().LocalName)
@ -382,10 +386,6 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics {
} }
} }
if len(file.RequiredProviders) != 0 {
mergeProviderVersionConstraints(m.ProviderRequirements, file.RequiredProviders)
}
for _, v := range file.Variables { for _, v := range file.Variables {
existing, exists := m.Variables[v.Name] existing, exists := m.Variables[v.Name]
if !exists { if !exists {
@ -458,7 +458,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics {
}) })
continue continue
} }
mergeDiags := existing.merge(r, m.ProviderRequirements) mergeDiags := existing.merge(r, m.ProviderRequirements.RequiredProviders)
diags = append(diags, mergeDiags...) diags = append(diags, mergeDiags...)
} }
@ -474,7 +474,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics {
}) })
continue continue
} }
mergeDiags := existing.merge(r, m.ProviderRequirements) mergeDiags := existing.merge(r, m.ProviderRequirements.RequiredProviders)
diags = append(diags, mergeDiags...) diags = append(diags, mergeDiags...)
} }
@ -487,7 +487,7 @@ func (m *Module) mergeFile(file *File) hcl.Diagnostics {
// only be populated after the module has been parsed. // only be populated after the module has been parsed.
func (m *Module) gatherProviderLocalNames() { func (m *Module) gatherProviderLocalNames() {
providers := make(map[addrs.Provider]string) providers := make(map[addrs.Provider]string)
for k, v := range m.ProviderRequirements { for k, v := range m.ProviderRequirements.RequiredProviders {
providers[v.Type] = k providers[v.Type] = k
} }
m.ProviderLocalNames = providers m.ProviderLocalNames = providers
@ -507,7 +507,7 @@ func (m *Module) LocalNameForProvider(p addrs.Provider) string {
// ProviderForLocalConfig returns the provider FQN for a given LocalProviderConfig // ProviderForLocalConfig returns the provider FQN for a given LocalProviderConfig
func (m *Module) ProviderForLocalConfig(pc addrs.LocalProviderConfig) addrs.Provider { func (m *Module) ProviderForLocalConfig(pc addrs.LocalProviderConfig) addrs.Provider {
if provider, exists := m.ProviderRequirements[pc.LocalName]; exists { if provider, exists := m.ProviderRequirements.RequiredProviders[pc.LocalName]; exists {
return provider.Type return provider.Type
} }
return addrs.ImpliedProviderForUnqualifiedType(pc.LocalName) return addrs.ImpliedProviderForUnqualifiedType(pc.LocalName)

View File

@ -35,25 +35,6 @@ func (p *Provider) merge(op *Provider) hcl.Diagnostics {
return diags return diags
} }
func mergeProviderVersionConstraints(recv map[string]ProviderRequirements, ovrd []*RequiredProvider) {
// Any provider name that's mentioned in the override gets nilled out in
// our map so that we'll rebuild it below. Any provider not mentioned is
// left unchanged.
for _, reqd := range ovrd {
delete(recv, reqd.Name)
}
for _, reqd := range ovrd {
var fqn addrs.Provider
if reqd.Source.SourceStr != "" {
// any errors parsing the source string will have already been captured.
fqn, _ = addrs.ParseProviderSourceString(reqd.Source.SourceStr)
} else {
fqn = addrs.ImpliedProviderForUnqualifiedType(reqd.Name)
}
recv[reqd.Name] = ProviderRequirements{Type: fqn, VersionConstraints: []VersionConstraint{reqd.Requirement}}
}
}
func (v *Variable) merge(ov *Variable) hcl.Diagnostics { func (v *Variable) merge(ov *Variable) hcl.Diagnostics {
var diags hcl.Diagnostics var diags hcl.Diagnostics
@ -197,7 +178,7 @@ func (mc *ModuleCall) merge(omc *ModuleCall) hcl.Diagnostics {
return diags return diags
} }
func (r *Resource) merge(or *Resource, prs map[string]ProviderRequirements) hcl.Diagnostics { func (r *Resource) merge(or *Resource, rps map[string]*RequiredProvider) hcl.Diagnostics {
var diags hcl.Diagnostics var diags hcl.Diagnostics
if r.Mode != or.Mode { if r.Mode != or.Mode {
@ -215,7 +196,7 @@ func (r *Resource) merge(or *Resource, prs map[string]ProviderRequirements) hcl.
if or.ProviderConfigRef != nil { if or.ProviderConfigRef != nil {
r.ProviderConfigRef = or.ProviderConfigRef r.ProviderConfigRef = or.ProviderConfigRef
if existing, exists := prs[or.ProviderConfigRef.Name]; exists { if existing, exists := rps[or.ProviderConfigRef.Name]; exists {
r.Provider = existing.Type r.Provider = existing.Type
} else { } else {
r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigRef.Name) r.Provider = addrs.ImpliedProviderForUnqualifiedType(r.ProviderConfigRef.Name)

View File

@ -3,7 +3,6 @@ package configs
import ( import (
"testing" "testing"
version "github.com/hashicorp/go-version"
"github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/gohcl" "github.com/hashicorp/hcl/v2/gohcl"
"github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/addrs"
@ -232,98 +231,3 @@ func TestModuleOverrideResourceFQNs(t *testing.T) {
t.Fatalf("wrong result: found provider config ref %s, expected nil", got.ProviderConfigRef) t.Fatalf("wrong result: found provider config ref %s, expected nil", got.ProviderConfigRef)
} }
} }
func TestMergeProviderVersionConstraints(t *testing.T) {
v1, _ := version.NewConstraint("1.0.0")
vc1 := VersionConstraint{
Required: v1,
}
v2, _ := version.NewConstraint("2.0.0")
vc2 := VersionConstraint{
Required: v2,
}
tests := map[string]struct {
Input map[string]ProviderRequirements
Override []*RequiredProvider
Want map[string]ProviderRequirements
}{
"basic merge": {
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.Provider{Type: "random"},
VersionConstraints: []VersionConstraint{},
},
},
[]*RequiredProvider{
&RequiredProvider{
Name: "null",
Requirement: VersionConstraint{},
},
},
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.Provider{Type: "random"},
VersionConstraints: []VersionConstraint{},
},
"null": ProviderRequirements{
Type: addrs.NewDefaultProvider("null"),
VersionConstraints: []VersionConstraint{
VersionConstraint{
Required: version.Constraints(nil),
DeclRange: hcl.Range{},
},
},
},
},
},
"override version constraint": {
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.Provider{Type: "random"},
VersionConstraints: []VersionConstraint{vc1},
},
},
[]*RequiredProvider{
&RequiredProvider{
Name: "random",
Requirement: vc2,
},
},
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.NewDefaultProvider("random"),
VersionConstraints: []VersionConstraint{vc2},
},
},
},
"merge with source constraint": {
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.Provider{Type: "random"},
VersionConstraints: []VersionConstraint{vc1},
},
},
[]*RequiredProvider{
&RequiredProvider{
Name: "random",
Source: Source{SourceStr: "hashicorp/random"},
Requirement: vc2,
},
},
map[string]ProviderRequirements{
"random": ProviderRequirements{
Type: addrs.NewDefaultProvider("random"),
VersionConstraints: []VersionConstraint{vc2},
},
},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
mergeProviderVersionConstraints(test.Input, test.Override)
assertResultDeepEqual(t, test.Input, test.Want)
})
}
}

View File

@ -1,6 +1,7 @@
package configs package configs
import ( import (
"strings"
"testing" "testing"
"github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/addrs"
@ -111,3 +112,94 @@ func TestProviderForLocalConfig(t *testing.T) {
t.Fatalf("wrong result! got %#v, want %#v\n", got, want) t.Fatalf("wrong result! got %#v, want %#v\n", got, want)
} }
} }
// At most one required_providers block per module is permitted.
func TestModule_required_providers_multiple(t *testing.T) {
_, diags := testModuleFromDir("testdata/invalid-modules/multiple-required-providers")
if !diags.HasErrors() {
t.Fatal("module should have error diags, but does not")
}
want := `Duplicate required providers configuration`
if got := diags.Error(); !strings.Contains(got, want) {
t.Fatalf("expected error to contain %q\nerror was:\n%s", want, got)
}
}
// A module may have required_providers configured in files loaded later than
// resources. These provider settings should still be reflected in the
// resources' configuration.
func TestModule_required_providers_after_resource(t *testing.T) {
mod, diags := testModuleFromDir("testdata/valid-modules/required-providers-after-resource")
if diags.HasErrors() {
t.Fatal(diags.Error())
}
want := addrs.NewProvider(addrs.DefaultRegistryHost, "foo", "test")
req, exists := mod.ProviderRequirements.RequiredProviders["test"]
if !exists {
t.Fatal("no provider requirements found for \"test\"")
}
if req.Type != want {
t.Errorf("wrong provider addr for \"test\"\ngot: %s\nwant: %s",
req.Type, want,
)
}
if got := mod.ManagedResources["test_instance.my-instance"].Provider; !got.Equals(want) {
t.Errorf("wrong provider addr for \"test_instance.my-instance\"\ngot: %s\nwant: %s",
got, want,
)
}
}
// We support overrides for required_providers blocks, which should replace the
// entire block for each provider localname, leaving other blocks unaffected.
// This should also be reflected in any resources in the module using this
// provider.
func TestModule_required_provider_overrides(t *testing.T) {
mod, diags := testModuleFromDir("testdata/valid-modules/required-providers-overrides")
if diags.HasErrors() {
t.Fatal(diags.Error())
}
// The foo provider and resource should be unaffected
want := addrs.NewProvider(addrs.DefaultRegistryHost, "acme", "foo")
req, exists := mod.ProviderRequirements.RequiredProviders["foo"]
if !exists {
t.Fatal("no provider requirements found for \"foo\"")
}
if req.Type != want {
t.Errorf("wrong provider addr for \"foo\"\ngot: %s\nwant: %s",
req.Type, want,
)
}
if got := mod.ManagedResources["foo_thing.ft"].Provider; !got.Equals(want) {
t.Errorf("wrong provider addr for \"foo_thing.ft\"\ngot: %s\nwant: %s",
got, want,
)
}
// The bar provider and resource should be using the override config
want = addrs.NewProvider(addrs.DefaultRegistryHost, "blorp", "bar")
req, exists = mod.ProviderRequirements.RequiredProviders["bar"]
if !exists {
t.Fatal("no provider requirements found for \"bar\"")
}
if req.Type != want {
t.Errorf("wrong provider addr for \"bar\"\ngot: %s\nwant: %s",
req.Type, want,
)
}
if gotVer, wantVer := req.Requirement.Required.String(), "~>2.0.0"; gotVer != wantVer {
t.Errorf("wrong provider version constraint for \"bar\"\ngot: %s\nwant: %s",
gotVer, wantVer,
)
}
if got := mod.ManagedResources["bar_thing.bt"].Provider; !got.Equals(want) {
t.Errorf("wrong provider addr for \"bar_thing.bt\"\ngot: %s\nwant: %s",
got, want,
)
}
}

View File

@ -75,7 +75,7 @@ func (p *Parser) loadConfigFile(path string, override bool) (*File, hcl.Diagnost
case "required_providers": case "required_providers":
reqs, reqsDiags := decodeRequiredProvidersBlock(innerBlock) reqs, reqsDiags := decodeRequiredProvidersBlock(innerBlock)
diags = append(diags, reqsDiags...) diags = append(diags, reqsDiags...)
file.RequiredProviders = append(file.RequiredProviders, reqs...) file.RequiredProviders = append(file.RequiredProviders, reqs)
case "provider_meta": case "provider_meta":
providerCfg, cfgDiags := decodeProviderMetaBlock(innerBlock) providerCfg, cfgDiags := decodeProviderMetaBlock(innerBlock)

View File

@ -12,41 +12,40 @@ import (
// parent. // parent.
type RequiredProvider struct { type RequiredProvider struct {
Name string Name string
Source Source Type addrs.Provider
Requirement VersionConstraint Requirement VersionConstraint
}
type Source struct {
SourceStr string
DeclRange hcl.Range DeclRange hcl.Range
} }
// ProviderRequirements represents provider version constraints from type RequiredProviders struct {
// required_providers blocks. RequiredProviders map[string]*RequiredProvider
type ProviderRequirements struct { DeclRange hcl.Range
Type addrs.Provider
VersionConstraints []VersionConstraint
} }
func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Diagnostics) { func decodeRequiredProvidersBlock(block *hcl.Block) (*RequiredProviders, hcl.Diagnostics) {
attrs, diags := block.Body.JustAttributes() attrs, diags := block.Body.JustAttributes()
var reqs []*RequiredProvider ret := &RequiredProviders{
RequiredProviders: make(map[string]*RequiredProvider),
DeclRange: block.DefRange,
}
for name, attr := range attrs { for name, attr := range attrs {
expr, err := attr.Expr.Value(nil) expr, err := attr.Expr.Value(nil)
if err != nil { if err != nil {
diags = append(diags, err...) diags = append(diags, err...)
} }
rp := &RequiredProvider{
Name: name,
DeclRange: attr.Expr.Range(),
}
switch { switch {
case expr.Type().IsPrimitiveType(): case expr.Type().IsPrimitiveType():
vc, reqDiags := decodeVersionConstraint(attr) vc, reqDiags := decodeVersionConstraint(attr)
diags = append(diags, reqDiags...) diags = append(diags, reqDiags...)
reqs = append(reqs, &RequiredProvider{ rp.Requirement = vc
Name: name,
Requirement: vc,
})
case expr.Type().IsObjectType(): case expr.Type().IsObjectType():
ret := &RequiredProvider{Name: name}
if expr.Type().HasAttribute("version") { if expr.Type().HasAttribute("version") {
vc := VersionConstraint{ vc := VersionConstraint{
DeclRange: attr.Range, DeclRange: attr.Range,
@ -64,25 +63,55 @@ func decodeRequiredProvidersBlock(block *hcl.Block) ([]*RequiredProvider, hcl.Di
}) })
} else { } else {
vc.Required = constraints vc.Required = constraints
ret.Requirement = vc rp.Requirement = vc
} }
} }
if expr.Type().HasAttribute("source") { if expr.Type().HasAttribute("source") {
ret.Source.SourceStr = expr.GetAttr("source").AsString() fqn, sourceDiags := addrs.ParseProviderSourceString(expr.GetAttr("source").AsString())
ret.Source.DeclRange = attr.Range
if sourceDiags.HasErrors() {
hclDiags := sourceDiags.ToHCL()
// The diagnostics from ParseProviderSourceString don't contain
// source location information because it has no context to compute
// them from, and so we'll add those in quickly here before we
// return.
for _, diag := range hclDiags {
if diag.Subject == nil {
diag.Subject = attr.Expr.Range().Ptr()
} }
reqs = append(reqs, ret) }
diags = append(diags, hclDiags...)
} else {
rp.Type = fqn
}
}
default: default:
// should not happen // should not happen
diags = append(diags, &hcl.Diagnostic{ diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError, Severity: hcl.DiagError,
Summary: "Invalid provider_requirements syntax", Summary: "Invalid required_providers syntax",
Detail: "provider_requirements entries must be strings or objects.", Detail: "required_providers entries must be strings or objects.",
Subject: attr.Expr.Range().Ptr(), Subject: attr.Expr.Range().Ptr(),
}) })
reqs = append(reqs, &RequiredProvider{Name: name}) }
return reqs, diags
if rp.Type.IsZero() {
pType, err := addrs.ParseProviderPart(rp.Name)
if err != nil {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Invalid provider name",
Detail: err.Error(),
Subject: attr.Expr.Range().Ptr(),
})
} else {
rp.Type = addrs.ImpliedProviderForUnqualifiedType(pType)
} }
} }
return reqs, diags
ret.RequiredProviders[rp.Name] = rp
}
return ret, diags
} }

View File

@ -1,8 +1,6 @@
package configs package configs
import ( import (
"fmt"
"sort"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -10,6 +8,7 @@ import (
version "github.com/hashicorp/go-version" version "github.com/hashicorp/go-version"
"github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hcltest" "github.com/hashicorp/hcl/v2/hcltest"
"github.com/hashicorp/terraform/addrs"
"github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty"
) )
@ -19,18 +18,35 @@ var (
if x.Name != y.Name { if x.Name != y.Name {
return false return false
} }
if x.Source != y.Source { if x.Type != y.Type {
return false return false
} }
if x.Requirement.Required.String() != y.Requirement.Required.String() { if x.Requirement.Required.String() != y.Requirement.Required.String() {
return false return false
} }
if x.DeclRange != y.DeclRange {
return false
}
return true return true
}) })
blockRange = hcl.Range{
Filename: "mock.tf",
Start: hcl.Pos{Line: 3, Column: 12, Byte: 27},
End: hcl.Pos{Line: 3, Column: 19, Byte: 34},
}
mockRange = hcl.Range{
Filename: "MockExprLiteral",
}
) )
func TestDecodeRequiredProvidersBlock_legacy(t *testing.T) { func TestDecodeRequiredProvidersBlock(t *testing.T) {
block := &hcl.Block{ tests := map[string]struct {
Block *hcl.Block
Want *RequiredProviders
Error string
}{
"legacy": {
Block: &hcl.Block{
Type: "required_providers", Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{ Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{ Attributes: hcl.Attributes{
@ -40,33 +56,22 @@ func TestDecodeRequiredProvidersBlock_legacy(t *testing.T) {
}, },
}, },
}), }),
} DefRange: blockRange,
},
want := &RequiredProvider{ Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"default": {
Name: "default", Name: "default",
Type: addrs.NewDefaultProvider("default"),
Requirement: testVC("1.0.0"), Requirement: testVC("1.0.0"),
} DeclRange: mockRange,
},
got, diags := decodeRequiredProvidersBlock(block) },
if diags.HasErrors() { DeclRange: blockRange,
t.Fatalf("unexpected error") },
} },
if len(got) != 1 { "provider source": {
t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) Block: &hcl.Block{
}
if !cmp.Equal(got[0], want, ignoreUnexported, comparer) {
t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer))
}
}
func TestDecodeRequiredProvidersBlock_provider_source(t *testing.T) {
mockRange := hcl.Range{
Filename: "mock.tf",
Start: hcl.Pos{Line: 3, Column: 12, Byte: 27},
End: hcl.Pos{Line: 3, Column: 19, Byte: 34},
}
block := &hcl.Block{
Type: "required_providers", Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{ Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{ Attributes: hcl.Attributes{
@ -76,31 +81,25 @@ func TestDecodeRequiredProvidersBlock_provider_source(t *testing.T) {
"source": cty.StringVal("mycloud/test"), "source": cty.StringVal("mycloud/test"),
"version": cty.StringVal("2.0.0"), "version": cty.StringVal("2.0.0"),
})), })),
Range: mockRange,
}, },
}, },
}), }),
} DefRange: blockRange,
},
want := &RequiredProvider{ Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"my_test": {
Name: "my_test", Name: "my_test",
Source: Source{SourceStr: "mycloud/test", DeclRange: mockRange}, Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"),
Requirement: testVC("2.0.0"), Requirement: testVC("2.0.0"),
} DeclRange: mockRange,
got, diags := decodeRequiredProvidersBlock(block) },
if diags.HasErrors() { },
t.Fatalf("unexpected error") DeclRange: blockRange,
} },
if len(got) != 1 { },
t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) "mixed": {
} Block: &hcl.Block{
if !cmp.Equal(got[0], want, ignoreUnexported, comparer) {
t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], want, ignoreUnexported, comparer))
}
}
func TestDecodeRequiredProvidersBlock_mixed(t *testing.T) {
block := &hcl.Block{
Type: "required_providers", Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{ Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{ Attributes: hcl.Attributes{
@ -117,41 +116,112 @@ func TestDecodeRequiredProvidersBlock_mixed(t *testing.T) {
}, },
}, },
}), }),
} DefRange: blockRange,
},
want := []*RequiredProvider{ Want: &RequiredProviders{
{ RequiredProviders: map[string]*RequiredProvider{
"legacy": {
Name: "legacy", Name: "legacy",
Type: addrs.NewDefaultProvider("legacy"),
Requirement: testVC("1.0.0"), Requirement: testVC("1.0.0"),
DeclRange: mockRange,
}, },
{ "my_test": {
Name: "my_test", Name: "my_test",
Source: Source{SourceStr: "mycloud/test", DeclRange: hcl.Range{}}, Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"),
Requirement: testVC("2.0.0"), Requirement: testVC("2.0.0"),
DeclRange: mockRange,
}, },
} },
DeclRange: blockRange,
got, diags := decodeRequiredProvidersBlock(block) },
},
sort.SliceStable(got, func(i, j int) bool { "version-only block": {
return got[i].Name < got[j].Name Block: &hcl.Block{
}) Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
if diags.HasErrors() { Attributes: hcl.Attributes{
t.Fatalf("unexpected error") "test": {
} Name: "test",
if len(got) != 2 { Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
t.Fatalf("wrong number of results, got %d, wanted 2", len(got)) "version": cty.StringVal("~>2.0.0"),
} })),
for i, rp := range want { },
if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) { },
t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer)) }),
} DefRange: blockRange,
} },
} Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
func TestDecodeRequiredProvidersBlock_version_error(t *testing.T) { "test": {
block := &hcl.Block{ Name: "test",
Type: addrs.NewDefaultProvider("test"),
Requirement: testVC("~>2.0.0"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
},
"invalid source": {
Block: &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"my_test": {
Name: "my_test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"source": cty.StringVal("some/invalid/provider/source/test"),
"version": cty.StringVal("~>2.0.0"),
})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"my_test": {
Name: "my_test",
Type: addrs.Provider{},
Requirement: testVC("~>2.0.0"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
Error: "Invalid provider source string",
},
"localname is invalid provider name": {
Block: &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"my_test": {
Name: "my_test",
Expr: hcltest.MockExprLiteral(cty.ObjectVal(map[string]cty.Value{
"version": cty.StringVal("~>2.0.0"),
})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"my_test": {
Name: "my_test",
Type: addrs.Provider{},
Requirement: testVC("~>2.0.0"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
Error: "Invalid provider name",
},
"version constraint error": {
Block: &hcl.Block{
Type: "required_providers", Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{ Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{ Attributes: hcl.Attributes{
@ -164,28 +234,65 @@ func TestDecodeRequiredProvidersBlock_version_error(t *testing.T) {
}, },
}, },
}), }),
} DefRange: blockRange,
},
want := []*RequiredProvider{ Want: &RequiredProviders{
{ RequiredProviders: map[string]*RequiredProvider{
"my_test": {
Name: "my_test", Name: "my_test",
Source: Source{SourceStr: "mycloud/test", DeclRange: hcl.Range{}}, Type: addrs.NewProvider(addrs.DefaultRegistryHost, "mycloud", "test"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
Error: "Invalid version constraint",
},
"invalid required_providers attribute value": {
Block: &hcl.Block{
Type: "required_providers",
Body: hcltest.MockBody(&hcl.BodyContent{
Attributes: hcl.Attributes{
"test": {
Name: "test",
Expr: hcltest.MockExprLiteral(cty.ListVal([]cty.Value{cty.StringVal("2.0.0")})),
},
},
}),
DefRange: blockRange,
},
Want: &RequiredProviders{
RequiredProviders: map[string]*RequiredProvider{
"test": {
Name: "test",
Type: addrs.NewDefaultProvider("test"),
DeclRange: mockRange,
},
},
DeclRange: blockRange,
},
Error: "Invalid required_providers syntax",
}, },
} }
got, diags := decodeRequiredProvidersBlock(block) for name, test := range tests {
if !diags.HasErrors() { t.Run(name, func(t *testing.T) {
t.Fatalf("expected error, got success") got, diags := decodeRequiredProvidersBlock(test.Block)
} else { if diags.HasErrors() {
fmt.Printf(diags[0].Summary) if test.Error == "" {
t.Fatalf("unexpected error")
} }
if len(got) != 1 { if gotErr := diags[0].Summary; gotErr != test.Error {
t.Fatalf("wrong number of results, got %d, wanted 1", len(got)) t.Errorf("wrong error, got %q, want %q", gotErr, test.Error)
} }
for i, rp := range want { } else if test.Error != "" {
if !cmp.Equal(got[i], rp, ignoreUnexported, comparer) { t.Fatalf("expected error")
t.Fatalf("wrong result:\n %s", cmp.Diff(got[0], rp, ignoreUnexported, comparer))
} }
if !cmp.Equal(got, test.Want, ignoreUnexported, comparer) {
t.Fatalf("wrong result:\n %s", cmp.Diff(got, test.Want, ignoreUnexported, comparer))
}
})
} }
} }

View File

@ -0,0 +1,7 @@
terraform {
required_providers {
bar = {
version = "~>1.0.0"
}
}
}

View File

@ -0,0 +1,7 @@
terraform {
required_providers {
foo = {
version = "~>2.0.0"
}
}
}

View File

@ -0,0 +1,3 @@
resource test_instance "my-instance" {
provider = test
}

View File

@ -0,0 +1,8 @@
terraform {
required_providers {
test = {
source = "foo/test"
version = "~>1.0.0"
}
}
}

View File

@ -0,0 +1,9 @@
terraform {
required_providers {
bar = {
source = "blorp/bar"
version = "~>2.0.0"
}
}
}

View File

@ -0,0 +1,7 @@
resource bar_thing "bt" {
provider = bar
}
resource foo_thing "ft" {
provider = foo
}

View File

@ -0,0 +1,11 @@
terraform {
required_providers {
bar = {
source = "acme/bar"
}
foo = {
source = "acme/foo"
}
}
}