use keepalive replies to detect dead connections
An ssh server should always send a reply packet to the keepalive request. If we miss those replies for over 2min, consider the connection dead and abort, rather than block the provisioner indefinitely.
This commit is contained in:
parent
b7bbb942ff
commit
780ca17884
|
@ -27,17 +27,23 @@ import (
|
||||||
const (
|
const (
|
||||||
// DefaultShebang is added at the top of a SSH script file
|
// DefaultShebang is added at the top of a SSH script file
|
||||||
DefaultShebang = "#!/bin/sh\n"
|
DefaultShebang = "#!/bin/sh\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// randShared is a global random generator object that is shared. This must be
|
||||||
|
// shared since it is seeded by the current time and creating multiple can
|
||||||
|
// result in the same values. By using a shared RNG we assure different numbers
|
||||||
|
// per call.
|
||||||
|
randLock sync.Mutex
|
||||||
|
randShared *rand.Rand
|
||||||
|
|
||||||
// enable ssh keeplive probes by default
|
// enable ssh keeplive probes by default
|
||||||
keepAliveInterval = 2 * time.Second
|
keepAliveInterval = 2 * time.Second
|
||||||
)
|
|
||||||
|
|
||||||
// randShared is a global random generator object that is shared.
|
// max time to wait for for a KeepAlive response before considering the
|
||||||
// This must be shared since it is seeded by the current time and
|
// connection to be dead.
|
||||||
// creating multiple can result in the same values. By using a shared
|
maxKeepAliveDelay = 120 * time.Second
|
||||||
// RNG we assure different numbers per call.
|
)
|
||||||
var randLock sync.Mutex
|
|
||||||
var randShared *rand.Rand
|
|
||||||
|
|
||||||
// Communicator represents the SSH communicator
|
// Communicator represents the SSH communicator
|
||||||
type Communicator struct {
|
type Communicator struct {
|
||||||
|
@ -224,24 +230,54 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
|
||||||
// Start a keepalive goroutine to help maintain the connection for
|
// Start a keepalive goroutine to help maintain the connection for
|
||||||
// long-running commands.
|
// long-running commands.
|
||||||
log.Printf("[DEBUG] starting ssh KeepAlives")
|
log.Printf("[DEBUG] starting ssh KeepAlives")
|
||||||
|
go func() {
|
||||||
|
defer cancelKeepAlive()
|
||||||
|
// Along with the KeepAlives generating packets to keep the tcp
|
||||||
|
// connection open, we will use the replies to verify liveness of the
|
||||||
|
// connection. This will prevent dead connections from blocking the
|
||||||
|
// provisioner indefinitely.
|
||||||
|
respCh := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
t := time.NewTicker(keepAliveInterval)
|
t := time.NewTicker(keepAliveInterval)
|
||||||
defer t.Stop()
|
defer t.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
// there's no useful response to these, just abort when there's
|
|
||||||
// an error.
|
|
||||||
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil)
|
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil)
|
||||||
if err != nil {
|
respCh <- err
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
after := time.NewTimer(maxKeepAliveDelay)
|
||||||
|
defer after.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case err := <-respCh:
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[ERROR] ssh keepalive: %s", err)
|
||||||
|
sshConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-after.C:
|
||||||
|
// abort after too many missed keepalives
|
||||||
|
log.Println("[ERROR] no reply from ssh server")
|
||||||
|
sshConn.Close()
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !after.Stop() {
|
||||||
|
<-after.C
|
||||||
|
}
|
||||||
|
after.Reset(maxKeepAliveDelay)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -100,16 +100,19 @@ func newMockLineServer(t *testing.T, signer ssh.Signer, pubKey string) string {
|
||||||
|
|
||||||
go func(in <-chan *ssh.Request) {
|
go func(in <-chan *ssh.Request) {
|
||||||
for req := range in {
|
for req := range in {
|
||||||
|
// since this channel's requests are serviced serially,
|
||||||
|
// this will block keepalive probes, and can simulate a
|
||||||
|
// hung connection.
|
||||||
|
if bytes.Contains(req.Payload, []byte("sleep")) {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
if req.WantReply {
|
if req.WantReply {
|
||||||
req.Reply(true, nil)
|
req.Reply(true, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(requests)
|
}(requests)
|
||||||
|
|
||||||
go func(newChannel ssh.NewChannel) {
|
|
||||||
conn.OpenChannel(newChannel.ChannelType(), nil)
|
|
||||||
}(newChannel)
|
|
||||||
|
|
||||||
defer channel.Close()
|
defer channel.Close()
|
||||||
}
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
@ -182,6 +185,10 @@ func TestStart(t *testing.T) {
|
||||||
// TestKeepAlives verifies that the keepalive messages don't interfere with
|
// TestKeepAlives verifies that the keepalive messages don't interfere with
|
||||||
// normal operation of the client.
|
// normal operation of the client.
|
||||||
func TestKeepAlives(t *testing.T) {
|
func TestKeepAlives(t *testing.T) {
|
||||||
|
ivl := keepAliveInterval
|
||||||
|
keepAliveInterval = 250 * time.Millisecond
|
||||||
|
defer func() { keepAliveInterval = ivl }()
|
||||||
|
|
||||||
address := newMockLineServer(t, nil, testClientPublicKey)
|
address := newMockLineServer(t, nil, testClientPublicKey)
|
||||||
parts := strings.Split(address, ":")
|
parts := strings.Split(address, ":")
|
||||||
|
|
||||||
|
@ -193,7 +200,6 @@ func TestKeepAlives(t *testing.T) {
|
||||||
"password": "pass",
|
"password": "pass",
|
||||||
"host": parts[0],
|
"host": parts[0],
|
||||||
"port": parts[1],
|
"port": parts[1],
|
||||||
"timeout": "30s",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -209,11 +215,11 @@ func TestKeepAlives(t *testing.T) {
|
||||||
|
|
||||||
var cmd remote.Cmd
|
var cmd remote.Cmd
|
||||||
stdout := new(bytes.Buffer)
|
stdout := new(bytes.Buffer)
|
||||||
cmd.Command = "echo foo"
|
cmd.Command = "sleep"
|
||||||
cmd.Stdout = stdout
|
cmd.Stdout = stdout
|
||||||
|
|
||||||
// wait a bit before executing the command, so that at least 1 keepalive is sent
|
// wait a bit before executing the command, so that at least 1 keepalive is sent
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
err = c.Start(&cmd)
|
err = c.Start(&cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -221,6 +227,52 @@ func TestKeepAlives(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDeadConnection verifies that failed keepalive messages will eventually
|
||||||
|
// kill the connection.
|
||||||
|
func TestFailedKeepAlives(t *testing.T) {
|
||||||
|
ivl := keepAliveInterval
|
||||||
|
del := maxKeepAliveDelay
|
||||||
|
maxKeepAliveDelay = 500 * time.Millisecond
|
||||||
|
keepAliveInterval = 250 * time.Millisecond
|
||||||
|
defer func() {
|
||||||
|
keepAliveInterval = ivl
|
||||||
|
maxKeepAliveDelay = del
|
||||||
|
}()
|
||||||
|
|
||||||
|
address := newMockLineServer(t, nil, testClientPublicKey)
|
||||||
|
parts := strings.Split(address, ":")
|
||||||
|
|
||||||
|
r := &terraform.InstanceState{
|
||||||
|
Ephemeral: terraform.EphemeralState{
|
||||||
|
ConnInfo: map[string]string{
|
||||||
|
"type": "ssh",
|
||||||
|
"user": "user",
|
||||||
|
"password": "pass",
|
||||||
|
"host": parts[0],
|
||||||
|
"port": parts[1],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := New(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error creating communicator: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Connect(nil); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var cmd remote.Cmd
|
||||||
|
stdout := new(bytes.Buffer)
|
||||||
|
cmd.Command = "sleep"
|
||||||
|
cmd.Stdout = stdout
|
||||||
|
|
||||||
|
err = c.Start(&cmd)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected connection error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLostConnection(t *testing.T) {
|
func TestLostConnection(t *testing.T) {
|
||||||
address := newMockLineServer(t, nil, testClientPublicKey)
|
address := newMockLineServer(t, nil, testClientPublicKey)
|
||||||
parts := strings.Split(address, ":")
|
parts := strings.Split(address, ":")
|
||||||
|
|
Loading…
Reference in New Issue