use keepalive replies to detect dead connections

An ssh server should always send a reply packet to the keepalive
request. If we miss those replies for over 2min, consider the connection
dead and abort, rather than block the provisioner indefinitely.
This commit is contained in:
James Bardin 2019-07-10 21:28:58 -04:00
parent b7bbb942ff
commit 780ca17884
2 changed files with 111 additions and 23 deletions

View File

@ -27,17 +27,23 @@ import (
const ( const (
// DefaultShebang is added at the top of a SSH script file // DefaultShebang is added at the top of a SSH script file
DefaultShebang = "#!/bin/sh\n" DefaultShebang = "#!/bin/sh\n"
)
var (
// randShared is a global random generator object that is shared. This must be
// shared since it is seeded by the current time and creating multiple can
// result in the same values. By using a shared RNG we assure different numbers
// per call.
randLock sync.Mutex
randShared *rand.Rand
// enable ssh keeplive probes by default // enable ssh keeplive probes by default
keepAliveInterval = 2 * time.Second keepAliveInterval = 2 * time.Second
)
// randShared is a global random generator object that is shared. // max time to wait for for a KeepAlive response before considering the
// This must be shared since it is seeded by the current time and // connection to be dead.
// creating multiple can result in the same values. By using a shared maxKeepAliveDelay = 120 * time.Second
// RNG we assure different numbers per call. )
var randLock sync.Mutex
var randShared *rand.Rand
// Communicator represents the SSH communicator // Communicator represents the SSH communicator
type Communicator struct { type Communicator struct {
@ -225,20 +231,50 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
// long-running commands. // long-running commands.
log.Printf("[DEBUG] starting ssh KeepAlives") log.Printf("[DEBUG] starting ssh KeepAlives")
go func() { go func() {
t := time.NewTicker(keepAliveInterval) defer cancelKeepAlive()
defer t.Stop() // Along with the KeepAlives generating packets to keep the tcp
for { // connection open, we will use the replies to verify liveness of the
select { // connection. This will prevent dead connections from blocking the
case <-t.C: // provisioner indefinitely.
// there's no useful response to these, just abort when there's respCh := make(chan error, 1)
// an error.
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil) go func() {
if err != nil { t := time.NewTicker(keepAliveInterval)
defer t.Stop()
for {
select {
case <-t.C:
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil)
respCh <- err
case <-ctx.Done():
return return
} }
}
}()
after := time.NewTimer(maxKeepAliveDelay)
defer after.Stop()
for {
select {
case err := <-respCh:
if err != nil {
log.Printf("[ERROR] ssh keepalive: %s", err)
sshConn.Close()
return
}
case <-after.C:
// abort after too many missed keepalives
log.Println("[ERROR] no reply from ssh server")
sshConn.Close()
return
case <-ctx.Done(): case <-ctx.Done():
return return
} }
if !after.Stop() {
<-after.C
}
after.Reset(maxKeepAliveDelay)
} }
}() }()

View File

@ -100,16 +100,19 @@ func newMockLineServer(t *testing.T, signer ssh.Signer, pubKey string) string {
go func(in <-chan *ssh.Request) { go func(in <-chan *ssh.Request) {
for req := range in { for req := range in {
// since this channel's requests are serviced serially,
// this will block keepalive probes, and can simulate a
// hung connection.
if bytes.Contains(req.Payload, []byte("sleep")) {
time.Sleep(time.Second)
}
if req.WantReply { if req.WantReply {
req.Reply(true, nil) req.Reply(true, nil)
} }
} }
}(requests) }(requests)
go func(newChannel ssh.NewChannel) {
conn.OpenChannel(newChannel.ChannelType(), nil)
}(newChannel)
defer channel.Close() defer channel.Close()
} }
conn.Close() conn.Close()
@ -182,6 +185,10 @@ func TestStart(t *testing.T) {
// TestKeepAlives verifies that the keepalive messages don't interfere with // TestKeepAlives verifies that the keepalive messages don't interfere with
// normal operation of the client. // normal operation of the client.
func TestKeepAlives(t *testing.T) { func TestKeepAlives(t *testing.T) {
ivl := keepAliveInterval
keepAliveInterval = 250 * time.Millisecond
defer func() { keepAliveInterval = ivl }()
address := newMockLineServer(t, nil, testClientPublicKey) address := newMockLineServer(t, nil, testClientPublicKey)
parts := strings.Split(address, ":") parts := strings.Split(address, ":")
@ -193,7 +200,6 @@ func TestKeepAlives(t *testing.T) {
"password": "pass", "password": "pass",
"host": parts[0], "host": parts[0],
"port": parts[1], "port": parts[1],
"timeout": "30s",
}, },
}, },
} }
@ -209,11 +215,11 @@ func TestKeepAlives(t *testing.T) {
var cmd remote.Cmd var cmd remote.Cmd
stdout := new(bytes.Buffer) stdout := new(bytes.Buffer)
cmd.Command = "echo foo" cmd.Command = "sleep"
cmd.Stdout = stdout cmd.Stdout = stdout
// wait a bit before executing the command, so that at least 1 keepalive is sent // wait a bit before executing the command, so that at least 1 keepalive is sent
time.Sleep(3 * time.Second) time.Sleep(500 * time.Millisecond)
err = c.Start(&cmd) err = c.Start(&cmd)
if err != nil { if err != nil {
@ -221,6 +227,52 @@ func TestKeepAlives(t *testing.T) {
} }
} }
// TestDeadConnection verifies that failed keepalive messages will eventually
// kill the connection.
func TestFailedKeepAlives(t *testing.T) {
ivl := keepAliveInterval
del := maxKeepAliveDelay
maxKeepAliveDelay = 500 * time.Millisecond
keepAliveInterval = 250 * time.Millisecond
defer func() {
keepAliveInterval = ivl
maxKeepAliveDelay = del
}()
address := newMockLineServer(t, nil, testClientPublicKey)
parts := strings.Split(address, ":")
r := &terraform.InstanceState{
Ephemeral: terraform.EphemeralState{
ConnInfo: map[string]string{
"type": "ssh",
"user": "user",
"password": "pass",
"host": parts[0],
"port": parts[1],
},
},
}
c, err := New(r)
if err != nil {
t.Fatalf("error creating communicator: %s", err)
}
if err := c.Connect(nil); err != nil {
t.Fatal(err)
}
var cmd remote.Cmd
stdout := new(bytes.Buffer)
cmd.Command = "sleep"
cmd.Stdout = stdout
err = c.Start(&cmd)
if err == nil {
t.Fatal("expected connection error")
}
}
func TestLostConnection(t *testing.T) { func TestLostConnection(t *testing.T) {
address := newMockLineServer(t, nil, testClientPublicKey) address := newMockLineServer(t, nil, testClientPublicKey)
parts := strings.Split(address, ":") parts := strings.Split(address, ":")