use keepalive replies to detect dead connections
An ssh server should always send a reply packet to the keepalive request. If we miss those replies for over 2min, consider the connection dead and abort, rather than block the provisioner indefinitely.
This commit is contained in:
parent
b7bbb942ff
commit
780ca17884
|
@ -27,17 +27,23 @@ import (
|
|||
const (
|
||||
// DefaultShebang is added at the top of a SSH script file
|
||||
DefaultShebang = "#!/bin/sh\n"
|
||||
)
|
||||
|
||||
var (
|
||||
// randShared is a global random generator object that is shared. This must be
|
||||
// shared since it is seeded by the current time and creating multiple can
|
||||
// result in the same values. By using a shared RNG we assure different numbers
|
||||
// per call.
|
||||
randLock sync.Mutex
|
||||
randShared *rand.Rand
|
||||
|
||||
// enable ssh keeplive probes by default
|
||||
keepAliveInterval = 2 * time.Second
|
||||
)
|
||||
|
||||
// randShared is a global random generator object that is shared.
|
||||
// This must be shared since it is seeded by the current time and
|
||||
// creating multiple can result in the same values. By using a shared
|
||||
// RNG we assure different numbers per call.
|
||||
var randLock sync.Mutex
|
||||
var randShared *rand.Rand
|
||||
// max time to wait for for a KeepAlive response before considering the
|
||||
// connection to be dead.
|
||||
maxKeepAliveDelay = 120 * time.Second
|
||||
)
|
||||
|
||||
// Communicator represents the SSH communicator
|
||||
type Communicator struct {
|
||||
|
@ -225,20 +231,50 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
|
|||
// long-running commands.
|
||||
log.Printf("[DEBUG] starting ssh KeepAlives")
|
||||
go func() {
|
||||
t := time.NewTicker(keepAliveInterval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
// there's no useful response to these, just abort when there's
|
||||
// an error.
|
||||
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil)
|
||||
if err != nil {
|
||||
defer cancelKeepAlive()
|
||||
// Along with the KeepAlives generating packets to keep the tcp
|
||||
// connection open, we will use the replies to verify liveness of the
|
||||
// connection. This will prevent dead connections from blocking the
|
||||
// provisioner indefinitely.
|
||||
respCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
t := time.NewTicker(keepAliveInterval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil)
|
||||
respCh <- err
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
after := time.NewTimer(maxKeepAliveDelay)
|
||||
defer after.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-respCh:
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] ssh keepalive: %s", err)
|
||||
sshConn.Close()
|
||||
return
|
||||
}
|
||||
case <-after.C:
|
||||
// abort after too many missed keepalives
|
||||
log.Println("[ERROR] no reply from ssh server")
|
||||
sshConn.Close()
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
if !after.Stop() {
|
||||
<-after.C
|
||||
}
|
||||
after.Reset(maxKeepAliveDelay)
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
|
@ -100,16 +100,19 @@ func newMockLineServer(t *testing.T, signer ssh.Signer, pubKey string) string {
|
|||
|
||||
go func(in <-chan *ssh.Request) {
|
||||
for req := range in {
|
||||
// since this channel's requests are serviced serially,
|
||||
// this will block keepalive probes, and can simulate a
|
||||
// hung connection.
|
||||
if bytes.Contains(req.Payload, []byte("sleep")) {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
if req.WantReply {
|
||||
req.Reply(true, nil)
|
||||
}
|
||||
}
|
||||
}(requests)
|
||||
|
||||
go func(newChannel ssh.NewChannel) {
|
||||
conn.OpenChannel(newChannel.ChannelType(), nil)
|
||||
}(newChannel)
|
||||
|
||||
defer channel.Close()
|
||||
}
|
||||
conn.Close()
|
||||
|
@ -182,6 +185,10 @@ func TestStart(t *testing.T) {
|
|||
// TestKeepAlives verifies that the keepalive messages don't interfere with
|
||||
// normal operation of the client.
|
||||
func TestKeepAlives(t *testing.T) {
|
||||
ivl := keepAliveInterval
|
||||
keepAliveInterval = 250 * time.Millisecond
|
||||
defer func() { keepAliveInterval = ivl }()
|
||||
|
||||
address := newMockLineServer(t, nil, testClientPublicKey)
|
||||
parts := strings.Split(address, ":")
|
||||
|
||||
|
@ -193,7 +200,6 @@ func TestKeepAlives(t *testing.T) {
|
|||
"password": "pass",
|
||||
"host": parts[0],
|
||||
"port": parts[1],
|
||||
"timeout": "30s",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -209,11 +215,11 @@ func TestKeepAlives(t *testing.T) {
|
|||
|
||||
var cmd remote.Cmd
|
||||
stdout := new(bytes.Buffer)
|
||||
cmd.Command = "echo foo"
|
||||
cmd.Command = "sleep"
|
||||
cmd.Stdout = stdout
|
||||
|
||||
// wait a bit before executing the command, so that at least 1 keepalive is sent
|
||||
time.Sleep(3 * time.Second)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
err = c.Start(&cmd)
|
||||
if err != nil {
|
||||
|
@ -221,6 +227,52 @@ func TestKeepAlives(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestDeadConnection verifies that failed keepalive messages will eventually
|
||||
// kill the connection.
|
||||
func TestFailedKeepAlives(t *testing.T) {
|
||||
ivl := keepAliveInterval
|
||||
del := maxKeepAliveDelay
|
||||
maxKeepAliveDelay = 500 * time.Millisecond
|
||||
keepAliveInterval = 250 * time.Millisecond
|
||||
defer func() {
|
||||
keepAliveInterval = ivl
|
||||
maxKeepAliveDelay = del
|
||||
}()
|
||||
|
||||
address := newMockLineServer(t, nil, testClientPublicKey)
|
||||
parts := strings.Split(address, ":")
|
||||
|
||||
r := &terraform.InstanceState{
|
||||
Ephemeral: terraform.EphemeralState{
|
||||
ConnInfo: map[string]string{
|
||||
"type": "ssh",
|
||||
"user": "user",
|
||||
"password": "pass",
|
||||
"host": parts[0],
|
||||
"port": parts[1],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c, err := New(r)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating communicator: %s", err)
|
||||
}
|
||||
|
||||
if err := c.Connect(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cmd remote.Cmd
|
||||
stdout := new(bytes.Buffer)
|
||||
cmd.Command = "sleep"
|
||||
cmd.Stdout = stdout
|
||||
|
||||
err = c.Start(&cmd)
|
||||
if err == nil {
|
||||
t.Fatal("expected connection error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLostConnection(t *testing.T) {
|
||||
address := newMockLineServer(t, nil, testClientPublicKey)
|
||||
parts := strings.Split(address, ":")
|
||||
|
|
Loading…
Reference in New Issue