communicator/ssh: bastion host support

* adds `bastion_*` fields to `connection` which add configuration for a
   bastion host
 * if `bastion_host` is set, connect to that host first, then jump
   through it to make the SSH connection to `host`
 * enables SSH Agent forwarding by default
This commit is contained in:
Paul Hinze 2015-06-22 11:34:02 -05:00
parent e4931771af
commit a7cbbbd258
4 changed files with 263 additions and 63 deletions

View File

@ -19,6 +19,7 @@ import (
"github.com/hashicorp/terraform/communicator/remote"
"github.com/hashicorp/terraform/terraform"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
const (
@ -47,9 +48,9 @@ type sshConfig struct {
// noPty, if true, will not request a pty from the remote end.
noPty bool
// sshAgentConn is a pointer to the UNIX connection for talking with the
// ssh-agent.
sshAgentConn net.Conn
// sshAgent is a struct surrounding the agent.Agent client and the net.Conn
// to the SSH Agent. It is nil if no SSH agent is configured
sshAgent *sshAgent
}
// New creates a new communicator implementation over SSH.
@ -122,6 +123,26 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
c.client = ssh.NewClient(sshConn, sshChan, req)
if c.config.sshAgent != nil {
log.Printf("[DEBUG] Telling SSH config to foward to agent")
if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil {
return err
}
log.Printf("[DEBUG] Setting up a session to request agent forwarding")
session, err := c.newSession()
if err != nil {
return err
}
defer session.Close()
if err = agent.RequestAgentForwarding(session); err != nil {
return err
}
log.Printf("[INFO] agent forwarding enabled")
}
if o != nil {
o.Output("Connected!")
}
@ -131,8 +152,8 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
// Disconnect implementation of communicator.Communicator interface
func (c *Communicator) Disconnect() error {
if c.config.sshAgentConn != nil {
return c.config.sshAgentConn.Close()
if c.config.sshAgent != nil {
return c.config.sshAgent.Close()
}
return nil
@ -563,3 +584,43 @@ func ConnectFunc(network, addr string) func() (net.Conn, error) {
return c, nil
}
}
// BastionConnectFunc is a convenience method for returning a function
// that connects to a host over a bastion connection.
func BastionConnectFunc(
bProto string,
bAddr string,
bConf *ssh.ClientConfig,
proto string,
addr string) func() (net.Conn, error) {
return func() (net.Conn, error) {
log.Printf("[DEBUG] Connecting to bastion: %s", bAddr)
bastion, err := ssh.Dial(bProto, bAddr, bConf)
if err != nil {
return nil, fmt.Errorf("Error connecting to bastion: %s", err)
}
log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr)
conn, err := bastion.Dial(proto, addr)
if err != nil {
bastion.Close()
return nil, err
}
// Wrap it up so we close both things properly
return &bastionConn{
Conn: conn,
Bastion: bastion,
}, nil
}
}
type bastionConn struct {
net.Conn
Bastion *ssh.Client
}
func (c *bastionConn) Close() error {
c.Conn.Close()
return c.Bastion.Close()
}

View File

@ -44,6 +44,12 @@ type connectionInfo struct {
Timeout string
ScriptPath string `mapstructure:"script_path"`
TimeoutVal time.Duration `mapstructure:"-"`
BastionUser string `mapstructure:"bastion_user"`
BastionPassword string `mapstructure:"bastion_password"`
BastionKeyFile string `mapstructure:"bastion_key_file"`
BastionHost string `mapstructure:"bastion_host"`
BastionPort int `mapstructure:"bastion_port"`
}
// parseConnectionInfo is used to convert the ConnInfo of the InstanceState into
@ -86,6 +92,22 @@ func parseConnectionInfo(s *terraform.InstanceState) (*connectionInfo, error) {
connInfo.TimeoutVal = DefaultTimeout
}
// Default all bastion config attrs to their non-bastion counterparts
if connInfo.BastionHost != "" {
if connInfo.BastionUser == "" {
connInfo.BastionUser = connInfo.User
}
if connInfo.BastionPassword == "" {
connInfo.BastionPassword = connInfo.Password
}
if connInfo.BastionKeyFile == "" {
connInfo.BastionKeyFile = connInfo.KeyFile
}
if connInfo.BastionPort == 0 {
connInfo.BastionPort = connInfo.Port
}
}
return connInfo, nil
}
@ -102,73 +124,152 @@ func safeDuration(dur string, defaultDur time.Duration) time.Duration {
// prepareSSHConfig is used to turn the *ConnectionInfo provided into a
// usable *SSHConfig for client initialization.
func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) {
var conn net.Conn
var err error
sshConf := &ssh.ClientConfig{
User: connInfo.User,
sshAgent, err := connectToAgent(connInfo)
if err != nil {
return nil, err
}
if connInfo.Agent {
sshAuthSock := os.Getenv("SSH_AUTH_SOCK")
if sshAuthSock == "" {
return nil, fmt.Errorf("SSH Requested but SSH_AUTH_SOCK not-specified")
}
conn, err = net.Dial("unix", sshAuthSock)
if err != nil {
return nil, fmt.Errorf("Error connecting to SSH_AUTH_SOCK: %v", err)
}
// I need to close this but, later after all connections have been made
// defer conn.Close()
signers, err := agent.NewClient(conn).Signers()
if err != nil {
return nil, fmt.Errorf("Error getting keys from ssh agent: %v", err)
}
sshConf.Auth = append(sshConf.Auth, ssh.PublicKeys(signers...))
sshConf, err := buildSSHClientConfig(sshClientConfigOpts{
user: connInfo.User,
keyFile: connInfo.KeyFile,
password: connInfo.Password,
sshAgent: sshAgent,
})
if err != nil {
return nil, err
}
if connInfo.KeyFile != "" {
fullPath, err := homedir.Expand(connInfo.KeyFile)
if err != nil {
return nil, fmt.Errorf("Failed to expand home directory: %v", err)
}
key, err := ioutil.ReadFile(fullPath)
if err != nil {
return nil, fmt.Errorf("Failed to read key file '%s': %v", connInfo.KeyFile, err)
}
// We parse the private key on our own first so that we can
// show a nicer error if the private key has a password.
block, _ := pem.Decode(key)
if block == nil {
return nil, fmt.Errorf(
"Failed to read key '%s': no key found", connInfo.KeyFile)
}
if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
return nil, fmt.Errorf(
"Failed to read key '%s': password protected keys are\n"+
"not supported. Please decrypt the key prior to use.", connInfo.KeyFile)
}
signer, err := ssh.ParsePrivateKey(key)
var bastionConf *ssh.ClientConfig
if connInfo.BastionHost != "" {
bastionConf, err = buildSSHClientConfig(sshClientConfigOpts{
user: connInfo.BastionUser,
keyFile: connInfo.BastionKeyFile,
password: connInfo.BastionPassword,
sshAgent: sshAgent,
})
if err != nil {
return nil, fmt.Errorf("Failed to parse key file '%s': %v", connInfo.KeyFile, err)
return nil, err
}
sshConf.Auth = append(sshConf.Auth, ssh.PublicKeys(signer))
}
if connInfo.Password != "" {
sshConf.Auth = append(sshConf.Auth,
ssh.Password(connInfo.Password))
sshConf.Auth = append(sshConf.Auth,
ssh.KeyboardInteractive(PasswordKeyboardInteractive(connInfo.Password)))
}
host := fmt.Sprintf("%s:%d", connInfo.Host, connInfo.Port)
connectFunc := ConnectFunc("tcp", host)
if bastionConf != nil {
bastionHost := fmt.Sprintf("%s:%d", connInfo.BastionHost, connInfo.BastionPort)
connectFunc = BastionConnectFunc("tcp", bastionHost, bastionConf, "tcp", host)
}
config := &sshConfig{
config: sshConf,
connection: ConnectFunc("tcp", host),
sshAgentConn: conn,
config: sshConf,
connection: connectFunc,
sshAgent: sshAgent,
}
return config, nil
}
type sshClientConfigOpts struct {
keyFile string
password string
sshAgent *sshAgent
user string
}
func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
conf := &ssh.ClientConfig{
User: opts.user,
}
if opts.sshAgent != nil {
conf.Auth = append(conf.Auth, opts.sshAgent.Auth())
}
if opts.keyFile != "" {
pubKeyAuth, err := readPublicKeyFromPath(opts.keyFile)
if err != nil {
return nil, err
}
conf.Auth = append(conf.Auth, pubKeyAuth)
}
if opts.password != "" {
conf.Auth = append(conf.Auth, ssh.Password(opts.password))
conf.Auth = append(conf.Auth, ssh.KeyboardInteractive(
PasswordKeyboardInteractive(opts.password)))
}
return conf, nil
}
func readPublicKeyFromPath(path string) (ssh.AuthMethod, error) {
fullPath, err := homedir.Expand(path)
if err != nil {
return nil, fmt.Errorf("Failed to expand home directory: %s", err)
}
key, err := ioutil.ReadFile(fullPath)
if err != nil {
return nil, fmt.Errorf("Failed to read key file %q: %s", path, err)
}
// We parse the private key on our own first so that we can
// show a nicer error if the private key has a password.
block, _ := pem.Decode(key)
if block == nil {
return nil, fmt.Errorf("Failed to read key %q: no key found", path)
}
if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
return nil, fmt.Errorf(
"Failed to read key %q: password protected keys are\n"+
"not supported. Please decrypt the key prior to use.", path)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil, fmt.Errorf("Failed to parse key file %q: %s", path, err)
}
return ssh.PublicKeys(signer), nil
}
func connectToAgent(connInfo *connectionInfo) (*sshAgent, error) {
if connInfo.Agent != true {
// No agent configured
return nil, nil
}
sshAuthSock := os.Getenv("SSH_AUTH_SOCK")
if sshAuthSock == "" {
return nil, fmt.Errorf("SSH Requested but SSH_AUTH_SOCK not-specified")
}
conn, err := net.Dial("unix", sshAuthSock)
if err != nil {
return nil, fmt.Errorf("Error connecting to SSH_AUTH_SOCK: %v", err)
}
// connection close is handled over in Communicator
return &sshAgent{
agent: agent.NewClient(conn),
conn: conn,
}, nil
}
// A tiny wrapper around an agent.Agent to expose the ability to close its
// associated connection on request.
type sshAgent struct {
agent agent.Agent
conn net.Conn
}
func (a *sshAgent) Close() error {
return a.conn.Close()
}
func (a *sshAgent) Auth() ssh.AuthMethod {
return ssh.PublicKeysCallback(a.agent.Signers)
}
func (a *sshAgent) ForwardToAgent(client *ssh.Client) error {
return agent.ForwardToAgent(client, a.agent)
}

View File

@ -17,6 +17,8 @@ func TestProvisioner_connInfo(t *testing.T) {
"host": "127.0.0.1",
"port": "22",
"timeout": "30s",
"bastion_host": "127.0.1.1",
},
},
}
@ -47,4 +49,19 @@ func TestProvisioner_connInfo(t *testing.T) {
if conf.ScriptPath != DefaultScriptPath {
t.Fatalf("bad: %v", conf)
}
if conf.BastionHost != "127.0.1.1" {
t.Fatalf("bad: %v", conf)
}
if conf.BastionPort != 22 {
t.Fatalf("bad: %v", conf)
}
if conf.BastionUser != "root" {
t.Fatalf("bad: %v", conf)
}
if conf.BastionPassword != "supersecret" {
t.Fatalf("bad: %v", conf)
}
if conf.BastionKeyFile != "/my/key/file.pem" {
t.Fatalf("bad: %v", conf)
}
}

View File

@ -80,3 +80,24 @@ provisioner "file" {
* `insecure` - Set to true to not validate the HTTPS certificate chain.
* `cacert` - The CA certificate to validate against.
<a id="bastion"></a>
## Connecting through a Bastion Host with SSH
The `ssh` connection additionally supports the following fields to facilitate a
[bastion host](https://en.wikipedia.org/wiki/Bastion_host) connection.
* `bastion_host` - Setting this enables the bastion Host connection. This host
will be connected to first, and the `host` connection will be made from there.
* `bastion_port` - The port to use connect to the bastion host. Defaults to the
value of `port`.
* `bastion_user` - The user to use to connect to the bastion host. Defaults to
the value of `user`.
* `bastion_password` - The password we should use for the bastion host.
Defaults to the value of `password`.
* `bastion_key_file` - The SSH key to use for the bastion host. Defaults to the
value of `key_file`.