Merge pull request #20437 from hashicorp/jbardin/ssh-keepalive
add ssh keepalive messages to communicator
This commit is contained in:
commit
929231a2e9
|
@ -3,6 +3,7 @@ package ssh
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -26,6 +27,9 @@ import (
|
||||||
const (
|
const (
|
||||||
// DefaultShebang is added at the top of a SSH script file
|
// DefaultShebang is added at the top of a SSH script file
|
||||||
DefaultShebang = "#!/bin/sh\n"
|
DefaultShebang = "#!/bin/sh\n"
|
||||||
|
|
||||||
|
// enable ssh keeplive probes by default
|
||||||
|
keepAliveInterval = 2 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// randShared is a global random generator object that is shared.
|
// randShared is a global random generator object that is shared.
|
||||||
|
@ -42,6 +46,7 @@ type Communicator struct {
|
||||||
config *sshConfig
|
config *sshConfig
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
address string
|
address string
|
||||||
|
cancelKeepAlive context.CancelFunc
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
|
@ -205,11 +210,39 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if o != nil {
|
if o != nil {
|
||||||
o.Output("Connected!")
|
o.Output("Connected!")
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
ctx, cancelKeepAlive := context.WithCancel(context.TODO())
|
||||||
|
c.cancelKeepAlive = cancelKeepAlive
|
||||||
|
|
||||||
|
// Start a keepalive goroutine to help maintain the connection for
|
||||||
|
// long-running commands.
|
||||||
|
log.Printf("[DEBUG] starting ssh KeepAlives")
|
||||||
|
go func() {
|
||||||
|
t := time.NewTicker(keepAliveInterval)
|
||||||
|
defer t.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
// there's no useful response to these, just abort when there's
|
||||||
|
// an error.
|
||||||
|
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect implementation of communicator.Communicator interface
|
// Disconnect implementation of communicator.Communicator interface
|
||||||
|
@ -217,6 +250,10 @@ func (c *Communicator) Disconnect() error {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
|
if c.cancelKeepAlive != nil {
|
||||||
|
c.cancelKeepAlive()
|
||||||
|
}
|
||||||
|
|
||||||
if c.config.sshAgent != nil {
|
if c.config.sshAgent != nil {
|
||||||
if err := c.config.sshAgent.Close(); err != nil {
|
if err := c.config.sshAgent.Close(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -179,6 +179,48 @@ func TestStart(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestKeepAlives verifies that the keepalive messages don't interfere with
|
||||||
|
// normal operation of the client.
|
||||||
|
func TestKeepAlives(t *testing.T) {
|
||||||
|
address := newMockLineServer(t, nil)
|
||||||
|
parts := strings.Split(address, ":")
|
||||||
|
|
||||||
|
r := &terraform.InstanceState{
|
||||||
|
Ephemeral: terraform.EphemeralState{
|
||||||
|
ConnInfo: map[string]string{
|
||||||
|
"type": "ssh",
|
||||||
|
"user": "user",
|
||||||
|
"password": "pass",
|
||||||
|
"host": parts[0],
|
||||||
|
"port": parts[1],
|
||||||
|
"timeout": "30s",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := New(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error creating communicator: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Connect(nil); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd remote.Cmd
|
||||||
|
stdout := new(bytes.Buffer)
|
||||||
|
cmd.Command = "echo foo"
|
||||||
|
cmd.Stdout = stdout
|
||||||
|
|
||||||
|
// wait a bit before executing the command, so that at least 1 keepalive is sent
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
err = c.Start(&cmd)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error executing remote command: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLostConnection(t *testing.T) {
|
func TestLostConnection(t *testing.T) {
|
||||||
address := newMockLineServer(t, nil, testClientPublicKey)
|
address := newMockLineServer(t, nil, testClientPublicKey)
|
||||||
parts := strings.Split(address, ":")
|
parts := strings.Split(address, ":")
|
||||||
|
|
Loading…
Reference in New Issue