test LockWithContext
This commit is contained in:
parent
93b1dd6323
commit
d1460d8c82
|
@ -1,6 +1,10 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/terraform/terraform"
|
||||
)
|
||||
|
||||
|
@ -34,3 +38,51 @@ func (s *InmemState) Lock(*LockInfo) (string, error) {
|
|||
func (s *InmemState) Unlock(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// inmemLocker is an in-memory State implementation for testing locks.
|
||||
type inmemLocker struct {
|
||||
*InmemState
|
||||
|
||||
mu sync.Mutex
|
||||
lockInfo *LockInfo
|
||||
// count the calls to Lock
|
||||
lockCounter int
|
||||
}
|
||||
|
||||
func (s *inmemLocker) Lock(info *LockInfo) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lockCounter++
|
||||
|
||||
lockErr := &LockError{
|
||||
Info: &LockInfo{},
|
||||
}
|
||||
|
||||
if s.lockInfo != nil {
|
||||
lockErr.Err = errors.New("state locked")
|
||||
*lockErr.Info = *s.lockInfo
|
||||
return "", lockErr
|
||||
}
|
||||
|
||||
info.Created = time.Now().UTC()
|
||||
s.lockInfo = info
|
||||
return s.lockInfo.ID, nil
|
||||
}
|
||||
|
||||
func (s *inmemLocker) Unlock(id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
lockErr := &LockError{
|
||||
Info: &LockInfo{},
|
||||
}
|
||||
|
||||
if id != s.lockInfo.ID {
|
||||
lockErr.Err = errors.New("invalid lock id")
|
||||
*lockErr.Info = *s.lockInfo
|
||||
return lockErr
|
||||
}
|
||||
|
||||
s.lockInfo = nil
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -14,3 +14,39 @@ func TestInmemState_impl(t *testing.T) {
|
|||
var _ StatePersister = new(InmemState)
|
||||
var _ StateRefresher = new(InmemState)
|
||||
}
|
||||
|
||||
func TestInmemLocker(t *testing.T) {
|
||||
inmem := &InmemState{state: TestStateInitial()}
|
||||
// test that it correctly wraps the inmem state
|
||||
s := &inmemLocker{InmemState: inmem}
|
||||
TestState(t, s)
|
||||
|
||||
info := NewLockInfo()
|
||||
|
||||
id, err := s.Lock(info)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
t.Fatal("no lock id from state lock")
|
||||
}
|
||||
|
||||
// locking again should fail
|
||||
_, err = s.Lock(NewLockInfo())
|
||||
if err == nil {
|
||||
t.Fatal("state locked while locked")
|
||||
}
|
||||
|
||||
if err.(*LockError).Info.ID != id {
|
||||
t.Fatal("wrong lock id from lock failure")
|
||||
}
|
||||
|
||||
if err := s.Unlock(id); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := s.Lock(NewLockInfo()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/terraform/helper/logging"
|
||||
)
|
||||
|
@ -50,3 +52,52 @@ func TestNewLockInfo(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockWithContext(t *testing.T) {
|
||||
inmem := &InmemState{state: TestStateInitial()}
|
||||
// test that it correctly wraps the inmem state
|
||||
s := &inmemLocker{InmemState: inmem}
|
||||
|
||||
id, err := s.Lock(NewLockInfo())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// use a cancelled context for an immediate timeout
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
info := NewLockInfo()
|
||||
info.Info = "lock with context"
|
||||
_, err = LockWithContext(ctx, s, info)
|
||||
if err == nil {
|
||||
t.Fatal("lock should have failed immediately")
|
||||
}
|
||||
|
||||
// unlock the state during LockWithContext
|
||||
unlocked := make(chan struct{})
|
||||
go func() {
|
||||
defer close(unlocked)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := s.Unlock(id); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
id, err = LockWithContext(ctx, s, info)
|
||||
if err != nil {
|
||||
t.Fatal("lock should have completed within 2s:", err)
|
||||
}
|
||||
|
||||
// ensure the goruotine completes
|
||||
<-unlocked
|
||||
|
||||
// Lock should have been called a total of 4 times.
|
||||
// 1 initial lock, 1 failure, 1 failure + 1 retry
|
||||
if s.lockCounter != 4 {
|
||||
t.Fatalf("lock only called %d times", s.lockCounter)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue