communicator/ssh: bastion host support
* adds `bastion_*` fields to `connection` which add configuration for a bastion host * if `bastion_host` is set, connect to that host first, then jump through it to make the SSH connection to `host` * enables SSH Agent forwarding by default
This commit is contained in:
parent
e4931771af
commit
a7cbbbd258
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/hashicorp/terraform/communicator/remote"
|
||||
"github.com/hashicorp/terraform/terraform"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -47,9 +48,9 @@ type sshConfig struct {
|
|||
// noPty, if true, will not request a pty from the remote end.
|
||||
noPty bool
|
||||
|
||||
// sshAgentConn is a pointer to the UNIX connection for talking with the
|
||||
// ssh-agent.
|
||||
sshAgentConn net.Conn
|
||||
// sshAgent is a struct surrounding the agent.Agent client and the net.Conn
|
||||
// to the SSH Agent. It is nil if no SSH agent is configured
|
||||
sshAgent *sshAgent
|
||||
}
|
||||
|
||||
// New creates a new communicator implementation over SSH.
|
||||
|
@ -122,6 +123,26 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
|
|||
|
||||
c.client = ssh.NewClient(sshConn, sshChan, req)
|
||||
|
||||
if c.config.sshAgent != nil {
|
||||
log.Printf("[DEBUG] Telling SSH config to foward to agent")
|
||||
if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] Setting up a session to request agent forwarding")
|
||||
session, err := c.newSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
if err = agent.RequestAgentForwarding(session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("[INFO] agent forwarding enabled")
|
||||
}
|
||||
|
||||
if o != nil {
|
||||
o.Output("Connected!")
|
||||
}
|
||||
|
@ -131,8 +152,8 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
|
|||
|
||||
// Disconnect implementation of communicator.Communicator interface
|
||||
func (c *Communicator) Disconnect() error {
|
||||
if c.config.sshAgentConn != nil {
|
||||
return c.config.sshAgentConn.Close()
|
||||
if c.config.sshAgent != nil {
|
||||
return c.config.sshAgent.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -563,3 +584,43 @@ func ConnectFunc(network, addr string) func() (net.Conn, error) {
|
|||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
// BastionConnectFunc is a convenience method for returning a function
|
||||
// that connects to a host over a bastion connection.
|
||||
func BastionConnectFunc(
|
||||
bProto string,
|
||||
bAddr string,
|
||||
bConf *ssh.ClientConfig,
|
||||
proto string,
|
||||
addr string) func() (net.Conn, error) {
|
||||
return func() (net.Conn, error) {
|
||||
log.Printf("[DEBUG] Connecting to bastion: %s", bAddr)
|
||||
bastion, err := ssh.Dial(bProto, bAddr, bConf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error connecting to bastion: %s", err)
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr)
|
||||
conn, err := bastion.Dial(proto, addr)
|
||||
if err != nil {
|
||||
bastion.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wrap it up so we close both things properly
|
||||
return &bastionConn{
|
||||
Conn: conn,
|
||||
Bastion: bastion,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type bastionConn struct {
|
||||
net.Conn
|
||||
Bastion *ssh.Client
|
||||
}
|
||||
|
||||
func (c *bastionConn) Close() error {
|
||||
c.Conn.Close()
|
||||
return c.Bastion.Close()
|
||||
}
|
||||
|
|
|
@ -44,6 +44,12 @@ type connectionInfo struct {
|
|||
Timeout string
|
||||
ScriptPath string `mapstructure:"script_path"`
|
||||
TimeoutVal time.Duration `mapstructure:"-"`
|
||||
|
||||
BastionUser string `mapstructure:"bastion_user"`
|
||||
BastionPassword string `mapstructure:"bastion_password"`
|
||||
BastionKeyFile string `mapstructure:"bastion_key_file"`
|
||||
BastionHost string `mapstructure:"bastion_host"`
|
||||
BastionPort int `mapstructure:"bastion_port"`
|
||||
}
|
||||
|
||||
// parseConnectionInfo is used to convert the ConnInfo of the InstanceState into
|
||||
|
@ -86,6 +92,22 @@ func parseConnectionInfo(s *terraform.InstanceState) (*connectionInfo, error) {
|
|||
connInfo.TimeoutVal = DefaultTimeout
|
||||
}
|
||||
|
||||
// Default all bastion config attrs to their non-bastion counterparts
|
||||
if connInfo.BastionHost != "" {
|
||||
if connInfo.BastionUser == "" {
|
||||
connInfo.BastionUser = connInfo.User
|
||||
}
|
||||
if connInfo.BastionPassword == "" {
|
||||
connInfo.BastionPassword = connInfo.Password
|
||||
}
|
||||
if connInfo.BastionKeyFile == "" {
|
||||
connInfo.BastionKeyFile = connInfo.KeyFile
|
||||
}
|
||||
if connInfo.BastionPort == 0 {
|
||||
connInfo.BastionPort = connInfo.Port
|
||||
}
|
||||
}
|
||||
|
||||
return connInfo, nil
|
||||
}
|
||||
|
||||
|
@ -102,73 +124,152 @@ func safeDuration(dur string, defaultDur time.Duration) time.Duration {
|
|||
// prepareSSHConfig is used to turn the *ConnectionInfo provided into a
|
||||
// usable *SSHConfig for client initialization.
|
||||
func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
|
||||
sshConf := &ssh.ClientConfig{
|
||||
User: connInfo.User,
|
||||
sshAgent, err := connectToAgent(connInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if connInfo.Agent {
|
||||
sshAuthSock := os.Getenv("SSH_AUTH_SOCK")
|
||||
|
||||
if sshAuthSock == "" {
|
||||
return nil, fmt.Errorf("SSH Requested but SSH_AUTH_SOCK not-specified")
|
||||
}
|
||||
|
||||
conn, err = net.Dial("unix", sshAuthSock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error connecting to SSH_AUTH_SOCK: %v", err)
|
||||
}
|
||||
// I need to close this but, later after all connections have been made
|
||||
// defer conn.Close()
|
||||
signers, err := agent.NewClient(conn).Signers()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error getting keys from ssh agent: %v", err)
|
||||
}
|
||||
|
||||
sshConf.Auth = append(sshConf.Auth, ssh.PublicKeys(signers...))
|
||||
sshConf, err := buildSSHClientConfig(sshClientConfigOpts{
|
||||
user: connInfo.User,
|
||||
keyFile: connInfo.KeyFile,
|
||||
password: connInfo.Password,
|
||||
sshAgent: sshAgent,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if connInfo.KeyFile != "" {
|
||||
fullPath, err := homedir.Expand(connInfo.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to expand home directory: %v", err)
|
||||
}
|
||||
key, err := ioutil.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to read key file '%s': %v", connInfo.KeyFile, err)
|
||||
}
|
||||
|
||||
// We parse the private key on our own first so that we can
|
||||
// show a nicer error if the private key has a password.
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf(
|
||||
"Failed to read key '%s': no key found", connInfo.KeyFile)
|
||||
}
|
||||
if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
|
||||
return nil, fmt.Errorf(
|
||||
"Failed to read key '%s': password protected keys are\n"+
|
||||
"not supported. Please decrypt the key prior to use.", connInfo.KeyFile)
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(key)
|
||||
var bastionConf *ssh.ClientConfig
|
||||
if connInfo.BastionHost != "" {
|
||||
bastionConf, err = buildSSHClientConfig(sshClientConfigOpts{
|
||||
user: connInfo.BastionUser,
|
||||
keyFile: connInfo.BastionKeyFile,
|
||||
password: connInfo.BastionPassword,
|
||||
sshAgent: sshAgent,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to parse key file '%s': %v", connInfo.KeyFile, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sshConf.Auth = append(sshConf.Auth, ssh.PublicKeys(signer))
|
||||
}
|
||||
if connInfo.Password != "" {
|
||||
sshConf.Auth = append(sshConf.Auth,
|
||||
ssh.Password(connInfo.Password))
|
||||
sshConf.Auth = append(sshConf.Auth,
|
||||
ssh.KeyboardInteractive(PasswordKeyboardInteractive(connInfo.Password)))
|
||||
}
|
||||
|
||||
host := fmt.Sprintf("%s:%d", connInfo.Host, connInfo.Port)
|
||||
connectFunc := ConnectFunc("tcp", host)
|
||||
|
||||
if bastionConf != nil {
|
||||
bastionHost := fmt.Sprintf("%s:%d", connInfo.BastionHost, connInfo.BastionPort)
|
||||
connectFunc = BastionConnectFunc("tcp", bastionHost, bastionConf, "tcp", host)
|
||||
}
|
||||
|
||||
config := &sshConfig{
|
||||
config: sshConf,
|
||||
connection: ConnectFunc("tcp", host),
|
||||
sshAgentConn: conn,
|
||||
config: sshConf,
|
||||
connection: connectFunc,
|
||||
sshAgent: sshAgent,
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
type sshClientConfigOpts struct {
|
||||
keyFile string
|
||||
password string
|
||||
sshAgent *sshAgent
|
||||
user string
|
||||
}
|
||||
|
||||
func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
|
||||
conf := &ssh.ClientConfig{
|
||||
User: opts.user,
|
||||
}
|
||||
|
||||
if opts.sshAgent != nil {
|
||||
conf.Auth = append(conf.Auth, opts.sshAgent.Auth())
|
||||
}
|
||||
|
||||
if opts.keyFile != "" {
|
||||
pubKeyAuth, err := readPublicKeyFromPath(opts.keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Auth = append(conf.Auth, pubKeyAuth)
|
||||
}
|
||||
|
||||
if opts.password != "" {
|
||||
conf.Auth = append(conf.Auth, ssh.Password(opts.password))
|
||||
conf.Auth = append(conf.Auth, ssh.KeyboardInteractive(
|
||||
PasswordKeyboardInteractive(opts.password)))
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func readPublicKeyFromPath(path string) (ssh.AuthMethod, error) {
|
||||
fullPath, err := homedir.Expand(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to expand home directory: %s", err)
|
||||
}
|
||||
key, err := ioutil.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to read key file %q: %s", path, err)
|
||||
}
|
||||
|
||||
// We parse the private key on our own first so that we can
|
||||
// show a nicer error if the private key has a password.
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("Failed to read key %q: no key found", path)
|
||||
}
|
||||
if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
|
||||
return nil, fmt.Errorf(
|
||||
"Failed to read key %q: password protected keys are\n"+
|
||||
"not supported. Please decrypt the key prior to use.", path)
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to parse key file %q: %s", path, err)
|
||||
}
|
||||
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
|
||||
func connectToAgent(connInfo *connectionInfo) (*sshAgent, error) {
|
||||
if connInfo.Agent != true {
|
||||
// No agent configured
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sshAuthSock := os.Getenv("SSH_AUTH_SOCK")
|
||||
|
||||
if sshAuthSock == "" {
|
||||
return nil, fmt.Errorf("SSH Requested but SSH_AUTH_SOCK not-specified")
|
||||
}
|
||||
|
||||
conn, err := net.Dial("unix", sshAuthSock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error connecting to SSH_AUTH_SOCK: %v", err)
|
||||
}
|
||||
|
||||
// connection close is handled over in Communicator
|
||||
return &sshAgent{
|
||||
agent: agent.NewClient(conn),
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// A tiny wrapper around an agent.Agent to expose the ability to close its
|
||||
// associated connection on request.
|
||||
type sshAgent struct {
|
||||
agent agent.Agent
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (a *sshAgent) Close() error {
|
||||
return a.conn.Close()
|
||||
}
|
||||
|
||||
func (a *sshAgent) Auth() ssh.AuthMethod {
|
||||
return ssh.PublicKeysCallback(a.agent.Signers)
|
||||
}
|
||||
|
||||
func (a *sshAgent) ForwardToAgent(client *ssh.Client) error {
|
||||
return agent.ForwardToAgent(client, a.agent)
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@ func TestProvisioner_connInfo(t *testing.T) {
|
|||
"host": "127.0.0.1",
|
||||
"port": "22",
|
||||
"timeout": "30s",
|
||||
|
||||
"bastion_host": "127.0.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -47,4 +49,19 @@ func TestProvisioner_connInfo(t *testing.T) {
|
|||
if conf.ScriptPath != DefaultScriptPath {
|
||||
t.Fatalf("bad: %v", conf)
|
||||
}
|
||||
if conf.BastionHost != "127.0.1.1" {
|
||||
t.Fatalf("bad: %v", conf)
|
||||
}
|
||||
if conf.BastionPort != 22 {
|
||||
t.Fatalf("bad: %v", conf)
|
||||
}
|
||||
if conf.BastionUser != "root" {
|
||||
t.Fatalf("bad: %v", conf)
|
||||
}
|
||||
if conf.BastionPassword != "supersecret" {
|
||||
t.Fatalf("bad: %v", conf)
|
||||
}
|
||||
if conf.BastionKeyFile != "/my/key/file.pem" {
|
||||
t.Fatalf("bad: %v", conf)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,3 +80,24 @@ provisioner "file" {
|
|||
* `insecure` - Set to true to not validate the HTTPS certificate chain.
|
||||
|
||||
* `cacert` - The CA certificate to validate against.
|
||||
|
||||
<a id="bastion"></a>
|
||||
## Connecting through a Bastion Host with SSH
|
||||
|
||||
The `ssh` connection additionally supports the following fields to facilitate a
|
||||
[bastion host](https://en.wikipedia.org/wiki/Bastion_host) connection.
|
||||
|
||||
* `bastion_host` - Setting this enables the bastion Host connection. This host
|
||||
will be connected to first, and the `host` connection will be made from there.
|
||||
|
||||
* `bastion_port` - The port to use connect to the bastion host. Defaults to the
|
||||
value of `port`.
|
||||
|
||||
* `bastion_user` - The user to use to connect to the bastion host. Defaults to
|
||||
the value of `user`.
|
||||
|
||||
* `bastion_password` - The password we should use for the bastion host.
|
||||
Defaults to the value of `password`.
|
||||
|
||||
* `bastion_key_file` - The SSH key to use for the bastion host. Defaults to the
|
||||
value of `key_file`.
|
||||
|
|
Loading…
Reference in New Issue