add ssh keepalive messages to communicator

Long running remote-exec commands with no output may be cutoff during
execution. Enable ssh keepalives for all ssh connections.
This commit is contained in:
James Bardin 2019-02-22 13:58:07 -05:00
parent ef24b18b25
commit b5384100a6
2 changed files with 85 additions and 6 deletions

View File

@ -3,6 +3,7 @@ package ssh
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -26,6 +27,9 @@ import (
const ( const (
// DefaultShebang is added at the top of a SSH script file // DefaultShebang is added at the top of a SSH script file
DefaultShebang = "#!/bin/sh\n" DefaultShebang = "#!/bin/sh\n"
// enable ssh keeplive probes by default
keepAliveInterval = 2 * time.Second
) )
// randShared is a global random generator object that is shared. // randShared is a global random generator object that is shared.
@ -42,6 +46,7 @@ type Communicator struct {
config *sshConfig config *sshConfig
conn net.Conn conn net.Conn
address string address string
cancelKeepAlive context.CancelFunc
lock sync.Mutex lock sync.Mutex
} }
@ -203,11 +208,39 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
} }
} }
if err != nil {
return err
}
if o != nil { if o != nil {
o.Output("Connected!") o.Output("Connected!")
} }
return err ctx, cancelKeepAlive := context.WithCancel(context.TODO())
c.cancelKeepAlive = cancelKeepAlive
// Start a keepalive goroutine to help maintain the connection for
// long-running commands.
log.Printf("[DEBUG] starting ssh KeepAlives")
go func() {
t := time.NewTicker(keepAliveInterval)
defer t.Stop()
for {
select {
case <-t.C:
// there's no useful response to these, just abort when there's
// an error.
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil)
if err != nil {
return
}
case <-ctx.Done():
return
}
}
}()
return nil
} }
// Disconnect implementation of communicator.Communicator interface // Disconnect implementation of communicator.Communicator interface
@ -215,6 +248,10 @@ func (c *Communicator) Disconnect() error {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
if c.cancelKeepAlive != nil {
c.cancelKeepAlive()
}
if c.config.sshAgent != nil { if c.config.sshAgent != nil {
if err := c.config.sshAgent.Close(); err != nil { if err := c.config.sshAgent.Close(); err != nil {
return err return err

View File

@ -179,6 +179,48 @@ func TestStart(t *testing.T) {
} }
} }
// TestKeepAlives verifies that the keepalive messages don't interfere with
// normal operation of the client.
func TestKeepAlives(t *testing.T) {
address := newMockLineServer(t, nil)
parts := strings.Split(address, ":")
r := &terraform.InstanceState{
Ephemeral: terraform.EphemeralState{
ConnInfo: map[string]string{
"type": "ssh",
"user": "user",
"password": "pass",
"host": parts[0],
"port": parts[1],
"timeout": "30s",
},
},
}
c, err := New(r)
if err != nil {
t.Fatalf("error creating communicator: %s", err)
}
if err := c.Connect(nil); err != nil {
t.Fatal(err)
}
var cmd remote.Cmd
stdout := new(bytes.Buffer)
cmd.Command = "echo foo"
cmd.Stdout = stdout
// wait a bit before executing the command, so that at least 1 keepalive is sent
time.Sleep(3 * time.Second)
err = c.Start(&cmd)
if err != nil {
t.Fatalf("error executing remote command: %s", err)
}
}
func TestLostConnection(t *testing.T) { func TestLostConnection(t *testing.T) {
address := newMockLineServer(t, nil) address := newMockLineServer(t, nil)
parts := strings.Split(address, ":") parts := strings.Split(address, ":")