diff --git a/state/cache.go b/state/cache.go index 8d562dea2..375b43e7a 100644 --- a/state/cache.go +++ b/state/cache.go @@ -1,6 +1,8 @@ package state import ( + "reflect" + "github.com/hashicorp/terraform/terraform" ) @@ -10,7 +12,8 @@ type CacheState struct { Cache CacheStateCache Durable CacheStateDurable - state *terraform.State + refreshResult CacheRefreshResult + state *terraform.State } // StateReader impl. @@ -26,6 +29,7 @@ func (s *CacheState) WriteState(state *terraform.State) error { return err } + s.state = state return s.Cache.PersistState() } @@ -38,9 +42,79 @@ func (s *CacheState) WriteState(state *terraform.State) error { // // StateRefresher impl. func (s *CacheState) RefreshState() error { + // Refresh the durable state + if err := s.Durable.RefreshState(); err != nil { + return err + } + + // Refresh the cached state + if err := s.Cache.RefreshState(); err != nil { + return err + } + + // Handle the matrix of cases that can happen when comparing these + // two states. + cached := s.Cache.State() + durable := s.Durable.State() + switch { + case cached == nil && durable == nil: + // Initialized + s.refreshResult = CacheRefreshInit + case cached != nil && durable == nil: + // Cache is newer than remote. Not a big deal, user can just + // persist to get correct state. + s.refreshResult = CacheRefreshLocalNewer + case cached == nil && durable != nil: + // Cache should be updated since the remote is set but cache isn't + s.refreshResult = CacheRefreshUpdateLocal + case durable.Serial < cached.Serial: + // Cache is newer than remote. Not a big deal, user can just + // persist to get correct state. + s.refreshResult = CacheRefreshLocalNewer + case durable.Serial > cached.Serial: + // Cache should be updated since the remote is newer + s.refreshResult = CacheRefreshUpdateLocal + case durable.Serial == cached.Serial: + // They're supposedly equal, verify. + if reflect.DeepEqual(cached, durable) { + // Hashes are the same, everything is great + s.refreshResult = CacheRefreshNoop + break + } + + // This is very bad. This means we have two state files that + // have the same serial but have a different hash. We can't + // reconcile this. The most likely cause is parallel apply + // operations. + s.refreshResult = CacheRefreshConflict + + // Return early so we don't updtae the state + return nil + default: + panic("unhandled cache refresh state") + } + + if s.refreshResult == CacheRefreshUpdateLocal { + if err := s.Cache.WriteState(durable); err != nil { + s.refreshResult = CacheRefreshNoop + return err + } + if err := s.Cache.PersistState(); err != nil { + s.refreshResult = CacheRefreshNoop + return err + } + } + + s.state = cached + return nil } +// RefreshResult returns the result of the last refresh. +func (s *CacheState) RefreshResult() CacheRefreshResult { + return s.refreshResult +} + // PersistState takes the local cache, assuming it is newer than the remote // state, and persists it to the durable storage. If you want to challenge the // assumption that the local state is the latest, call a RefreshState prior @@ -61,11 +135,57 @@ type CacheStateCache interface { StateReader StateWriter StatePersister + StateRefresher } // CacheStateDurable is the meta-interface that must be implemented for // the durable storage for CacheState. type CacheStateDurable interface { + StateReader StateWriter StatePersister + StateRefresher } + +// CacheRefreshResult is used to explain the result of the previous +// RefreshState for a CacheState. +type CacheRefreshResult int + +const ( + // CacheRefreshNoop indicates nothing has happened, + // but that does not indicate an error. Everything is + // just up to date. (Push/Pull) + CacheRefreshNoop CacheRefreshResult = iota + + // CacheRefreshInit indicates that there is no local or + // remote state, and that the state was initialized + CacheRefreshInit + + // CacheRefreshUpdateLocal indicates the local state + // was updated. (Pull) + CacheRefreshUpdateLocal + + // CacheRefreshUpdateRemote indicates the remote state + // was updated. (Push) + CacheRefreshUpdateRemote + + // CacheRefreshLocalNewer means the pull was a no-op + // because the local state is newer than that of the + // server. This means a Push should take place. (Pull) + CacheRefreshLocalNewer + + // CacheRefreshRemoteNewer means the push was a no-op + // because the remote state is newer than that of the + // local state. This means a Pull should take place. + // (Push) + CacheRefreshRemoteNewer + + // CacheRefreshConflict means that the push or pull + // was a no-op because there is a conflict. This means + // there are multiple state definitions at the same + // serial number with different contents. This requires + // an operator to intervene and resolve the conflict. + // Shame on the user for doing concurrent apply. + // (Push/Pull) + CacheRefreshConflict +) diff --git a/state/cache_test.go b/state/cache_test.go new file mode 100644 index 000000000..c99aeb826 --- /dev/null +++ b/state/cache_test.go @@ -0,0 +1,58 @@ +package state + +import ( + "os" + "reflect" + "testing" +) + +func TestCacheState(t *testing.T) { + cache := testLocalState(t) + durable := testLocalState(t) + defer os.Remove(cache.Path) + defer os.Remove(durable.Path) + + TestState(t, &CacheState{ + Cache: cache, + Durable: durable, + }) +} + +func TestCacheState_persistDurable(t *testing.T) { + cache := testLocalState(t) + durable := testLocalState(t) + defer os.Remove(cache.Path) + defer os.Remove(durable.Path) + + cs := &CacheState{ + Cache: cache, + Durable: durable, + } + + state := cache.State() + state.Modules = nil + if err := cs.WriteState(state); err != nil { + t.Fatalf("err: %s", err) + } + + if reflect.DeepEqual(cache.State(), durable.State()) { + t.Fatal("cache and durable should not be the same") + } + + if err := cs.PersistState(); err != nil { + t.Fatalf("err: %s", err) + } + + if !reflect.DeepEqual(cache.State(), durable.State()) { + t.Fatalf( + "cache and durable should be the same\n\n%#v\n\n%#v", + cache.State(), durable.State()) + } +} + +func TestCacheState_impl(t *testing.T) { + var _ StateReader = new(CacheState) + var _ StateWriter = new(CacheState) + var _ StatePersister = new(CacheState) + var _ StateRefresher = new(CacheState) +} diff --git a/state/local_test.go b/state/local_test.go index b77fced92..631c97948 100644 --- a/state/local_test.go +++ b/state/local_test.go @@ -9,21 +9,9 @@ import ( ) func TestLocalState(t *testing.T) { - f, err := ioutil.TempFile("", "tf") - if err != nil { - t.Fatalf("err: %s", err) - } - defer os.Remove(f.Name()) - - err = terraform.WriteState(TestStateInitial, f) - f.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - TestState(t, &LocalState{ - Path: f.Name(), - }) + ls := testLocalState(t) + defer os.Remove(ls.Path) + TestState(t, ls) } func TestLocalState_impl(t *testing.T) { @@ -32,3 +20,23 @@ func TestLocalState_impl(t *testing.T) { var _ StatePersister = new(LocalState) var _ StateRefresher = new(LocalState) } + +func testLocalState(t *testing.T) *LocalState { + f, err := ioutil.TempFile("", "tf") + if err != nil { + t.Fatalf("err: %s", err) + } + + err = terraform.WriteState(TestStateInitial(), f) + f.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + ls := &LocalState{Path: f.Name()} + if err := ls.RefreshState(); err != nil { + t.Fatalf("bad: %s", err) + } + + return ls +} diff --git a/state/remote/state_test.go b/state/remote/state_test.go index 18c76d449..08b51439b 100644 --- a/state/remote/state_test.go +++ b/state/remote/state_test.go @@ -8,7 +8,7 @@ import ( func TestState(t *testing.T) { s := &State{Client: new(InmemClient)} - s.WriteState(state.TestStateInitial) + s.WriteState(state.TestStateInitial()) if err := s.PersistState(); err != nil { t.Fatalf("err: %s", err) } diff --git a/state/testing.go b/state/testing.go index 5233d06f0..6b3f6b4ef 100644 --- a/state/testing.go +++ b/state/testing.go @@ -1,25 +1,13 @@ package state import ( + "bytes" "reflect" "testing" "github.com/hashicorp/terraform/terraform" ) -// TestStateInitial is the initial state that a State should have -// for TestState. -var TestStateInitial *terraform.State = &terraform.State{ - Modules: []*terraform.ModuleState{ - &terraform.ModuleState{ - Path: []string{"root", "child"}, - Outputs: map[string]string{ - "foo": "bar", - }, - }, - }, -} - // TestState is a helper for testing state implementations. It is expected // that the given implementation is pre-loaded with the TestStateInitial // state. @@ -37,11 +25,12 @@ func TestState(t *testing.T, s interface{}) { } // current will track our current state - current := TestStateInitial + current := TestStateInitial() + current.Serial++ // Check that the initial state is correct if !reflect.DeepEqual(reader.State(), current) { - t.Fatalf("not initial: %#v", reader.State()) + t.Fatalf("not initial: %#v\n\n%#v", reader.State(), current) } // Write a new state and verify that we have it @@ -58,7 +47,7 @@ func TestState(t *testing.T, s interface{}) { } if actual := reader.State(); !reflect.DeepEqual(actual, current) { - t.Fatalf("bad: %#v", actual) + t.Fatalf("bad: %#v\n\n%#v", actual, current) } } @@ -75,8 +64,30 @@ func TestState(t *testing.T, s interface{}) { } } - if actual := reader.State(); !reflect.DeepEqual(actual, current) { - t.Fatalf("bad: %#v", actual) + // Just set the serials the same... Then compare. + actual := reader.State() + actual.Serial = current.Serial + if !reflect.DeepEqual(actual, current) { + t.Fatalf("bad: %#v\n\n%#v", actual, current) } } } + +// TestStateInitial is the initial state that a State should have +// for TestState. +func TestStateInitial() *terraform.State { + initial := &terraform.State{ + Modules: []*terraform.ModuleState{ + &terraform.ModuleState{ + Path: []string{"root", "child"}, + Outputs: map[string]string{ + "foo": "bar", + }, + }, + }, + } + + var scratch bytes.Buffer + terraform.WriteState(initial, &scratch) + return initial +}