Fix races in WaitForState

The WaitForState method can't read the result values in a timeout
because they are still owned by the running goroutine. Keep all values
scoped inside the goroutine, and save them into an atomic.Value to be
returned.

Fixes race introduced in #8510
This commit is contained in:
James Bardin 2016-08-30 16:22:21 -04:00
parent dea2860735
commit 82be35a797
1 changed files with 32 additions and 16 deletions

View File

@ -2,6 +2,7 @@ package resource
import ( import (
"log" "log"
"sync/atomic"
"time" "time"
) )
@ -61,9 +62,15 @@ func (conf *StateChangeConf) WaitForState() (interface{}, error) {
conf.ContinuousTargetOccurence = 1 conf.ContinuousTargetOccurence = 1
} }
var result interface{} // We can't safely read the result values if we timeout, so store them in
var resulterr error // an atomic.Value
var currentState string type Result struct {
Result interface{}
State string
Error error
}
var lastResult atomic.Value
lastResult.Store(Result{})
doneCh := make(chan struct{}) doneCh := make(chan struct{})
go func() { go func() {
@ -74,7 +81,6 @@ func (conf *StateChangeConf) WaitForState() (interface{}, error) {
wait := 100 * time.Millisecond wait := 100 * time.Millisecond
var err error
for first := true; ; first = false { for first := true; ; first = false {
if !first { if !first {
// If a poll interval has been specified, choose that interval. // If a poll interval has been specified, choose that interval.
@ -99,14 +105,20 @@ func (conf *StateChangeConf) WaitForState() (interface{}, error) {
} }
} }
result, currentState, err = conf.Refresh() res, currentState, err := conf.Refresh()
result := Result{
Result: res,
State: currentState,
Error: err,
}
lastResult.Store(result)
if err != nil { if err != nil {
resulterr = err
return return
} }
// If we're waiting for the absence of a thing, then return // If we're waiting for the absence of a thing, then return
if result == nil && len(conf.Target) == 0 { if res == nil && len(conf.Target) == 0 {
targetOccurence += 1 targetOccurence += 1
if conf.ContinuousTargetOccurence == targetOccurence { if conf.ContinuousTargetOccurence == targetOccurence {
return return
@ -115,14 +127,15 @@ func (conf *StateChangeConf) WaitForState() (interface{}, error) {
} }
} }
if result == nil { if res == nil {
// If we didn't find the resource, check if we have been // If we didn't find the resource, check if we have been
// not finding it for awhile, and if so, report an error. // not finding it for awhile, and if so, report an error.
notfoundTick += 1 notfoundTick += 1
if notfoundTick > conf.NotFoundChecks { if notfoundTick > conf.NotFoundChecks {
resulterr = &NotFoundError{ result.Error = &NotFoundError{
LastError: resulterr, LastError: err,
} }
lastResult.Store(result)
return return
} }
} else { } else {
@ -151,11 +164,12 @@ func (conf *StateChangeConf) WaitForState() (interface{}, error) {
} }
if !found { if !found {
resulterr = &UnexpectedStateError{ result.Error = &UnexpectedStateError{
LastError: resulterr, LastError: err,
State: currentState, State: result.State,
ExpectedState: conf.Target, ExpectedState: conf.Target,
} }
lastResult.Store(result)
return return
} }
} }
@ -164,11 +178,13 @@ func (conf *StateChangeConf) WaitForState() (interface{}, error) {
select { select {
case <-doneCh: case <-doneCh:
return result, resulterr r := lastResult.Load().(Result)
return r.Result, r.Error
case <-time.After(conf.Timeout): case <-time.After(conf.Timeout):
r := lastResult.Load().(Result)
return nil, &TimeoutError{ return nil, &TimeoutError{
LastError: resulterr, LastError: r.Error,
LastState: currentState, LastState: r.State,
ExpectedState: conf.Target, ExpectedState: conf.Target,
} }
} }