diff --git a/internal/terraform/context_apply_test.go b/internal/terraform/context_apply_test.go index 06ff0de10..09babdbe0 100644 --- a/internal/terraform/context_apply_test.go +++ b/internal/terraform/context_apply_test.go @@ -577,14 +577,18 @@ func TestContext2Apply_refCount(t *testing.T) { func TestContext2Apply_providerAlias(t *testing.T) { m := testModule(t, "apply-provider-alias") - p := testProvider("aws") - p.PlanResourceChangeFn = testDiffFn - p.ApplyResourceChangeFn = testApplyFn + + // Each provider instance must be completely independent to ensure that we + // are verifying the correct state of each. + p := func() (providers.Interface, error) { + p := testProvider("aws") + p.PlanResourceChangeFn = testDiffFn + p.ApplyResourceChangeFn = testApplyFn + return p, nil + } ctx := testContext2(t, &ContextOpts{ Providers: map[addrs.Provider]providers.Factory{ - addrs.NewDefaultProvider("aws"): func() (providers.Interface, error) { - return p, nil - }, + addrs.NewDefaultProvider("aws"): p, }, }) @@ -612,15 +616,18 @@ func TestContext2Apply_providerAlias(t *testing.T) { func TestContext2Apply_providerAliasConfigure(t *testing.T) { m := testModule(t, "apply-provider-alias-configure") - p2 := testProvider("another") - p2.ApplyResourceChangeFn = testApplyFn - p2.PlanResourceChangeFn = testDiffFn + // Each provider instance must be completely independent to ensure that we + // are verifying the correct state of each. + p := func() (providers.Interface, error) { + p := testProvider("another") + p.ApplyResourceChangeFn = testApplyFn + p.PlanResourceChangeFn = testDiffFn + return p, nil + } ctx := testContext2(t, &ContextOpts{ Providers: map[addrs.Provider]providers.Factory{ - addrs.NewDefaultProvider("another"): func() (providers.Interface, error) { - return p2, nil - }, + addrs.NewDefaultProvider("another"): p, }, }) @@ -633,17 +640,29 @@ func TestContext2Apply_providerAliasConfigure(t *testing.T) { // Configure to record calls AFTER Plan above var configCount int32 - p2.ConfigureProviderFn = func(req providers.ConfigureProviderRequest) (resp providers.ConfigureProviderResponse) { - atomic.AddInt32(&configCount, 1) + p = func() (providers.Interface, error) { + p := testProvider("another") + p.ApplyResourceChangeFn = testApplyFn + p.PlanResourceChangeFn = testDiffFn + p.ConfigureProviderFn = func(req providers.ConfigureProviderRequest) (resp providers.ConfigureProviderResponse) { + atomic.AddInt32(&configCount, 1) - foo := req.Config.GetAttr("foo").AsString() - if foo != "bar" { - resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("foo: %#v", foo)) + foo := req.Config.GetAttr("foo").AsString() + if foo != "bar" { + resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("foo: %#v", foo)) + } + + return } - - return + return p, nil } + ctx = testContext2(t, &ContextOpts{ + Providers: map[addrs.Provider]providers.Factory{ + addrs.NewDefaultProvider("another"): p, + }, + }) + state, diags := ctx.Apply(plan, m) if diags.HasErrors() { t.Fatalf("diags: %s", diags.Err()) @@ -1544,8 +1563,11 @@ func TestContext2Apply_destroySkipsCBD(t *testing.T) { func TestContext2Apply_destroyModuleVarProviderConfig(t *testing.T) { m := testModule(t, "apply-destroy-mod-var-provider-config") - p := testProvider("aws") - p.PlanResourceChangeFn = testDiffFn + p := func() (providers.Interface, error) { + p := testProvider("aws") + p.PlanResourceChangeFn = testDiffFn + return p, nil + } state := states.NewState() root := state.EnsureModule(addrs.RootModuleInstance) root.SetResourceInstanceCurrent( @@ -1558,9 +1580,7 @@ func TestContext2Apply_destroyModuleVarProviderConfig(t *testing.T) { ) ctx := testContext2(t, &ContextOpts{ Providers: map[addrs.Provider]providers.Factory{ - addrs.NewDefaultProvider("aws"): func() (providers.Interface, error) { - return p, nil - }, + addrs.NewDefaultProvider("aws"): p, }, }) diff --git a/internal/terraform/context_input_test.go b/internal/terraform/context_input_test.go index a9ec11b38..5216efb59 100644 --- a/internal/terraform/context_input_test.go +++ b/internal/terraform/context_input_test.go @@ -85,8 +85,7 @@ func TestContext2Input_provider(t *testing.T) { func TestContext2Input_providerMulti(t *testing.T) { m := testModule(t, "input-provider-multi") - p := testProvider("aws") - p.GetProviderSchemaResponse = getProviderSchemaResponseFromProviderSchema(&ProviderSchema{ + getProviderSchemaResponse := getProviderSchemaResponseFromProviderSchema(&ProviderSchema{ Provider: &configschema.Block{ Attributes: map[string]*configschema.Attribute{ "foo": { @@ -108,6 +107,17 @@ func TestContext2Input_providerMulti(t *testing.T) { }, }) + // In order to update the provider to check only the configure calls during + // apply, we will need to inject a new factory function after plan. We must + // use a closure around the factory, because in order for the inputs to + // work during apply we need to maintain the same context value, preventing + // us from assigning a new Providers map. + providerFactory := func() (providers.Interface, error) { + p := testProvider("aws") + p.GetProviderSchemaResponse = getProviderSchemaResponse + return p, nil + } + inp := &MockUIInput{ InputReturnMap: map[string]string{ "provider.aws.foo": "bar", @@ -118,7 +128,7 @@ func TestContext2Input_providerMulti(t *testing.T) { ctx := testContext2(t, &ContextOpts{ Providers: map[addrs.Provider]providers.Factory{ addrs.NewDefaultProvider("aws"): func() (providers.Interface, error) { - return p, nil + return providerFactory() }, }, UIInput: inp, @@ -134,12 +144,18 @@ func TestContext2Input_providerMulti(t *testing.T) { plan, diags := ctx.Plan(m, states.NewState(), DefaultPlanOpts) assertNoErrors(t, diags) - p.ConfigureProviderFn = func(req providers.ConfigureProviderRequest) (resp providers.ConfigureProviderResponse) { - lock.Lock() - defer lock.Unlock() - actual = append(actual, req.Config.GetAttr("foo").AsString()) - return + providerFactory = func() (providers.Interface, error) { + p := testProvider("aws") + p.GetProviderSchemaResponse = getProviderSchemaResponse + p.ConfigureProviderFn = func(req providers.ConfigureProviderRequest) (resp providers.ConfigureProviderResponse) { + lock.Lock() + defer lock.Unlock() + actual = append(actual, req.Config.GetAttr("foo").AsString()) + return + } + return p, nil } + if _, diags := ctx.Apply(plan, m); diags.HasErrors() { t.Fatalf("apply errors: %s", diags.Err()) }