diff --git a/helper/resource/wait.go b/helper/resource/wait.go new file mode 100644 index 000000000..407ce5085 --- /dev/null +++ b/helper/resource/wait.go @@ -0,0 +1,100 @@ +package resource + +import ( + "errors" + "fmt" + "log" + "time" +) + +// StateRefreshFunc is a function type used for StateChangeConf that is +// responsible for refreshing the item being watched for a state change. +// +// It returns three results. `result` is any object that will be returned +// as the final object after waiting for state change. This allows you to +// return the final updated object, for example an EC2 instance after refreshing +// it. +// +// `state` is the latest state of that object. And `err` is any error that +// may have happened while refreshing the state. +type StateRefreshFunc func() (result interface{}, state string, err error) + +// StateChangeConf is the configuration struct used for `WaitForState`. +type StateChangeConf struct { + Pending []string // States that are "allowed" and will continue trying + Refresh StateRefreshFunc // Refreshes the current state + Target string // Target state + Timeout time.Duration // The amount of time to wait before timeout +} + +type waitResult struct { + obj interface{} + err error +} + +// WaitForState watches an object and waits for it to achieve the state +// specified in the configuration using the specified Refresh() func, +// waiting the number of seconds specified in the timeout configuration. +func (conf *StateChangeConf) WaitForState() (i interface{}, err error) { + log.Printf("Waiting for state to become: %s", conf.Target) + + notfoundTick := 0 + + result := make(chan waitResult, 1) + + go func() { + for { + var currentState string + i, currentState, err = conf.Refresh() + if err != nil { + result <- waitResult{nil, err} + return + } + + if i == nil { + // If we didn't find the resource, check if we have been + // not finding it for awhile, and if so, report an error. + notfoundTick += 1 + if notfoundTick > 20 { + result <- waitResult{nil, errors.New("couldn't find resource")} + return + } + } else { + // Reset the counter for when a resource isn't found + notfoundTick = 0 + + if currentState == conf.Target { + result <- waitResult{i, nil} + return + } + + found := false + for _, allowed := range conf.Pending { + if currentState == allowed { + found = true + break + } + } + + if !found { + result <- waitResult{nil, fmt.Errorf("unexpected state '%s', wanted target '%s'", currentState, conf.Target)} + return + } + } + } + + // Wait between refreshes + time.Sleep(2 * time.Second) + }() + + select { + case waitResult := <-result: + err := waitResult.err + i = waitResult.obj + return i, err + case <-time.After(conf.Timeout): + err := fmt.Errorf("timeout while waiting for state to become '%s'", conf.Target) + i = nil + return i, err + } +} diff --git a/helper/resource/wait_test.go b/helper/resource/wait_test.go new file mode 100644 index 000000000..964b1bb7c --- /dev/null +++ b/helper/resource/wait_test.go @@ -0,0 +1,74 @@ +package resource + +import ( + "errors" + "testing" + "time" +) + +type nullObject struct{} + +func FailedStateRefreshFunc() StateRefreshFunc { + return func() (interface{}, string, error) { + return nil, "", errors.New("failed") + } +} + +func TimeoutStateRefreshFunc() StateRefreshFunc { + return func() (interface{}, string, error) { + time.Sleep(100 * time.Second) + return nil, "", errors.New("failed") + } +} + +func SuccessfulStateRefreshFunc() StateRefreshFunc { + return func() (interface{}, string, error) { + return &nullObject{}, "running", nil + } +} + +func TestWaitForState_timeout(t *testing.T) { + conf := &StateChangeConf{ + Pending: []string{"pending", "incomplete"}, + Target: "running", + Refresh: TimeoutStateRefreshFunc(), + Timeout: 1 * time.Millisecond, + } + + _, err := conf.WaitForState() + + if err == nil && err.Error() != "timeout while waiting for state to become 'running'" { + t.Fatalf("err: %s", err) + } + +} + +func TestWaitForState_success(t *testing.T) { + conf := &StateChangeConf{ + Pending: []string{"pending", "incomplete"}, + Target: "running", + Refresh: SuccessfulStateRefreshFunc(), + Timeout: 200 * time.Second, + } + + _, err := conf.WaitForState() + + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestWaitForState_failure(t *testing.T) { + conf := &StateChangeConf{ + Pending: []string{"pending", "incomplete"}, + Target: "running", + Refresh: FailedStateRefreshFunc(), + Timeout: 200 * time.Second, + } + + _, err := conf.WaitForState() + + if err == nil && err.Error() != "failed" { + t.Fatalf("err: %s", err) + } +}