From d1460d8c824fedfd9ce8e933a4b601717559f637 Mon Sep 17 00:00:00 2001 From: James Bardin Date: Mon, 3 Apr 2017 11:00:45 -0400 Subject: [PATCH] test LockWithContext --- state/inmem.go | 52 +++++++++++++++++++++++++++++++++++++++++++++ state/inmem_test.go | 36 +++++++++++++++++++++++++++++++ state/state_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+) diff --git a/state/inmem.go b/state/inmem.go index a930f78c7..2bbfb3d44 100644 --- a/state/inmem.go +++ b/state/inmem.go @@ -1,6 +1,10 @@ package state import ( + "errors" + "sync" + "time" + "github.com/hashicorp/terraform/terraform" ) @@ -34,3 +38,51 @@ func (s *InmemState) Lock(*LockInfo) (string, error) { func (s *InmemState) Unlock(string) error { return nil } + +// inmemLocker is an in-memory State implementation for testing locks. +type inmemLocker struct { + *InmemState + + mu sync.Mutex + lockInfo *LockInfo + // count the calls to Lock + lockCounter int +} + +func (s *inmemLocker) Lock(info *LockInfo) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.lockCounter++ + + lockErr := &LockError{ + Info: &LockInfo{}, + } + + if s.lockInfo != nil { + lockErr.Err = errors.New("state locked") + *lockErr.Info = *s.lockInfo + return "", lockErr + } + + info.Created = time.Now().UTC() + s.lockInfo = info + return s.lockInfo.ID, nil +} + +func (s *inmemLocker) Unlock(id string) error { + s.mu.Lock() + defer s.mu.Unlock() + + lockErr := &LockError{ + Info: &LockInfo{}, + } + + if id != s.lockInfo.ID { + lockErr.Err = errors.New("invalid lock id") + *lockErr.Info = *s.lockInfo + return lockErr + } + + s.lockInfo = nil + return nil +} diff --git a/state/inmem_test.go b/state/inmem_test.go index 885127122..6ca8a69a5 100644 --- a/state/inmem_test.go +++ b/state/inmem_test.go @@ -14,3 +14,39 @@ func TestInmemState_impl(t *testing.T) { var _ StatePersister = new(InmemState) var _ StateRefresher = new(InmemState) } + +func TestInmemLocker(t *testing.T) { + inmem := &InmemState{state: TestStateInitial()} + // test that it correctly wraps the inmem state + s := &inmemLocker{InmemState: inmem} + TestState(t, s) + + info := NewLockInfo() + + id, err := s.Lock(info) + if err != nil { + t.Fatal(err) + } + + if id == "" { + t.Fatal("no lock id from state lock") + } + + // locking again should fail + _, err = s.Lock(NewLockInfo()) + if err == nil { + t.Fatal("state locked while locked") + } + + if err.(*LockError).Info.ID != id { + t.Fatal("wrong lock id from lock failure") + } + + if err := s.Unlock(id); err != nil { + t.Fatal(err) + } + + if _, err := s.Lock(NewLockInfo()); err != nil { + t.Fatal(err) + } +} diff --git a/state/state_test.go b/state/state_test.go index e93f5680a..df7a6fd05 100644 --- a/state/state_test.go +++ b/state/state_test.go @@ -1,12 +1,14 @@ package state import ( + "context" "encoding/json" "flag" "io/ioutil" "log" "os" "testing" + "time" "github.com/hashicorp/terraform/helper/logging" ) @@ -50,3 +52,52 @@ func TestNewLockInfo(t *testing.T) { t.Fatal(err) } } + +func TestLockWithContext(t *testing.T) { + inmem := &InmemState{state: TestStateInitial()} + // test that it correctly wraps the inmem state + s := &inmemLocker{InmemState: inmem} + + id, err := s.Lock(NewLockInfo()) + if err != nil { + t.Fatal(err) + } + + // use a cancelled context for an immediate timeout + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + info := NewLockInfo() + info.Info = "lock with context" + _, err = LockWithContext(ctx, s, info) + if err == nil { + t.Fatal("lock should have failed immediately") + } + + // unlock the state during LockWithContext + unlocked := make(chan struct{}) + go func() { + defer close(unlocked) + time.Sleep(500 * time.Millisecond) + if err := s.Unlock(id); err != nil { + t.Fatal(err) + } + }() + + ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + id, err = LockWithContext(ctx, s, info) + if err != nil { + t.Fatal("lock should have completed within 2s:", err) + } + + // ensure the goruotine completes + <-unlocked + + // Lock should have been called a total of 4 times. + // 1 initial lock, 1 failure, 1 failure + 1 retry + if s.lockCounter != 4 { + t.Fatalf("lock only called %d times", s.lockCounter) + } +}