Merge pull request #20437 from hashicorp/jbardin/ssh-keepalive
add ssh keepalive messages to communicator
This commit is contained in:
commit
929231a2e9
|
@ -3,6 +3,7 @@ package ssh
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -26,6 +27,9 @@ import (
|
|||
const (
|
||||
// DefaultShebang is added at the top of a SSH script file
|
||||
DefaultShebang = "#!/bin/sh\n"
|
||||
|
||||
// enable ssh keeplive probes by default
|
||||
keepAliveInterval = 2 * time.Second
|
||||
)
|
||||
|
||||
// randShared is a global random generator object that is shared.
|
||||
|
@ -37,11 +41,12 @@ var randShared *rand.Rand
|
|||
|
||||
// Communicator represents the SSH communicator
|
||||
type Communicator struct {
|
||||
connInfo *connectionInfo
|
||||
client *ssh.Client
|
||||
config *sshConfig
|
||||
conn net.Conn
|
||||
address string
|
||||
connInfo *connectionInfo
|
||||
client *ssh.Client
|
||||
config *sshConfig
|
||||
conn net.Conn
|
||||
address string
|
||||
cancelKeepAlive context.CancelFunc
|
||||
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
@ -205,11 +210,39 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o != nil {
|
||||
o.Output("Connected!")
|
||||
}
|
||||
|
||||
return err
|
||||
ctx, cancelKeepAlive := context.WithCancel(context.TODO())
|
||||
c.cancelKeepAlive = cancelKeepAlive
|
||||
|
||||
// Start a keepalive goroutine to help maintain the connection for
|
||||
// long-running commands.
|
||||
log.Printf("[DEBUG] starting ssh KeepAlives")
|
||||
go func() {
|
||||
t := time.NewTicker(keepAliveInterval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
// there's no useful response to these, just abort when there's
|
||||
// an error.
|
||||
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect implementation of communicator.Communicator interface
|
||||
|
@ -217,6 +250,10 @@ func (c *Communicator) Disconnect() error {
|
|||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.cancelKeepAlive != nil {
|
||||
c.cancelKeepAlive()
|
||||
}
|
||||
|
||||
if c.config.sshAgent != nil {
|
||||
if err := c.config.sshAgent.Close(); err != nil {
|
||||
return err
|
||||
|
|
|
@ -179,6 +179,48 @@ func TestStart(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestKeepAlives verifies that the keepalive messages don't interfere with
|
||||
// normal operation of the client.
|
||||
func TestKeepAlives(t *testing.T) {
|
||||
address := newMockLineServer(t, nil)
|
||||
parts := strings.Split(address, ":")
|
||||
|
||||
r := &terraform.InstanceState{
|
||||
Ephemeral: terraform.EphemeralState{
|
||||
ConnInfo: map[string]string{
|
||||
"type": "ssh",
|
||||
"user": "user",
|
||||
"password": "pass",
|
||||
"host": parts[0],
|
||||
"port": parts[1],
|
||||
"timeout": "30s",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c, err := New(r)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating communicator: %s", err)
|
||||
}
|
||||
|
||||
if err := c.Connect(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var cmd remote.Cmd
|
||||
stdout := new(bytes.Buffer)
|
||||
cmd.Command = "echo foo"
|
||||
cmd.Stdout = stdout
|
||||
|
||||
// wait a bit before executing the command, so that at least 1 keepalive is sent
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
err = c.Start(&cmd)
|
||||
if err != nil {
|
||||
t.Fatalf("error executing remote command: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLostConnection(t *testing.T) {
|
||||
address := newMockLineServer(t, nil, testClientPublicKey)
|
||||
parts := strings.Split(address, ":")
|
||||
|
|
Loading…
Reference in New Issue