diff --git a/backend/local/backend.go b/backend/local/backend.go index 966f5e82d..afb1af9eb 100644 --- a/backend/local/backend.go +++ b/backend/local/backend.go @@ -82,6 +82,10 @@ type Local struct { schema *schema.Backend opLock sync.Mutex once sync.Once + + // workingDir is where the State* paths should be relative to. + // This is currently only used for tests. + workingDir string } func (b *Local) Input( @@ -140,7 +144,7 @@ func (b *Local) States() ([]string, string, error) { current = name } - entries, err := ioutil.ReadDir(DefaultEnvDir) + entries, err := ioutil.ReadDir(filepath.Join(b.workingDir, DefaultEnvDir)) // no error if there's no envs configured if os.IsNotExist(err) { return envs, current, nil @@ -182,14 +186,19 @@ func (b *Local) DeleteState(name string) error { return errors.New("cannot delete default state") } + _, current, err := b.States() + if err != nil { + return err + } + // if we're deleting the current state, we change back to the default - if name == b.currentState { + if name == current { if err := b.ChangeState(backend.DefaultStateName); err != nil { return err } } - return os.RemoveAll(filepath.Join(DefaultEnvDir, name)) + return os.RemoveAll(filepath.Join(b.workingDir, DefaultEnvDir, name)) } // Change to the named state, creating it if it doesn't exist. @@ -231,13 +240,13 @@ func (b *Local) ChangeState(name string) error { } } - err = os.MkdirAll(DefaultDataDir, 0755) + err = os.MkdirAll(filepath.Join(b.workingDir, DefaultDataDir), 0755) if err != nil { return err } err = ioutil.WriteFile( - filepath.Join(DefaultDataDir, DefaultEnvFile), + filepath.Join(b.workingDir, DefaultDataDir, DefaultEnvFile), []byte(name), 0644, ) @@ -412,7 +421,7 @@ func (b *Local) statePath() (string, error) { path := DefaultStateFilename if current != backend.DefaultStateName && current != "" { - path = filepath.Join(DefaultEnvDir, b.currentState, DefaultStateFilename) + path = filepath.Join(b.workingDir, DefaultEnvDir, b.currentState, DefaultStateFilename) } return path, nil } @@ -430,7 +439,7 @@ func (b *Local) createState(name string) error { } } - err = os.MkdirAll(filepath.Join(DefaultEnvDir, name), 0755) + err = os.MkdirAll(filepath.Join(b.workingDir, DefaultEnvDir, name), 0755) if err != nil { return err } @@ -442,7 +451,7 @@ func (b *Local) createState(name string) error { // configuration files. // If there are no configured environments, currentStateName returns "default" func (b *Local) currentStateName() (string, error) { - contents, err := ioutil.ReadFile(filepath.Join(DefaultDataDir, DefaultEnvFile)) + contents, err := ioutil.ReadFile(filepath.Join(b.workingDir, DefaultDataDir, DefaultEnvFile)) if os.IsNotExist(err) { return backend.DefaultStateName, nil } diff --git a/backend/local/backend_test.go b/backend/local/backend_test.go index c97a8e7b9..a72a0fcc0 100644 --- a/backend/local/backend_test.go +++ b/backend/local/backend_test.go @@ -1,8 +1,10 @@ package local import ( + "fmt" "io/ioutil" "os" + "path/filepath" "reflect" "strings" "testing" @@ -15,6 +17,7 @@ func TestLocal_impl(t *testing.T) { var _ backend.Enhanced = new(Local) var _ backend.Local = new(Local) var _ backend.CLI = new(Local) + var _ backend.MultiState = new(Local) } func checkState(t *testing.T, path, expected string) { @@ -53,7 +56,7 @@ func TestLocal_addAndRemoveStates(t *testing.T) { } if !reflect.DeepEqual(states, expectedStates) { - t.Fatal("expected []string{%q}, got %q", dflt, states) + t.Fatalf("expected []string{%q}, got %q", dflt, states) } expectedA := "test_A" @@ -62,6 +65,9 @@ func TestLocal_addAndRemoveStates(t *testing.T) { } states, current, err = b.States() + if err != nil { + t.Fatal(err) + } if current != expectedA { t.Fatalf("expected %q, got %q", expectedA, current) } @@ -77,6 +83,9 @@ func TestLocal_addAndRemoveStates(t *testing.T) { } states, current, err = b.States() + if err != nil { + t.Fatal(err) + } if current != expectedB { t.Fatalf("expected %q, got %q", expectedB, current) } @@ -91,6 +100,9 @@ func TestLocal_addAndRemoveStates(t *testing.T) { } states, current, err = b.States() + if err != nil { + t.Fatal(err) + } if current != expectedB { t.Fatalf("expected %q, got %q", dflt, current) } @@ -105,6 +117,9 @@ func TestLocal_addAndRemoveStates(t *testing.T) { } states, current, err = b.States() + if err != nil { + t.Fatal(err) + } if current != dflt { t.Fatalf("expected %q, got %q", dflt, current) } @@ -119,6 +134,101 @@ func TestLocal_addAndRemoveStates(t *testing.T) { } } +// verify the behavior with a backend that doesn't support multiple states +func TestLocal_noMultiStateBackend(t *testing.T) { + type noMultiState struct { + backend.Backend + } + + b := &Local{ + Backend: &noMultiState{}, + } + + _, _, err := b.States() + if err != ErrEnvNotSupported { + t.Fatal("backend does not support environments.", err) + } + + err = b.ChangeState("test") + if err != ErrEnvNotSupported { + t.Fatal("backend does not support environments.", err) + } + + err = b.ChangeState("test") + if err != ErrEnvNotSupported { + t.Fatal("backend does not support environments.", err) + } +} + +// verify that the MultiState methods are dispatched to the correct Backend. +func TestLocal_multiStateBackend(t *testing.T) { + defer testTmpDir(t)() + + dflt := backend.DefaultStateName + expectedStates := []string{dflt} + + // make a second tmp dir for the sub-Backend. + // we verify the corret backend was called by checking the paths. + tmp, err := ioutil.TempDir("", "tf") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmp) + + fmt.Println("second tmp:", tmp) + + b := &Local{ + Backend: &Local{ + workingDir: tmp, + }, + } + + testA := "test_A" + if err := b.ChangeState(testA); err != nil { + t.Fatal(err) + } + + states, current, err := b.States() + if err != nil { + t.Fatal(err) + } + if current != testA { + t.Fatalf("expected %q, got %q", testA, current) + } + + expectedStates = append(expectedStates, testA) + if !reflect.DeepEqual(states, expectedStates) { + t.Fatalf("expected %q, got %q", expectedStates, states) + } + + // verify that no environment paths were created for the top-level Backend + if _, err := os.Stat(DefaultDataDir); !os.IsNotExist(err) { + t.Fatal("remote state operations should not have written local files") + } + + if _, err := os.Stat(filepath.Join(DefaultEnvDir, testA)); !os.IsNotExist(err) { + t.Fatal("remote state operations should not have written local files") + } + + // remove the new state + if err := b.DeleteState(testA); err != nil { + t.Fatal(err) + } + + states, current, err = b.States() + if err != nil { + t.Fatal(err) + } + if current != dflt { + t.Fatalf("expected %q, got %q", dflt, current) + } + + if !reflect.DeepEqual(states, expectedStates[:1]) { + t.Fatalf("expected %q, got %q", expectedStates, states) + } + +} + // change into a tmp dir and return a deferable func to change back and cleanup func testTmpDir(t *testing.T) func() { tmp, err := ioutil.TempDir("", "tf")