diff --git a/builtin/provisioners/chef/resource_provisioner.go b/builtin/provisioners/chef/resource_provisioner.go index 8fc45e59e..6edf973f3 100644 --- a/builtin/provisioners/chef/resource_provisioner.go +++ b/builtin/provisioners/chef/resource_provisioner.go @@ -15,7 +15,6 @@ import ( "strings" "sync" "text/template" - "time" "github.com/hashicorp/terraform/communicator" "github.com/hashicorp/terraform/communicator/remote" @@ -307,8 +306,11 @@ func applyFn(ctx context.Context) error { return err } + ctx, cancel := context.WithTimeout(ctx, comm.Timeout()) + defer cancel() + // Wait and retry until we establish the connection - err = retryFunc(comm.Timeout(), func() error { + err = communicator.Retry(ctx, func() error { return comm.Connect(o) }) if err != nil { @@ -717,24 +719,6 @@ func (p *provisioner) copyOutput(o terraform.UIOutput, r io.Reader, doneCh chan< } } -// retryFunc is used to retry a function for a given duration -func retryFunc(timeout time.Duration, f func() error) error { - finish := time.After(timeout) - for { - err := f() - if err == nil { - return nil - } - log.Printf("Retryable error: %v", err) - - select { - case <-finish: - return err - case <-time.After(3 * time.Second): - } - } -} - func decodeConfig(d *schema.ResourceData) (*provisioner, error) { p := &provisioner{ ClientOptions: getStringList(d.Get("client_options")), diff --git a/builtin/provisioners/file/resource_provisioner.go b/builtin/provisioners/file/resource_provisioner.go index 9b9e8a97b..5514250d7 100644 --- a/builtin/provisioners/file/resource_provisioner.go +++ b/builtin/provisioners/file/resource_provisioner.go @@ -4,9 +4,7 @@ import ( "context" "fmt" "io/ioutil" - "log" "os" - "time" "github.com/hashicorp/terraform/communicator" "github.com/hashicorp/terraform/helper/schema" @@ -50,6 +48,9 @@ func applyFn(ctx context.Context) error { return err } + ctx, cancel := context.WithTimeout(ctx, comm.Timeout()) + defer cancel() + // Get the source src, deleteSource, err := getSrc(data) if err != nil { @@ -61,21 +62,11 @@ func applyFn(ctx context.Context) error { // Begin the file copy dst := data.Get("destination").(string) - resultCh := make(chan error, 1) - go func() { - resultCh <- copyFiles(comm, src, dst) - }() - // Allow the file copy to complete unless there is an interrupt. - // If there is an interrupt we make no attempt to cleanly close - // the connection currently. We just abruptly exit. Because Terraform - // taints the resource, this is fine. - select { - case err := <-resultCh: + if err := copyFiles(ctx, comm, src, dst); err != nil { return err - case <-ctx.Done(): - return fmt.Errorf("file transfer interrupted") } + return nil } func validateFn(c *terraform.ResourceConfig) (ws []string, es []error) { @@ -107,16 +98,21 @@ func getSrc(data *schema.ResourceData) (string, bool, error) { } // copyFiles is used to copy the files from a source to a destination -func copyFiles(comm communicator.Communicator, src, dst string) error { +func copyFiles(ctx context.Context, comm communicator.Communicator, src, dst string) error { // Wait and retry until we establish the connection - err := retryFunc(comm.Timeout(), func() error { - err := comm.Connect(nil) - return err + err := communicator.Retry(ctx, func() error { + return comm.Connect(nil) }) if err != nil { return err } - defer comm.Disconnect() + + // disconnect when the context is canceled, which will close this after + // Apply as well. + go func() { + <-ctx.Done() + comm.Disconnect() + }() info, err := os.Stat(src) if err != nil { @@ -144,21 +140,3 @@ func copyFiles(comm communicator.Communicator, src, dst string) error { } return err } - -// retryFunc is used to retry a function for a given duration -func retryFunc(timeout time.Duration, f func() error) error { - finish := time.After(timeout) - for { - err := f() - if err == nil { - return nil - } - log.Printf("Retryable error: %v", err) - - select { - case <-finish: - return err - case <-time.After(3 * time.Second): - } - } -} diff --git a/builtin/provisioners/habitat/resource_provisioner.go b/builtin/provisioners/habitat/resource_provisioner.go index f9d47ff07..aa404dae1 100644 --- a/builtin/provisioners/habitat/resource_provisioner.go +++ b/builtin/provisioners/habitat/resource_provisioner.go @@ -6,12 +6,10 @@ import ( "errors" "fmt" "io" - "log" "net/url" "path" "strings" "text/template" - "time" "github.com/hashicorp/terraform/communicator" "github.com/hashicorp/terraform/communicator/remote" @@ -233,10 +231,13 @@ func applyFn(ctx context.Context) error { return err } - err = retryFunc(comm.Timeout(), func() error { - err = comm.Connect(o) - return err + ctx, cancel := context.WithTimeout(ctx, comm.Timeout()) + defer cancel() + + err = communicator.Retry(ctx, func() error { + return comm.Connect(o) }) + if err != nil { return err } @@ -728,24 +729,6 @@ func (p *provisioner) uploadUserTOML(o terraform.UIOutput, comm communicator.Com } -func retryFunc(timeout time.Duration, f func() error) error { - finish := time.After(timeout) - - for { - err := f() - if err == nil { - return nil - } - log.Printf("Retryable error: %v", err) - - select { - case <-finish: - return err - case <-time.After(3 * time.Second): - } - } -} - func (p *provisioner) copyOutput(o terraform.UIOutput, r io.Reader, doneCh chan<- struct{}) { defer close(doneCh) lr := linereader.New(r) diff --git a/builtin/provisioners/remote-exec/resource_provisioner.go b/builtin/provisioners/remote-exec/resource_provisioner.go index ba811dafe..378a282ed 100644 --- a/builtin/provisioners/remote-exec/resource_provisioner.go +++ b/builtin/provisioners/remote-exec/resource_provisioner.go @@ -9,7 +9,6 @@ import ( "log" "os" "strings" - "sync/atomic" "time" "github.com/hashicorp/terraform/communicator" @@ -159,7 +158,7 @@ func runScripts( scripts []io.ReadCloser) error { // Wrap out context in a cancelation function that we use to // kill the connection. - ctx, cancelFunc := context.WithCancel(ctx) + ctx, cancelFunc := context.WithTimeout(ctx, comm.Timeout()) defer cancelFunc() // Wait for the context to end and then disconnect @@ -169,9 +168,8 @@ func runScripts( }() // Wait and retry until we establish the connection - err := retryFunc(ctx, comm.Timeout(), func() error { - err := comm.Connect(o) - return err + err := communicator.Retry(ctx, func() error { + return comm.Connect(o) }) if err != nil { return err @@ -179,49 +177,34 @@ func runScripts( for _, script := range scripts { var cmd *remote.Cmd + outR, outW := io.Pipe() errR, errW := io.Pipe() - outDoneCh := make(chan struct{}) - errDoneCh := make(chan struct{}) - go copyOutput(o, outR, outDoneCh) - go copyOutput(o, errR, errDoneCh) + defer outW.Close() + defer errW.Close() + + go copyOutput(o, outR) + go copyOutput(o, errR) remotePath := comm.ScriptPath() - err = retryFunc(ctx, comm.Timeout(), func() error { - if err := comm.UploadScript(remotePath, script); err != nil { - return fmt.Errorf("Failed to upload script: %v", err) - } - cmd = &remote.Cmd{ - Command: remotePath, - Stdout: outW, - Stderr: errW, - } - if err := comm.Start(cmd); err != nil { - return fmt.Errorf("Error starting script: %v", err) - } - - return nil - }) - if err == nil { - cmd.Wait() - if cmd.ExitStatus != 0 { - err = fmt.Errorf("Script exited with non-zero exit status: %d", cmd.ExitStatus) - } + if err := comm.UploadScript(remotePath, script); err != nil { + return fmt.Errorf("Failed to upload script: %v", err) } - // If we have an error, end our context so the disconnect happens. - // This has to happen before the output cleanup below since during - // an interrupt this will cause the outputs to end. - if err != nil { - cancelFunc() + cmd = &remote.Cmd{ + Command: remotePath, + Stdout: outW, + Stderr: errW, + } + if err := comm.Start(cmd); err != nil { + return fmt.Errorf("Error starting script: %v", err) } - // Wait for output to clean up - outW.Close() - errW.Close() - <-outDoneCh - <-errDoneCh + cmd.Wait() + if cmd.ExitStatus != 0 { + err = fmt.Errorf("Script exited with non-zero exit status: %d", cmd.ExitStatus) + } // Upload a blank follow up file in the same path to prevent residual // script contents from remaining on remote machine @@ -230,93 +213,15 @@ func runScripts( // This feature is best-effort. log.Printf("[WARN] Failed to upload empty follow up script: %v", err) } - - // If we have an error, return it out now that we've cleaned up - if err != nil { - return err - } } return nil } func copyOutput( - o terraform.UIOutput, r io.Reader, doneCh chan<- struct{}) { - defer close(doneCh) + o terraform.UIOutput, r io.Reader) { lr := linereader.New(r) for line := range lr.Ch { o.Output(line) } } - -// retryFunc is used to retry a function for a given duration -func retryFunc(ctx context.Context, timeout time.Duration, f func() error) error { - // Build a new context with the timeout - ctx, done := context.WithTimeout(ctx, timeout) - defer done() - - // container for atomic error value - type errWrap struct { - E error - } - - // Try the function in a goroutine - var errVal atomic.Value - doneCh := make(chan struct{}) - go func() { - defer close(doneCh) - - delay := time.Duration(0) - for { - // If our context ended, we want to exit right away. - select { - case <-ctx.Done(): - return - case <-time.After(delay): - } - - // Try the function call - err := f() - errVal.Store(&errWrap{err}) - - if err == nil { - return - } - - log.Printf("[WARN] retryable error: %v", err) - - delay *= 2 - - if delay == 0 { - delay = initialBackoffDelay - } - - if delay > maxBackoffDelay { - delay = maxBackoffDelay - } - - log.Printf("[INFO] sleeping for %s", delay) - } - }() - - // Wait for completion - select { - case <-ctx.Done(): - case <-doneCh: - } - - // Check if we have a context error to check if we're interrupted or timeout - switch ctx.Err() { - case context.Canceled: - return fmt.Errorf("interrupted") - case context.DeadlineExceeded: - return fmt.Errorf("timeout") - } - - // Check if we got an error executing - if ev, ok := errVal.Load().(errWrap); ok { - return ev.E - } - - return nil -} diff --git a/builtin/provisioners/remote-exec/resource_provisioner_test.go b/builtin/provisioners/remote-exec/resource_provisioner_test.go index 8c447788d..a6e024fe5 100644 --- a/builtin/provisioners/remote-exec/resource_provisioner_test.go +++ b/builtin/provisioners/remote-exec/resource_provisioner_test.go @@ -2,12 +2,8 @@ package remoteexec import ( "bytes" - "context" - "errors" "io" - "net" "testing" - "time" "strings" @@ -210,64 +206,6 @@ func TestResourceProvider_CollectScripts_scriptsEmpty(t *testing.T) { } } -func TestRetryFunc(t *testing.T) { - origMax := maxBackoffDelay - maxBackoffDelay = time.Second - origStart := initialBackoffDelay - initialBackoffDelay = 10 * time.Millisecond - - defer func() { - maxBackoffDelay = origMax - initialBackoffDelay = origStart - }() - - // succeed on the third try - errs := []error{io.EOF, &net.OpError{Err: errors.New("ERROR")}, nil} - count := 0 - - err := retryFunc(context.Background(), time.Second, func() error { - if count >= len(errs) { - return errors.New("failed to stop after nil error") - } - - err := errs[count] - count++ - - return err - }) - - if count != 3 { - t.Fatal("retry func should have been called 3 times") - } - - if err != nil { - t.Fatal(err) - } -} - -func TestRetryFuncBackoff(t *testing.T) { - origMax := maxBackoffDelay - maxBackoffDelay = time.Second - origStart := initialBackoffDelay - initialBackoffDelay = 100 * time.Millisecond - - defer func() { - maxBackoffDelay = origMax - initialBackoffDelay = origStart - }() - - count := 0 - - retryFunc(context.Background(), time.Second, func() error { - count++ - return io.EOF - }) - - if count > 4 { - t.Fatalf("retry func failed to backoff. called %d times", count) - } -} - func testConfig(t *testing.T, c map[string]interface{}) *terraform.ResourceConfig { r, err := config.NewRawConfig(c) if err != nil { diff --git a/communicator/communicator.go b/communicator/communicator.go index 5fa2631a4..440c1c727 100644 --- a/communicator/communicator.go +++ b/communicator/communicator.go @@ -1,8 +1,11 @@ package communicator import ( + "context" "fmt" "io" + "log" + "sync/atomic" "time" "github.com/hashicorp/terraform/communicator/remote" @@ -51,3 +54,96 @@ func New(s *terraform.InstanceState) (Communicator, error) { return nil, fmt.Errorf("connection type '%s' not supported", connType) } } + +// maxBackoffDelay is the maximum delay between retry attempts +var maxBackoffDelay = 20 * time.Second +var initialBackoffDelay = time.Second + +// Fatal is an interface that error values can return to halt Retry +type Fatal interface { + FatalError() error +} + +// Retry retries the function f until it returns a nil error, a Fatal error, or +// the context expires. +func Retry(ctx context.Context, f func() error) error { + // container for atomic error value + type errWrap struct { + E error + } + + // Try the function in a goroutine + var errVal atomic.Value + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + + delay := time.Duration(0) + for { + // If our context ended, we want to exit right away. + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + + // Try the function call + err := f() + + // return if we have no error, or a FatalError + done := false + switch e := err.(type) { + case nil: + done = true + case Fatal: + err = e.FatalError() + done = true + } + + errVal.Store(errWrap{err}) + + if done { + return + } + + log.Printf("[WARN] retryable error: %v", err) + + delay *= 2 + + if delay == 0 { + delay = initialBackoffDelay + } + + if delay > maxBackoffDelay { + delay = maxBackoffDelay + } + + log.Printf("[INFO] sleeping for %s", delay) + } + }() + + // Wait for completion + select { + case <-ctx.Done(): + case <-doneCh: + } + + var lastErr error + // Check if we got an error executing + if ev, ok := errVal.Load().(errWrap); ok { + lastErr = ev.E + } + + // Check if we have a context error to check if we're interrupted or timeout + switch ctx.Err() { + case context.Canceled: + return fmt.Errorf("interrupted - last error: %v", lastErr) + case context.DeadlineExceeded: + return fmt.Errorf("timeout - last error: %v", lastErr) + } + + if lastErr != nil { + return lastErr + } + return nil +} diff --git a/communicator/communicator_test.go b/communicator/communicator_test.go index 33a91cd6f..659222421 100644 --- a/communicator/communicator_test.go +++ b/communicator/communicator_test.go @@ -1,7 +1,12 @@ package communicator import ( + "context" + "errors" + "io" + "net" "testing" + "time" "github.com/hashicorp/terraform/terraform" ) @@ -28,3 +33,66 @@ func TestCommunicator_new(t *testing.T) { t.Fatalf("err: %v", err) } } +func TestRetryFunc(t *testing.T) { + origMax := maxBackoffDelay + maxBackoffDelay = time.Second + origStart := initialBackoffDelay + initialBackoffDelay = 10 * time.Millisecond + + defer func() { + maxBackoffDelay = origMax + initialBackoffDelay = origStart + }() + + // succeed on the third try + errs := []error{io.EOF, &net.OpError{Err: errors.New("ERROR")}, nil} + count := 0 + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := Retry(ctx, func() error { + if count >= len(errs) { + return errors.New("failed to stop after nil error") + } + + err := errs[count] + count++ + + return err + }) + + if count != 3 { + t.Fatal("retry func should have been called 3 times") + } + + if err != nil { + t.Fatal(err) + } +} + +func TestRetryFuncBackoff(t *testing.T) { + origMax := maxBackoffDelay + maxBackoffDelay = time.Second + origStart := initialBackoffDelay + initialBackoffDelay = 100 * time.Millisecond + + defer func() { + maxBackoffDelay = origMax + initialBackoffDelay = origStart + }() + + count := 0 + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + Retry(ctx, func() error { + count++ + return io.EOF + }) + + if count > 4 { + t.Fatalf("retry func failed to backoff. called %d times", count) + } +} diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index 4ad67aefc..85dabb6b5 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -63,6 +63,14 @@ type sshConfig struct { sshAgent *sshAgent } +type fatalError struct { + error +} + +func (e fatalError) FatalError() error { + return e.error +} + // New creates a new communicator implementation over SSH. func New(s *terraform.InstanceState) (*Communicator, error) { connInfo, err := parseConnectionInfo(s) @@ -159,8 +167,8 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { host := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, host, c.config.config) if err != nil { - log.Printf("handshake error: %s", err) - return err + log.Printf("fatal handshake error: %s", err) + return fatalError{err} } c.client = ssh.NewClient(sshConn, sshChan, req) @@ -168,7 +176,7 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { if c.config.sshAgent != nil { log.Printf("[DEBUG] Telling SSH config to forward to agent") if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { - return err + return fatalError{err} } log.Printf("[DEBUG] Setting up a session to request agent forwarding")