diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index 5b814e3c3..0c1d4d151 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -27,17 +27,23 @@ import ( const ( // DefaultShebang is added at the top of a SSH script file DefaultShebang = "#!/bin/sh\n" +) + +var ( + // randShared is a global random generator object that is shared. This must be + // shared since it is seeded by the current time and creating multiple can + // result in the same values. By using a shared RNG we assure different numbers + // per call. + randLock sync.Mutex + randShared *rand.Rand // enable ssh keeplive probes by default keepAliveInterval = 2 * time.Second -) -// randShared is a global random generator object that is shared. -// This must be shared since it is seeded by the current time and -// creating multiple can result in the same values. By using a shared -// RNG we assure different numbers per call. -var randLock sync.Mutex -var randShared *rand.Rand + // max time to wait for for a KeepAlive response before considering the + // connection to be dead. + maxKeepAliveDelay = 120 * time.Second +) // Communicator represents the SSH communicator type Communicator struct { @@ -225,20 +231,50 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { // long-running commands. log.Printf("[DEBUG] starting ssh KeepAlives") go func() { - t := time.NewTicker(keepAliveInterval) - defer t.Stop() - for { - select { - case <-t.C: - // there's no useful response to these, just abort when there's - // an error. - _, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil) - if err != nil { + defer cancelKeepAlive() + // Along with the KeepAlives generating packets to keep the tcp + // connection open, we will use the replies to verify liveness of the + // connection. This will prevent dead connections from blocking the + // provisioner indefinitely. + respCh := make(chan error, 1) + + go func() { + t := time.NewTicker(keepAliveInterval) + defer t.Stop() + for { + select { + case <-t.C: + _, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil) + respCh <- err + case <-ctx.Done(): return } + } + }() + + after := time.NewTimer(maxKeepAliveDelay) + defer after.Stop() + + for { + select { + case err := <-respCh: + if err != nil { + log.Printf("[ERROR] ssh keepalive: %s", err) + sshConn.Close() + return + } + case <-after.C: + // abort after too many missed keepalives + log.Println("[ERROR] no reply from ssh server") + sshConn.Close() + return case <-ctx.Done(): return } + if !after.Stop() { + <-after.C + } + after.Reset(maxKeepAliveDelay) } }() diff --git a/communicator/ssh/communicator_test.go b/communicator/ssh/communicator_test.go index a84afeb87..bbe821363 100644 --- a/communicator/ssh/communicator_test.go +++ b/communicator/ssh/communicator_test.go @@ -100,16 +100,19 @@ func newMockLineServer(t *testing.T, signer ssh.Signer, pubKey string) string { go func(in <-chan *ssh.Request) { for req := range in { + // since this channel's requests are serviced serially, + // this will block keepalive probes, and can simulate a + // hung connection. + if bytes.Contains(req.Payload, []byte("sleep")) { + time.Sleep(time.Second) + } + if req.WantReply { req.Reply(true, nil) } } }(requests) - go func(newChannel ssh.NewChannel) { - conn.OpenChannel(newChannel.ChannelType(), nil) - }(newChannel) - defer channel.Close() } conn.Close() @@ -182,6 +185,10 @@ func TestStart(t *testing.T) { // TestKeepAlives verifies that the keepalive messages don't interfere with // normal operation of the client. func TestKeepAlives(t *testing.T) { + ivl := keepAliveInterval + keepAliveInterval = 250 * time.Millisecond + defer func() { keepAliveInterval = ivl }() + address := newMockLineServer(t, nil, testClientPublicKey) parts := strings.Split(address, ":") @@ -193,7 +200,6 @@ func TestKeepAlives(t *testing.T) { "password": "pass", "host": parts[0], "port": parts[1], - "timeout": "30s", }, }, } @@ -209,11 +215,11 @@ func TestKeepAlives(t *testing.T) { var cmd remote.Cmd stdout := new(bytes.Buffer) - cmd.Command = "echo foo" + cmd.Command = "sleep" cmd.Stdout = stdout // wait a bit before executing the command, so that at least 1 keepalive is sent - time.Sleep(3 * time.Second) + time.Sleep(500 * time.Millisecond) err = c.Start(&cmd) if err != nil { @@ -221,6 +227,52 @@ func TestKeepAlives(t *testing.T) { } } +// TestDeadConnection verifies that failed keepalive messages will eventually +// kill the connection. +func TestFailedKeepAlives(t *testing.T) { + ivl := keepAliveInterval + del := maxKeepAliveDelay + maxKeepAliveDelay = 500 * time.Millisecond + keepAliveInterval = 250 * time.Millisecond + defer func() { + keepAliveInterval = ivl + maxKeepAliveDelay = del + }() + + address := newMockLineServer(t, nil, testClientPublicKey) + parts := strings.Split(address, ":") + + r := &terraform.InstanceState{ + Ephemeral: terraform.EphemeralState{ + ConnInfo: map[string]string{ + "type": "ssh", + "user": "user", + "password": "pass", + "host": parts[0], + "port": parts[1], + }, + }, + } + + c, err := New(r) + if err != nil { + t.Fatalf("error creating communicator: %s", err) + } + + if err := c.Connect(nil); err != nil { + t.Fatal(err) + } + var cmd remote.Cmd + stdout := new(bytes.Buffer) + cmd.Command = "sleep" + cmd.Stdout = stdout + + err = c.Start(&cmd) + if err == nil { + t.Fatal("expected connection error") + } +} + func TestLostConnection(t *testing.T) { address := newMockLineServer(t, nil, testClientPublicKey) parts := strings.Split(address, ":")