diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index 402331da1..672dc78af 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/terraform/communicator/remote" "github.com/hashicorp/terraform/terraform" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" ) const ( @@ -47,9 +48,9 @@ type sshConfig struct { // noPty, if true, will not request a pty from the remote end. noPty bool - // sshAgentConn is a pointer to the UNIX connection for talking with the - // ssh-agent. - sshAgentConn net.Conn + // sshAgent is a struct surrounding the agent.Agent client and the net.Conn + // to the SSH Agent. It is nil if no SSH agent is configured + sshAgent *sshAgent } // New creates a new communicator implementation over SSH. @@ -122,6 +123,26 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { c.client = ssh.NewClient(sshConn, sshChan, req) + if c.config.sshAgent != nil { + log.Printf("[DEBUG] Telling SSH config to foward to agent") + if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { + return err + } + + log.Printf("[DEBUG] Setting up a session to request agent forwarding") + session, err := c.newSession() + if err != nil { + return err + } + defer session.Close() + + if err = agent.RequestAgentForwarding(session); err != nil { + return err + } + + log.Printf("[INFO] agent forwarding enabled") + } + if o != nil { o.Output("Connected!") } @@ -131,8 +152,8 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { // Disconnect implementation of communicator.Communicator interface func (c *Communicator) Disconnect() error { - if c.config.sshAgentConn != nil { - return c.config.sshAgentConn.Close() + if c.config.sshAgent != nil { + return c.config.sshAgent.Close() } return nil @@ -563,3 +584,43 @@ func ConnectFunc(network, addr string) func() (net.Conn, error) { return c, nil } } + +// BastionConnectFunc is a convenience method for returning a function +// that connects to a host over a bastion connection. +func BastionConnectFunc( + bProto string, + bAddr string, + bConf *ssh.ClientConfig, + proto string, + addr string) func() (net.Conn, error) { + return func() (net.Conn, error) { + log.Printf("[DEBUG] Connecting to bastion: %s", bAddr) + bastion, err := ssh.Dial(bProto, bAddr, bConf) + if err != nil { + return nil, fmt.Errorf("Error connecting to bastion: %s", err) + } + + log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr) + conn, err := bastion.Dial(proto, addr) + if err != nil { + bastion.Close() + return nil, err + } + + // Wrap it up so we close both things properly + return &bastionConn{ + Conn: conn, + Bastion: bastion, + }, nil + } +} + +type bastionConn struct { + net.Conn + Bastion *ssh.Client +} + +func (c *bastionConn) Close() error { + c.Conn.Close() + return c.Bastion.Close() +} diff --git a/communicator/ssh/provisioner.go b/communicator/ssh/provisioner.go index 6facfcf52..cda7fc2f3 100644 --- a/communicator/ssh/provisioner.go +++ b/communicator/ssh/provisioner.go @@ -44,6 +44,12 @@ type connectionInfo struct { Timeout string ScriptPath string `mapstructure:"script_path"` TimeoutVal time.Duration `mapstructure:"-"` + + BastionUser string `mapstructure:"bastion_user"` + BastionPassword string `mapstructure:"bastion_password"` + BastionKeyFile string `mapstructure:"bastion_key_file"` + BastionHost string `mapstructure:"bastion_host"` + BastionPort int `mapstructure:"bastion_port"` } // parseConnectionInfo is used to convert the ConnInfo of the InstanceState into @@ -86,6 +92,22 @@ func parseConnectionInfo(s *terraform.InstanceState) (*connectionInfo, error) { connInfo.TimeoutVal = DefaultTimeout } + // Default all bastion config attrs to their non-bastion counterparts + if connInfo.BastionHost != "" { + if connInfo.BastionUser == "" { + connInfo.BastionUser = connInfo.User + } + if connInfo.BastionPassword == "" { + connInfo.BastionPassword = connInfo.Password + } + if connInfo.BastionKeyFile == "" { + connInfo.BastionKeyFile = connInfo.KeyFile + } + if connInfo.BastionPort == 0 { + connInfo.BastionPort = connInfo.Port + } + } + return connInfo, nil } @@ -102,73 +124,152 @@ func safeDuration(dur string, defaultDur time.Duration) time.Duration { // prepareSSHConfig is used to turn the *ConnectionInfo provided into a // usable *SSHConfig for client initialization. func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) { - var conn net.Conn - var err error - - sshConf := &ssh.ClientConfig{ - User: connInfo.User, + sshAgent, err := connectToAgent(connInfo) + if err != nil { + return nil, err } - if connInfo.Agent { - sshAuthSock := os.Getenv("SSH_AUTH_SOCK") - if sshAuthSock == "" { - return nil, fmt.Errorf("SSH Requested but SSH_AUTH_SOCK not-specified") - } - - conn, err = net.Dial("unix", sshAuthSock) - if err != nil { - return nil, fmt.Errorf("Error connecting to SSH_AUTH_SOCK: %v", err) - } - // I need to close this but, later after all connections have been made - // defer conn.Close() - signers, err := agent.NewClient(conn).Signers() - if err != nil { - return nil, fmt.Errorf("Error getting keys from ssh agent: %v", err) - } - - sshConf.Auth = append(sshConf.Auth, ssh.PublicKeys(signers...)) + sshConf, err := buildSSHClientConfig(sshClientConfigOpts{ + user: connInfo.User, + keyFile: connInfo.KeyFile, + password: connInfo.Password, + sshAgent: sshAgent, + }) + if err != nil { + return nil, err } - if connInfo.KeyFile != "" { - fullPath, err := homedir.Expand(connInfo.KeyFile) - if err != nil { - return nil, fmt.Errorf("Failed to expand home directory: %v", err) - } - key, err := ioutil.ReadFile(fullPath) - if err != nil { - return nil, fmt.Errorf("Failed to read key file '%s': %v", connInfo.KeyFile, err) - } - // We parse the private key on our own first so that we can - // show a nicer error if the private key has a password. - block, _ := pem.Decode(key) - if block == nil { - return nil, fmt.Errorf( - "Failed to read key '%s': no key found", connInfo.KeyFile) - } - if block.Headers["Proc-Type"] == "4,ENCRYPTED" { - return nil, fmt.Errorf( - "Failed to read key '%s': password protected keys are\n"+ - "not supported. Please decrypt the key prior to use.", connInfo.KeyFile) - } - - signer, err := ssh.ParsePrivateKey(key) + var bastionConf *ssh.ClientConfig + if connInfo.BastionHost != "" { + bastionConf, err = buildSSHClientConfig(sshClientConfigOpts{ + user: connInfo.BastionUser, + keyFile: connInfo.BastionKeyFile, + password: connInfo.BastionPassword, + sshAgent: sshAgent, + }) if err != nil { - return nil, fmt.Errorf("Failed to parse key file '%s': %v", connInfo.KeyFile, err) + return nil, err } - - sshConf.Auth = append(sshConf.Auth, ssh.PublicKeys(signer)) - } - if connInfo.Password != "" { - sshConf.Auth = append(sshConf.Auth, - ssh.Password(connInfo.Password)) - sshConf.Auth = append(sshConf.Auth, - ssh.KeyboardInteractive(PasswordKeyboardInteractive(connInfo.Password))) } + host := fmt.Sprintf("%s:%d", connInfo.Host, connInfo.Port) + connectFunc := ConnectFunc("tcp", host) + + if bastionConf != nil { + bastionHost := fmt.Sprintf("%s:%d", connInfo.BastionHost, connInfo.BastionPort) + connectFunc = BastionConnectFunc("tcp", bastionHost, bastionConf, "tcp", host) + } + config := &sshConfig{ - config: sshConf, - connection: ConnectFunc("tcp", host), - sshAgentConn: conn, + config: sshConf, + connection: connectFunc, + sshAgent: sshAgent, } return config, nil } + +type sshClientConfigOpts struct { + keyFile string + password string + sshAgent *sshAgent + user string +} + +func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) { + conf := &ssh.ClientConfig{ + User: opts.user, + } + + if opts.sshAgent != nil { + conf.Auth = append(conf.Auth, opts.sshAgent.Auth()) + } + + if opts.keyFile != "" { + pubKeyAuth, err := readPublicKeyFromPath(opts.keyFile) + if err != nil { + return nil, err + } + conf.Auth = append(conf.Auth, pubKeyAuth) + } + + if opts.password != "" { + conf.Auth = append(conf.Auth, ssh.Password(opts.password)) + conf.Auth = append(conf.Auth, ssh.KeyboardInteractive( + PasswordKeyboardInteractive(opts.password))) + } + + return conf, nil +} + +func readPublicKeyFromPath(path string) (ssh.AuthMethod, error) { + fullPath, err := homedir.Expand(path) + if err != nil { + return nil, fmt.Errorf("Failed to expand home directory: %s", err) + } + key, err := ioutil.ReadFile(fullPath) + if err != nil { + return nil, fmt.Errorf("Failed to read key file %q: %s", path, err) + } + + // We parse the private key on our own first so that we can + // show a nicer error if the private key has a password. + block, _ := pem.Decode(key) + if block == nil { + return nil, fmt.Errorf("Failed to read key %q: no key found", path) + } + if block.Headers["Proc-Type"] == "4,ENCRYPTED" { + return nil, fmt.Errorf( + "Failed to read key %q: password protected keys are\n"+ + "not supported. Please decrypt the key prior to use.", path) + } + + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return nil, fmt.Errorf("Failed to parse key file %q: %s", path, err) + } + + return ssh.PublicKeys(signer), nil +} + +func connectToAgent(connInfo *connectionInfo) (*sshAgent, error) { + if connInfo.Agent != true { + // No agent configured + return nil, nil + } + + sshAuthSock := os.Getenv("SSH_AUTH_SOCK") + + if sshAuthSock == "" { + return nil, fmt.Errorf("SSH Requested but SSH_AUTH_SOCK not-specified") + } + + conn, err := net.Dial("unix", sshAuthSock) + if err != nil { + return nil, fmt.Errorf("Error connecting to SSH_AUTH_SOCK: %v", err) + } + + // connection close is handled over in Communicator + return &sshAgent{ + agent: agent.NewClient(conn), + conn: conn, + }, nil +} + +// A tiny wrapper around an agent.Agent to expose the ability to close its +// associated connection on request. +type sshAgent struct { + agent agent.Agent + conn net.Conn +} + +func (a *sshAgent) Close() error { + return a.conn.Close() +} + +func (a *sshAgent) Auth() ssh.AuthMethod { + return ssh.PublicKeysCallback(a.agent.Signers) +} + +func (a *sshAgent) ForwardToAgent(client *ssh.Client) error { + return agent.ForwardToAgent(client, a.agent) +} diff --git a/communicator/ssh/provisioner_test.go b/communicator/ssh/provisioner_test.go index 33c2b7b7b..fc6b686fb 100644 --- a/communicator/ssh/provisioner_test.go +++ b/communicator/ssh/provisioner_test.go @@ -17,6 +17,8 @@ func TestProvisioner_connInfo(t *testing.T) { "host": "127.0.0.1", "port": "22", "timeout": "30s", + + "bastion_host": "127.0.1.1", }, }, } @@ -47,4 +49,19 @@ func TestProvisioner_connInfo(t *testing.T) { if conf.ScriptPath != DefaultScriptPath { t.Fatalf("bad: %v", conf) } + if conf.BastionHost != "127.0.1.1" { + t.Fatalf("bad: %v", conf) + } + if conf.BastionPort != 22 { + t.Fatalf("bad: %v", conf) + } + if conf.BastionUser != "root" { + t.Fatalf("bad: %v", conf) + } + if conf.BastionPassword != "supersecret" { + t.Fatalf("bad: %v", conf) + } + if conf.BastionKeyFile != "/my/key/file.pem" { + t.Fatalf("bad: %v", conf) + } } diff --git a/website/source/docs/provisioners/connection.html.markdown b/website/source/docs/provisioners/connection.html.markdown index 57524486c..6efd73e83 100644 --- a/website/source/docs/provisioners/connection.html.markdown +++ b/website/source/docs/provisioners/connection.html.markdown @@ -80,3 +80,24 @@ provisioner "file" { * `insecure` - Set to true to not validate the HTTPS certificate chain. * `cacert` - The CA certificate to validate against. + + +## Connecting through a Bastion Host with SSH + +The `ssh` connection additionally supports the following fields to facilitate a +[bastion host](https://en.wikipedia.org/wiki/Bastion_host) connection. + +* `bastion_host` - Setting this enables the bastion Host connection. This host + will be connected to first, and the `host` connection will be made from there. + +* `bastion_port` - The port to use connect to the bastion host. Defaults to the + value of `port`. + +* `bastion_user` - The user to use to connect to the bastion host. Defaults to + the value of `user`. + +* `bastion_password` - The password we should use for the bastion host. + Defaults to the value of `password`. + +* `bastion_key_file` - The SSH key to use for the bastion host. Defaults to the + value of `key_file`.