diff --git a/builtin/bins/provisioner-remote-exec/main.go b/builtin/bins/provisioner-remote-exec/main.go new file mode 100644 index 000000000..b7b2c53b4 --- /dev/null +++ b/builtin/bins/provisioner-remote-exec/main.go @@ -0,0 +1,10 @@ +package main + +import ( + "github.com/hashicorp/terraform/builtin/provisioners/remote-exec" + "github.com/hashicorp/terraform/plugin" +) + +func main() { + plugin.Serve(new(remoteexec.ResourceProvisioner)) +} diff --git a/builtin/bins/provisioner-remote-exec/main_test.go b/builtin/bins/provisioner-remote-exec/main_test.go new file mode 100644 index 000000000..06ab7d0f9 --- /dev/null +++ b/builtin/bins/provisioner-remote-exec/main_test.go @@ -0,0 +1 @@ +package main diff --git a/builtin/providers/aws/resource_aws_instance.go b/builtin/providers/aws/resource_aws_instance.go index df00804a2..a7fdb7edc 100644 --- a/builtin/providers/aws/resource_aws_instance.go +++ b/builtin/providers/aws/resource_aws_instance.go @@ -88,6 +88,10 @@ func resource_aws_instance_create( instance = instanceRaw.(*ec2.Instance) + // Initialize the connection info + rs.ConnInfo["type"] = "ssh" + rs.ConnInfo["host"] = instance.PublicIpAddress + // Set our attributes rs, err = resource_aws_instance_update_state(rs, instance) if err != nil { diff --git a/builtin/provisioners/remote-exec/resource_provisioner.go b/builtin/provisioners/remote-exec/resource_provisioner.go new file mode 100644 index 000000000..7fdf37cb3 --- /dev/null +++ b/builtin/provisioners/remote-exec/resource_provisioner.go @@ -0,0 +1,359 @@ +package remoteexec + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "strings" + "time" + + "code.google.com/p/go.crypto/ssh" + helper "github.com/hashicorp/terraform/helper/ssh" + "github.com/hashicorp/terraform/terraform" + "github.com/mitchellh/mapstructure" +) + +const ( + // DefaultUser is used if there is no default user given + DefaultUser = "root" + + // DefaultPort is used if there is no port given + DefaultPort = 22 + + // DefaultScriptPath is used as the path to copy the file to + // for remote execution if not provided otherwise. + DefaultScriptPath = "/tmp/script.sh" + + // DefaultTimeout is used if there is no timeout given + DefaultTimeout = 5 * time.Minute + + // DefaultShebang is added at the top of the script file + DefaultShebang = "#!/bin/sh" +) + +type ResourceProvisioner struct{} + +// SSHConfig is decoded from the ConnInfo of the resource. These +// are the only keys we look at. If a KeyFile is given, that is used +// instead of a password. +type SSHConfig struct { + User string + Password string + KeyFile string `mapstructure:"key_file"` + Host string + Port int + Timeout string + ScriptPath string `mapstructure:"script_path"` + TimeoutVal time.Duration `mapstructure:"-"` +} + +func (p *ResourceProvisioner) Apply(s *terraform.ResourceState, + c *terraform.ResourceConfig) (*terraform.ResourceState, error) { + // Ensure the connection type is SSH + if err := p.verifySSH(s); err != nil { + return s, err + } + + // Get the SSH configuration + conf, err := p.sshConfig(s) + if err != nil { + return s, err + } + + // Collect the scripts + scripts, err := p.collectScripts(c) + if err != nil { + return s, err + } + for _, s := range scripts { + defer s.Close() + } + + // Copy and execute each script + if err := p.runScripts(conf, scripts); err != nil { + return s, err + } + return s, nil +} + +func (p *ResourceProvisioner) Validate(c *terraform.ResourceConfig) (ws []string, es []error) { + num := 0 + for name := range c.Raw { + switch name { + case "scripts": + fallthrough + case "script": + fallthrough + case "inline": + num++ + default: + es = append(es, fmt.Errorf("Unknown configuration '%s'", name)) + } + } + if num != 1 { + es = append(es, fmt.Errorf("Must provide one of 'scripts', 'script' or 'inline' to remote-exec")) + } + return +} + +// verifySSH is used to verify the ConnInfo is usable by remote-exec +func (p *ResourceProvisioner) verifySSH(s *terraform.ResourceState) error { + connType := s.ConnInfo["type"] + switch connType { + case "": + case "ssh": + default: + return fmt.Errorf("Connection type '%s' not supported", connType) + } + return nil +} + +// sshConfig is used to convert the ConnInfo of the ResourceState into +// a SSHConfig struct +func (p *ResourceProvisioner) sshConfig(s *terraform.ResourceState) (*SSHConfig, error) { + sshConf := &SSHConfig{} + decConf := &mapstructure.DecoderConfig{ + WeaklyTypedInput: true, + Result: sshConf, + } + dec, err := mapstructure.NewDecoder(decConf) + if err != nil { + return nil, err + } + if err := dec.Decode(s.ConnInfo); err != nil { + return nil, err + } + if sshConf.User == "" { + sshConf.User = DefaultUser + } + if sshConf.Port == 0 { + sshConf.Port = DefaultPort + } + if sshConf.ScriptPath == "" { + sshConf.ScriptPath = DefaultScriptPath + } + if sshConf.Timeout != "" { + sshConf.TimeoutVal = safeDuration(sshConf.Timeout, DefaultTimeout) + } else { + sshConf.TimeoutVal = DefaultTimeout + } + return sshConf, nil +} + +// generateScript takes the configuration and creates a script to be executed +// from the inline configs +func (p *ResourceProvisioner) generateScript(c *terraform.ResourceConfig) (string, error) { + lines := []string{DefaultShebang} + command, ok := c.Config["inline"] + if ok { + switch cmd := command.(type) { + case string: + lines = append(lines, cmd) + case []string: + lines = append(lines, cmd...) + case []interface{}: + for _, l := range cmd { + lStr, ok := l.(string) + if ok { + lines = append(lines, lStr) + } else { + return "", fmt.Errorf("Unsupported 'inline' type! Must be string, or list of strings.") + } + } + default: + return "", fmt.Errorf("Unsupported 'inline' type! Must be string, or list of strings.") + } + } + lines = append(lines, "") + return strings.Join(lines, "\n"), nil +} + +// collectScripts is used to collect all the scripts we need +// to execute in preperation for copying them. +func (p *ResourceProvisioner) collectScripts(c *terraform.ResourceConfig) ([]io.ReadCloser, error) { + // Check if inline + _, ok := c.Config["inline"] + if ok { + script, err := p.generateScript(c) + if err != nil { + return nil, err + } + rc := ioutil.NopCloser(bytes.NewReader([]byte(script))) + return []io.ReadCloser{rc}, nil + } + + // Collect scripts + var scripts []string + s, ok := c.Config["script"] + if ok { + sStr, ok := s.(string) + if !ok { + return nil, fmt.Errorf("Unsupported 'script' type! Must be a string.") + } + scripts = append(scripts, sStr) + } + + sl, ok := c.Config["scripts"] + if ok { + switch slt := sl.(type) { + case []string: + scripts = append(scripts, slt...) + case []interface{}: + for _, l := range slt { + lStr, ok := l.(string) + if ok { + scripts = append(scripts, lStr) + } else { + return nil, fmt.Errorf("Unsupported 'scripts' type! Must be list of strings.") + } + } + default: + return nil, fmt.Errorf("Unsupported 'scripts' type! Must be list of strings.") + } + } + + // Open all the scripts + var fhs []io.ReadCloser + for _, s := range scripts { + fh, err := os.Open(s) + if err != nil { + for _, fh := range fhs { + fh.Close() + } + return nil, fmt.Errorf("Failed to open script '%s': %v", s, err) + } + fhs = append(fhs, fh) + } + + // Done, return the file handles + return fhs, nil +} + +// runScripts is used to copy and execute a set of scripts +func (p *ResourceProvisioner) runScripts(conf *SSHConfig, scripts []io.ReadCloser) error { + sshConf := &ssh.ClientConfig{ + User: conf.User, + } + if conf.KeyFile != "" { + key, err := ioutil.ReadFile(conf.KeyFile) + if err != nil { + return fmt.Errorf("Failed to read key file '%s': %v", conf.KeyFile, err) + } + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return fmt.Errorf("Failed to parse key file '%s': %v", conf.KeyFile, err) + } + sshConf.Auth = append(sshConf.Auth, ssh.PublicKeys(signer)) + } + if conf.Password != "" { + sshConf.Auth = append(sshConf.Auth, + ssh.Password(conf.Password)) + sshConf.Auth = append(sshConf.Auth, + ssh.KeyboardInteractive(helper.PasswordKeyboardInteractive(conf.Password))) + } + host := fmt.Sprintf("%s:%d", conf.Host, conf.Port) + config := &helper.Config{ + SSHConfig: sshConf, + Connection: helper.ConnectFunc("tcp", host), + } + + // Wait and retry until we establish the SSH connection + var comm *helper.SSHCommunicator + err := retryFunc(conf.TimeoutVal, func() error { + var err error + comm, err = helper.New(host, config) + return err + }) + if err != nil { + return err + } + + for _, script := range scripts { + var cmd *helper.RemoteCmd + err := retryFunc(conf.TimeoutVal, func() error { + if err := comm.Upload(conf.ScriptPath, script); err != nil { + return fmt.Errorf("Failed to upload script: %v", err) + } + cmd = &helper.RemoteCmd{ + Command: fmt.Sprintf("chmod 0777 %s", conf.ScriptPath), + } + if err := comm.Start(cmd); err != nil { + return fmt.Errorf( + "Error chmodding script file to 0777 in remote "+ + "machine: %s", err) + } + cmd.Wait() + + rPipe1, wPipe1 := io.Pipe() + rPipe2, wPipe2 := io.Pipe() + go streamLogs(rPipe1, "stdout") + go streamLogs(rPipe2, "stderr") + + cmd = &helper.RemoteCmd{ + Command: conf.ScriptPath, + Stdout: wPipe1, + Stderr: wPipe2, + } + if err := comm.Start(cmd); err != nil { + return fmt.Errorf("Error starting script: %v", err) + } + return nil + }) + if err != nil { + return err + } + + cmd.Wait() + if cmd.ExitStatus != 0 { + return fmt.Errorf("Script exited with non-zero exit status: %d", cmd.ExitStatus) + } + } + + return nil +} + +// retryFunc is used to retry a function for a given duration +func retryFunc(timeout time.Duration, f func() error) error { + finish := time.After(timeout) + for { + err := f() + if err == nil { + return nil + } + log.Printf("Retryable error: %v", err) + + select { + case <-finish: + return err + case <-time.After(3 * time.Second): + } + } +} + +// safeDuration returns either the parsed duration or a default value +func safeDuration(dur string, defaultDur time.Duration) time.Duration { + d, err := time.ParseDuration(dur) + if err != nil { + log.Printf("Invalid duration '%s' for remote-exec, using default", dur) + return defaultDur + } + return d +} + +// streamLogs is used to stream lines from stdout/stderr +// of a remote command to log output for users. +func streamLogs(r io.ReadCloser, name string) { + defer r.Close() + bufR := bufio.NewReader(r) + for { + line, err := bufR.ReadString('\n') + if err != nil { + return + } + log.Printf("remote-exec: %s: %s", name, line) + } +} diff --git a/builtin/provisioners/remote-exec/resource_provisioner_test.go b/builtin/provisioners/remote-exec/resource_provisioner_test.go new file mode 100644 index 000000000..5d6dca377 --- /dev/null +++ b/builtin/provisioners/remote-exec/resource_provisioner_test.go @@ -0,0 +1,224 @@ +package remoteexec + +import ( + "bytes" + "io" + "testing" + + "github.com/hashicorp/terraform/config" + "github.com/hashicorp/terraform/terraform" +) + +func TestResourceProvisioner_impl(t *testing.T) { + var _ terraform.ResourceProvisioner = new(ResourceProvisioner) +} + +func TestResourceProvider_Validate_good(t *testing.T) { + c := testConfig(t, map[string]interface{}{ + "inline": "echo foo", + }) + p := new(ResourceProvisioner) + warn, errs := p.Validate(c) + if len(warn) > 0 { + t.Fatalf("Warnings: %v", warn) + } + if len(errs) > 0 { + t.Fatalf("Errors: %v", errs) + } +} + +func TestResourceProvider_Validate_bad(t *testing.T) { + c := testConfig(t, map[string]interface{}{ + "invalid": "nope", + }) + p := new(ResourceProvisioner) + warn, errs := p.Validate(c) + if len(warn) > 0 { + t.Fatalf("Warnings: %v", warn) + } + if len(errs) == 0 { + t.Fatalf("Should have errors") + } +} + +func TestResourceProvider_verifySSH(t *testing.T) { + p := new(ResourceProvisioner) + r := &terraform.ResourceState{ + ConnInfo: map[string]string{ + "type": "telnet", + }, + } + if err := p.verifySSH(r); err == nil { + t.Fatalf("expected error with telnet") + } + r.ConnInfo["type"] = "ssh" + if err := p.verifySSH(r); err != nil { + t.Fatalf("err: %v", err) + } +} + +func TestResourceProvider_sshConfig(t *testing.T) { + p := new(ResourceProvisioner) + r := &terraform.ResourceState{ + ConnInfo: map[string]string{ + "type": "ssh", + "user": "root", + "password": "supersecret", + "key_file": "/my/key/file.pem", + "host": "127.0.0.1", + "port": "22", + "timeout": "30s", + }, + } + + conf, err := p.sshConfig(r) + if err != nil { + t.Fatalf("err: %v", err) + } + + if conf.User != "root" { + t.Fatalf("bad: %v", conf) + } + if conf.Password != "supersecret" { + t.Fatalf("bad: %v", conf) + } + if conf.KeyFile != "/my/key/file.pem" { + t.Fatalf("bad: %v", conf) + } + if conf.Host != "127.0.0.1" { + t.Fatalf("bad: %v", conf) + } + if conf.Port != 22 { + t.Fatalf("bad: %v", conf) + } + if conf.Timeout != "30s" { + t.Fatalf("bad: %v", conf) + } + if conf.ScriptPath != DefaultScriptPath { + t.Fatalf("bad: %v", conf) + } +} + +func TestResourceProvider_generateScript(t *testing.T) { + p := new(ResourceProvisioner) + conf := testConfig(t, map[string]interface{}{ + "inline": []string{ + "cd /tmp", + "wget http://foobar", + "exit 0", + }, + }) + out, err := p.generateScript(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if out != expectedScriptOut { + t.Fatalf("bad: %v", out) + } +} + +var expectedScriptOut = `#!/bin/sh +cd /tmp +wget http://foobar +exit 0 +` + +func TestResourceProvider_CollectScripts_inline(t *testing.T) { + p := new(ResourceProvisioner) + conf := testConfig(t, map[string]interface{}{ + "inline": []string{ + "cd /tmp", + "wget http://foobar", + "exit 0", + }, + }) + + scripts, err := p.collectScripts(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(scripts) != 1 { + t.Fatalf("bad: %v", scripts) + } + + var out bytes.Buffer + _, err = io.Copy(&out, scripts[0]) + if err != nil { + t.Fatalf("err: %v", err) + } + + if string(out.Bytes()) != expectedScriptOut { + t.Fatalf("bad: %v", out.Bytes()) + } +} + +func TestResourceProvider_CollectScripts_script(t *testing.T) { + p := new(ResourceProvisioner) + conf := testConfig(t, map[string]interface{}{ + "script": "test-fixtures/script1.sh", + }) + + scripts, err := p.collectScripts(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(scripts) != 1 { + t.Fatalf("bad: %v", scripts) + } + + var out bytes.Buffer + _, err = io.Copy(&out, scripts[0]) + if err != nil { + t.Fatalf("err: %v", err) + } + + if string(out.Bytes()) != expectedScriptOut { + t.Fatalf("bad: %v", out.Bytes()) + } +} + +func TestResourceProvider_CollectScripts_scripts(t *testing.T) { + p := new(ResourceProvisioner) + conf := testConfig(t, map[string]interface{}{ + "scripts": []interface{}{ + "test-fixtures/script1.sh", + "test-fixtures/script1.sh", + "test-fixtures/script1.sh", + }, + }) + + scripts, err := p.collectScripts(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(scripts) != 3 { + t.Fatalf("bad: %v", scripts) + } + + for idx := range scripts { + var out bytes.Buffer + _, err = io.Copy(&out, scripts[idx]) + if err != nil { + t.Fatalf("err: %v", err) + } + + if string(out.Bytes()) != expectedScriptOut { + t.Fatalf("bad: %v", out.Bytes()) + } + } +} + +func testConfig( + t *testing.T, + c map[string]interface{}) *terraform.ResourceConfig { + r, err := config.NewRawConfig(c) + if err != nil { + t.Fatalf("bad: %s", err) + } + + return terraform.NewResourceConfig(r) +} diff --git a/builtin/provisioners/remote-exec/test-fixtures/script1.sh b/builtin/provisioners/remote-exec/test-fixtures/script1.sh new file mode 100755 index 000000000..cd22f3e63 --- /dev/null +++ b/builtin/provisioners/remote-exec/test-fixtures/script1.sh @@ -0,0 +1,4 @@ +#!/bin/sh +cd /tmp +wget http://foobar +exit 0 diff --git a/command/apply_test.go b/command/apply_test.go index 48a1a1c70..05a5c299e 100644 --- a/command/apply_test.go +++ b/command/apply_test.go @@ -383,8 +383,9 @@ func TestApply_state(t *testing.T) { originalState := &terraform.State{ Resources: map[string]*terraform.ResourceState{ "test_instance.foo": &terraform.ResourceState{ - ID: "bar", - Type: "test_instance", + ID: "bar", + Type: "test_instance", + ConnInfo: make(map[string]string), }, }, } diff --git a/config.go b/config.go index b616414d7..77bf01a07 100644 --- a/config.go +++ b/config.go @@ -36,7 +36,8 @@ func init() { "aws": "terraform-provider-aws", } BuiltinConfig.Provisioners = map[string]string{ - "local-exec": "terraform-provisioner-local-exec", + "local-exec": "terraform-provisioner-local-exec", + "remote-exec": "terraform-provisioner-remote-exec", } } diff --git a/config/config.go b/config/config.go index 1dd04227c..4e785b37c 100644 --- a/config/config.go +++ b/config/config.go @@ -42,6 +42,7 @@ type Resource struct { type Provisioner struct { Type string RawConfig *RawConfig + ConnInfo *RawConfig } // Variable is a variable defined within the configuration. diff --git a/config/loader_libucl.go b/config/loader_libucl.go index b95190e4e..78baaf164 100644 --- a/config/loader_libucl.go +++ b/config/loader_libucl.go @@ -316,6 +316,10 @@ func loadResourcesLibucl(o *libucl.Object) ([]*Resource, error) { // that is treated specially. delete(config, "provisioner") + // Delete the "connection" section since we handle that + // seperately + delete(config, "connection") + rawConfig, err := NewRawConfig(config) if err != nil { return nil, fmt.Errorf( @@ -339,11 +343,26 @@ func loadResourcesLibucl(o *libucl.Object) ([]*Resource, error) { } } + // If we have connection info, then parse those out + var connInfo map[string]interface{} + if conn := r.Get("connection"); conn != nil { + var err error + connInfo, err = loadConnInfoLibucl(conn) + conn.Close() + if err != nil { + return nil, fmt.Errorf( + "Error reading connection info for %s[%s]: %s", + t.Key(), + r.Key(), + err) + } + } + // If we have provisioners, then parse those out var provisioners []*Provisioner if po := r.Get("provisioner"); po != nil { var err error - provisioners, err = loadProvisionersLibucl(po) + provisioners, err = loadProvisionersLibucl(po, connInfo) po.Close() if err != nil { return nil, fmt.Errorf( @@ -367,7 +386,7 @@ func loadResourcesLibucl(o *libucl.Object) ([]*Resource, error) { return result, nil } -func loadProvisionersLibucl(o *libucl.Object) ([]*Provisioner, error) { +func loadProvisionersLibucl(o *libucl.Object, connInfo map[string]interface{}) ([]*Provisioner, error) { pos := make([]*libucl.Object, 0, int(o.Len())) // Accumulate all the actual provisioner configuration objects. We @@ -409,16 +428,58 @@ func loadProvisionersLibucl(o *libucl.Object) ([]*Provisioner, error) { return nil, err } + // Delete the "connection" section, handle seperately + delete(config, "connection") + rawConfig, err := NewRawConfig(config) if err != nil { return nil, err } + // Check if we have a provisioner-level connection + // block that overrides the resource-level + var subConnInfo map[string]interface{} + if conn := po.Get("connection"); conn != nil { + var err error + subConnInfo, err = loadConnInfoLibucl(conn) + conn.Close() + if err != nil { + return nil, err + } + } + + // Inherit from the resource connInfo any keys + // that are not explicitly overriden. + if connInfo != nil && subConnInfo != nil { + for k, v := range connInfo { + if _, ok := subConnInfo[k]; !ok { + subConnInfo[k] = v + } + } + } else if subConnInfo == nil { + subConnInfo = connInfo + } + + // Parse the connInfo + connRaw, err := NewRawConfig(subConnInfo) + if err != nil { + return nil, err + } + result = append(result, &Provisioner{ Type: po.Key(), RawConfig: rawConfig, + ConnInfo: connRaw, }) } return result, nil } + +func loadConnInfoLibucl(o *libucl.Object) (map[string]interface{}, error) { + var config map[string]interface{} + if err := o.Decode(&config); err != nil { + return nil, err + } + return config, nil +} diff --git a/config/loader_test.go b/config/loader_test.go index 4181dcb34..378eb9518 100644 --- a/config/loader_test.go +++ b/config/loader_test.go @@ -185,6 +185,44 @@ func TestLoad_provisioners(t *testing.T) { } } +func TestLoad_connections(t *testing.T) { + c, err := Load(filepath.Join(fixtureDir, "connection.tf")) + if err != nil { + t.Fatalf("err: %s", err) + } + + if c == nil { + t.Fatal("config should not be nil") + } + + actual := resourcesStr(c.Resources) + if actual != strings.TrimSpace(connectionResourcesStr) { + t.Fatalf("bad:\n%s", actual) + } + + // Check for the connection info + r := c.Resources[0] + if r.Name != "web" && r.Type != "aws_instance" { + t.Fatalf("Bad: %#v", r) + } + + p1 := r.Provisioners[0] + if p1.ConnInfo == nil || len(p1.ConnInfo.Raw) != 2 { + t.Fatalf("Bad: %#v", p1.ConnInfo) + } + if p1.ConnInfo.Raw["user"] != "nobody" { + t.Fatalf("Bad: %#v", p1.ConnInfo) + } + + p2 := r.Provisioners[1] + if p2.ConnInfo == nil || len(p2.ConnInfo.Raw) != 2 { + t.Fatalf("Bad: %#v", p2.ConnInfo) + } + if p2.ConnInfo.Raw["user"] != "root" { + t.Fatalf("Bad: %#v", p2.ConnInfo) + } +} + // This helper turns a provider configs field into a deterministic // string value for comparison in tests. func providerConfigsStr(pcs map[string]*ProviderConfig) string { @@ -448,6 +486,20 @@ aws_instance[web] (x1) user: var.foo ` +const connectionResourcesStr = ` +aws_instance[web] (x1) + ami + security_groups + provisioners + shell + path + shell + path + vars + resource: aws_security_group.firewall.foo + user: var.foo +` + const variablesVariablesStr = ` bar <> diff --git a/config/test-fixtures/connection.tf b/config/test-fixtures/connection.tf new file mode 100644 index 000000000..7f808f2cb --- /dev/null +++ b/config/test-fixtures/connection.tf @@ -0,0 +1,23 @@ +resource "aws_instance" "web" { + ami = "${var.foo}" + security_groups = [ + "foo", + "${aws_security_group.firewall.foo}" + ] + + connection { + type = "ssh" + user = "root" + } + + provisioner "shell" { + path = "foo" + connection { + user = "nobody" + } + } + + provisioner "shell" { + path = "bar" + } +} diff --git a/helper/ssh/communicator.go b/helper/ssh/communicator.go new file mode 100644 index 000000000..4aa865817 --- /dev/null +++ b/helper/ssh/communicator.go @@ -0,0 +1,543 @@ +package ssh + +import ( + "bufio" + "bytes" + "code.google.com/p/go.crypto/ssh" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + "path/filepath" + "sync" + "time" +) + +// RemoteCmd represents a remote command being prepared or run. +type RemoteCmd struct { + // Command is the command to run remotely. This is executed as if + // it were a shell command, so you are expected to do any shell escaping + // necessary. + Command string + + // Stdin specifies the process's standard input. If Stdin is + // nil, the process reads from an empty bytes.Buffer. + Stdin io.Reader + + // Stdout and Stderr represent the process's standard output and + // error. + // + // If either is nil, it will be set to ioutil.Discard. + Stdout io.Writer + Stderr io.Writer + + // This will be set to true when the remote command has exited. It + // shouldn't be set manually by the user, but there is no harm in + // doing so. + Exited bool + + // Once Exited is true, this will contain the exit code of the process. + ExitStatus int + + // Internal fields + exitCh chan struct{} + + // This thing is a mutex, lock when making modifications concurrently + sync.Mutex +} + +// SetExited is a helper for setting that this process is exited. This +// should be called by communicators who are running a remote command in +// order to set that the command is done. +func (r *RemoteCmd) SetExited(status int) { + r.Lock() + defer r.Unlock() + + if r.exitCh == nil { + r.exitCh = make(chan struct{}) + } + + r.Exited = true + r.ExitStatus = status + close(r.exitCh) +} + +// Wait waits for the remote command to complete. +func (r *RemoteCmd) Wait() { + // Make sure our condition variable is initialized. + r.Lock() + if r.exitCh == nil { + r.exitCh = make(chan struct{}) + } + r.Unlock() + + <-r.exitCh +} + +type SSHCommunicator struct { + client *ssh.Client + config *Config + conn net.Conn + address string +} + +// Config is the structure used to configure the SSH communicator. +type Config struct { + // The configuration of the Go SSH connection + SSHConfig *ssh.ClientConfig + + // Connection returns a new connection. The current connection + // in use will be closed as part of the Close method, or in the + // case an error occurs. + Connection func() (net.Conn, error) + + // NoPty, if true, will not request a pty from the remote end. + NoPty bool +} + +// Creates a new packer.Communicator implementation over SSH. This takes +// an already existing TCP connection and SSH configuration. +func New(address string, config *Config) (result *SSHCommunicator, err error) { + // Establish an initial connection and connect + result = &SSHCommunicator{ + config: config, + address: address, + } + + if err = result.reconnect(); err != nil { + result = nil + return + } + + return +} + +func (c *SSHCommunicator) Start(cmd *RemoteCmd) (err error) { + session, err := c.newSession() + if err != nil { + return + } + + // Setup our session + session.Stdin = cmd.Stdin + session.Stdout = cmd.Stdout + session.Stderr = cmd.Stderr + + if !c.config.NoPty { + // Request a PTY + termModes := ssh.TerminalModes{ + ssh.ECHO: 0, // do not echo + ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud + ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud + } + + if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { + return + } + } + + log.Printf("starting remote command: %s", cmd.Command) + err = session.Start(cmd.Command + "\n") + if err != nil { + return + } + + // A channel to keep track of our done state + doneCh := make(chan struct{}) + sessionLock := new(sync.Mutex) + timedOut := false + + // Start a goroutine to wait for the session to end and set the + // exit boolean and status. + go func() { + defer session.Close() + + err := session.Wait() + exitStatus := 0 + if err != nil { + exitErr, ok := err.(*ssh.ExitError) + if ok { + exitStatus = exitErr.ExitStatus() + } + } + + sessionLock.Lock() + defer sessionLock.Unlock() + + if timedOut { + // We timed out, so set the exit status to -1 + exitStatus = -1 + } + + log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) + cmd.SetExited(exitStatus) + close(doneCh) + }() + + return +} + +func (c *SSHCommunicator) Upload(path string, input io.Reader) error { + // The target directory and file for talking the SCP protocol + target_dir := filepath.Dir(path) + target_file := filepath.Base(path) + + // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). + // This does not work when the target host is unix. Switch to forward slash + // which works for unix and windows + target_dir = filepath.ToSlash(target_dir) + + scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { + return scpUploadFile(target_file, input, w, stdoutR) + } + + return c.scpSession("scp -vt "+target_dir, scpFunc) +} + +func (c *SSHCommunicator) UploadDir(dst string, src string, excl []string) error { + log.Printf("Upload dir '%s' to '%s'", src, dst) + scpFunc := func(w io.Writer, r *bufio.Reader) error { + uploadEntries := func() error { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + + entries, err := f.Readdir(-1) + if err != nil { + return err + } + + return scpUploadDir(src, entries, w, r) + } + + if src[len(src)-1] != '/' { + log.Printf("No trailing slash, creating the source directory name") + return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) + } else { + // Trailing slash, so only upload the contents + return uploadEntries() + } + } + + return c.scpSession("scp -rvt "+dst, scpFunc) +} + +func (c *SSHCommunicator) Download(string, io.Writer) error { + panic("not implemented yet") +} + +func (c *SSHCommunicator) newSession() (session *ssh.Session, err error) { + log.Println("opening new ssh session") + if c.client == nil { + err = errors.New("client not available") + } else { + session, err = c.client.NewSession() + } + + if err != nil { + log.Printf("ssh session open error: '%s', attempting reconnect", err) + if err := c.reconnect(); err != nil { + return nil, err + } + + return c.client.NewSession() + } + + return session, nil +} + +func (c *SSHCommunicator) reconnect() (err error) { + if c.conn != nil { + c.conn.Close() + } + + // Set the conn and client to nil since we'll recreate it + c.conn = nil + c.client = nil + + log.Printf("reconnecting to TCP connection for SSH") + c.conn, err = c.config.Connection() + if err != nil { + // Explicitly set this to the REAL nil. Connection() can return + // a nil implementation of net.Conn which will make the + // "if c.conn == nil" check fail above. Read here for more information + // on this psychotic language feature: + // + // http://golang.org/doc/faq#nil_error + c.conn = nil + + log.Printf("reconnection error: %s", err) + return + } + + log.Printf("handshaking with SSH") + sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) + if err != nil { + log.Printf("handshake error: %s", err) + } + if sshConn != nil { + c.client = ssh.NewClient(sshConn, sshChan, req) + } + + return +} + +func (c *SSHCommunicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { + session, err := c.newSession() + if err != nil { + return err + } + defer session.Close() + + // Get a pipe to stdin so that we can send data down + stdinW, err := session.StdinPipe() + if err != nil { + return err + } + + // We only want to close once, so we nil w after we close it, + // and only close in the defer if it hasn't been closed already. + defer func() { + if stdinW != nil { + stdinW.Close() + } + }() + + // Get a pipe to stdout so that we can get responses back + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return err + } + stdoutR := bufio.NewReader(stdoutPipe) + + // Set stderr to a bytes buffer + stderr := new(bytes.Buffer) + session.Stderr = stderr + + // Start the sink mode on the other side + // TODO(mitchellh): There are probably issues with shell escaping the path + log.Println("Starting remote scp process: ", scpCommand) + if err := session.Start(scpCommand); err != nil { + return err + } + + // Call our callback that executes in the context of SCP. We ignore + // EOF errors if they occur because it usually means that SCP prematurely + // ended on the other side. + log.Println("Started SCP session, beginning transfers...") + if err := f(stdinW, stdoutR); err != nil && err != io.EOF { + return err + } + + // Close the stdin, which sends an EOF, and then set w to nil so that + // our defer func doesn't close it again since that is unsafe with + // the Go SSH package. + log.Println("SCP session complete, closing stdin pipe.") + stdinW.Close() + stdinW = nil + + // Wait for the SCP connection to close, meaning it has consumed all + // our data and has completed. Or has errored. + log.Println("Waiting for SSH session to complete.") + err = session.Wait() + if err != nil { + if exitErr, ok := err.(*ssh.ExitError); ok { + // Otherwise, we have an ExitErorr, meaning we can just read + // the exit status + log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) + + // If we exited with status 127, it means SCP isn't available. + // Return a more descriptive error for that. + if exitErr.ExitStatus() == 127 { + return errors.New( + "SCP failed to start. This usually means that SCP is not\n" + + "properly installed on the remote system.") + } + } + + return err + } + + log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) + return nil +} + +// checkSCPStatus checks that a prior command sent to SCP completed +// successfully. If it did not complete successfully, an error will +// be returned. +func checkSCPStatus(r *bufio.Reader) error { + code, err := r.ReadByte() + if err != nil { + return err + } + + if code != 0 { + // Treat any non-zero (really 1 and 2) as fatal errors + message, _, err := r.ReadLine() + if err != nil { + return fmt.Errorf("Error reading error message: %s", err) + } + + return errors.New(string(message)) + } + + return nil +} + +func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { + // Create a temporary file where we can copy the contents of the src + // so that we can determine the length, since SCP is length-prefixed. + tf, err := ioutil.TempFile("", "packer-upload") + if err != nil { + return fmt.Errorf("Error creating temporary file for upload: %s", err) + } + defer os.Remove(tf.Name()) + defer tf.Close() + + log.Println("Copying input data into temporary file so we can read the length") + if _, err := io.Copy(tf, src); err != nil { + return err + } + + // Sync the file so that the contents are definitely on disk, then + // read the length of it. + if err := tf.Sync(); err != nil { + return fmt.Errorf("Error creating temporary file for upload: %s", err) + } + + // Seek the file to the beginning so we can re-read all of it + if _, err := tf.Seek(0, 0); err != nil { + return fmt.Errorf("Error creating temporary file for upload: %s", err) + } + + fi, err := tf.Stat() + if err != nil { + return fmt.Errorf("Error creating temporary file for upload: %s", err) + } + + // Start the protocol + log.Println("Beginning file upload...") + fmt.Fprintln(w, "C0644", fi.Size(), dst) + if err := checkSCPStatus(r); err != nil { + return err + } + + if _, err := io.Copy(w, tf); err != nil { + return err + } + + fmt.Fprint(w, "\x00") + if err := checkSCPStatus(r); err != nil { + return err + } + + return nil +} + +func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { + log.Printf("SCP: starting directory upload: %s", name) + fmt.Fprintln(w, "D0755 0", name) + err := checkSCPStatus(r) + if err != nil { + return err + } + + if err := f(); err != nil { + return err + } + + fmt.Fprintln(w, "E") + if err != nil { + return err + } + + return nil +} + +func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { + for _, fi := range fs { + realPath := filepath.Join(root, fi.Name()) + + // Track if this is actually a symlink to a directory. If it is + // a symlink to a file we don't do any special behavior because uploading + // a file just works. If it is a directory, we need to know so we + // treat it as such. + isSymlinkToDir := false + if fi.Mode()&os.ModeSymlink == os.ModeSymlink { + symPath, err := filepath.EvalSymlinks(realPath) + if err != nil { + return err + } + + symFi, err := os.Lstat(symPath) + if err != nil { + return err + } + + isSymlinkToDir = symFi.IsDir() + } + + if !fi.IsDir() && !isSymlinkToDir { + // It is a regular file (or symlink to a file), just upload it + f, err := os.Open(realPath) + if err != nil { + return err + } + + err = func() error { + defer f.Close() + return scpUploadFile(fi.Name(), f, w, r) + }() + + if err != nil { + return err + } + + continue + } + + // It is a directory, recursively upload + err := scpUploadDirProtocol(fi.Name(), w, r, func() error { + f, err := os.Open(realPath) + if err != nil { + return err + } + defer f.Close() + + entries, err := f.Readdir(-1) + if err != nil { + return err + } + + return scpUploadDir(realPath, entries, w, r) + }) + if err != nil { + return err + } + } + + return nil +} + +// ConnectFunc is a convenience method for returning a function +// that just uses net.Dial to communicate with the remote end that +// is suitable for use with the SSH communicator configuration. +func ConnectFunc(network, addr string) func() (net.Conn, error) { + return func() (net.Conn, error) { + c, err := net.DialTimeout(network, addr, 15*time.Second) + if err != nil { + return nil, err + } + + if tcpConn, ok := c.(*net.TCPConn); ok { + tcpConn.SetKeepAlive(true) + } + + return c, nil + } +} diff --git a/helper/ssh/communicator_test.go b/helper/ssh/communicator_test.go new file mode 100644 index 000000000..53e2f538d --- /dev/null +++ b/helper/ssh/communicator_test.go @@ -0,0 +1,157 @@ +// +build !race + +package ssh + +import ( + "bytes" + "code.google.com/p/go.crypto/ssh" + "fmt" + "net" + "testing" +) + +// private key for mock server +const testServerPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU +70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx +9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF +tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z +s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc +qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT ++IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea +riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH +D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh +atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT +b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN +ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M +MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4 +KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8 +e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1 +D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+ +3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj +orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw +64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc +XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc +QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g +/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ +I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk +gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl +NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw== +-----END RSA PRIVATE KEY-----` + +var serverConfig = &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + if c.User() == "user" && string(pass) == "pass" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, +} + +func init() { + // Parse and set the private key of the server, required to accept connections + signer, err := ssh.ParsePrivateKey([]byte(testServerPrivateKey)) + if err != nil { + panic("unable to parse private key: " + err.Error()) + } + serverConfig.AddHostKey(signer) +} + +func newMockLineServer(t *testing.T) string { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to listen for connection: %s", err) + } + + go func() { + defer l.Close() + c, err := l.Accept() + if err != nil { + t.Errorf("Unable to accept incoming connection: %s", err) + } + defer c.Close() + conn, chans, _, err := ssh.NewServerConn(c, serverConfig) + if err != nil { + t.Logf("Handshaking error: %v", err) + } + t.Log("Accepted SSH connection") + for newChannel := range chans { + channel, _, err := newChannel.Accept() + if err != nil { + t.Errorf("Unable to accept channel.") + } + t.Log("Accepted channel") + + go func() { + defer channel.Close() + conn.OpenChannel(newChannel.ChannelType(), nil) + }() + } + conn.Close() + }() + + return l.Addr().String() +} + +func TestNew_Invalid(t *testing.T) { + clientConfig := &ssh.ClientConfig{ + User: "user", + Auth: []ssh.AuthMethod{ + ssh.Password("i-am-invalid"), + }, + } + + address := newMockLineServer(t) + conn := func() (net.Conn, error) { + conn, err := net.Dial("tcp", address) + if err != nil { + t.Errorf("Unable to accept incoming connection: %v", err) + } + return conn, err + } + + config := &Config{ + Connection: conn, + SSHConfig: clientConfig, + } + + _, err := New(address, config) + if err == nil { + t.Fatal("should have had an error connecting") + } +} + +func TestStart(t *testing.T) { + clientConfig := &ssh.ClientConfig{ + User: "user", + Auth: []ssh.AuthMethod{ + ssh.Password("pass"), + }, + } + + address := newMockLineServer(t) + conn := func() (net.Conn, error) { + conn, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("unable to dial to remote side: %s", err) + } + return conn, err + } + + config := &Config{ + Connection: conn, + SSHConfig: clientConfig, + } + + client, err := New(address, config) + if err != nil { + t.Fatalf("error connecting to SSH: %s", err) + } + + var cmd RemoteCmd + stdout := new(bytes.Buffer) + cmd.Command = "echo foo" + cmd.Stdout = stdout + + client.Start(&cmd) +} diff --git a/helper/ssh/password.go b/helper/ssh/password.go new file mode 100644 index 000000000..e5e2a3595 --- /dev/null +++ b/helper/ssh/password.go @@ -0,0 +1,27 @@ +package ssh + +import ( + "code.google.com/p/go.crypto/ssh" + "log" +) + +// An implementation of ssh.KeyboardInteractiveChallenge that simply sends +// back the password for all questions. The questions are logged. +func PasswordKeyboardInteractive(password string) ssh.KeyboardInteractiveChallenge { + return func(user, instruction string, questions []string, echos []bool) ([]string, error) { + log.Printf("Keyboard interactive challenge: ") + log.Printf("-- User: %s", user) + log.Printf("-- Instructions: %s", instruction) + for i, question := range questions { + log.Printf("-- Question %d: %s", i+1, question) + } + + // Just send the password back for all questions + answers := make([]string, len(questions)) + for i, _ := range answers { + answers[i] = string(password) + } + + return answers, nil + } +} diff --git a/helper/ssh/password_test.go b/helper/ssh/password_test.go new file mode 100644 index 000000000..e74b46e06 --- /dev/null +++ b/helper/ssh/password_test.go @@ -0,0 +1,27 @@ +package ssh + +import ( + "code.google.com/p/go.crypto/ssh" + "reflect" + "testing" +) + +func TestPasswordKeyboardInteractive_Impl(t *testing.T) { + var raw interface{} + raw = PasswordKeyboardInteractive("foo") + if _, ok := raw.(ssh.KeyboardInteractiveChallenge); !ok { + t.Fatal("PasswordKeyboardInteractive must implement KeyboardInteractiveChallenge") + } +} + +func TestPasswordKeybardInteractive_Challenge(t *testing.T) { + p := PasswordKeyboardInteractive("foo") + result, err := p("foo", "bar", []string{"one", "two"}, nil) + if err != nil { + t.Fatalf("err not nil: %s", err) + } + + if !reflect.DeepEqual(result, []string{"foo", "foo"}) { + t.Fatalf("invalid password: %#v", result) + } +} diff --git a/terraform/context.go b/terraform/context.go index 72f6a8813..fe2168f37 100644 --- a/terraform/context.go +++ b/terraform/context.go @@ -3,6 +3,7 @@ package terraform import ( "fmt" "log" + "strconv" "strings" "sync" "sync/atomic" @@ -467,6 +468,11 @@ func (c *Context) applyWalkFn() depgraph.WalkFunc { diff.init() } + // If we do not have any connection info, initialize + if r.State.ConnInfo == nil { + r.State.ConnInfo = make(map[string]string) + } + // Remove any output values from the diff for k, ad := range diff.Attributes { if ad.Type == DiffAttrOutput { @@ -555,6 +561,13 @@ func (c *Context) applyWalkFn() depgraph.WalkFunc { // defined after the resource creation has already completed. func (c *Context) applyProvisioners(r *Resource, rs *ResourceState) (*ResourceState, error) { var err error + + // Store the original connection info, restore later + origConnInfo := rs.ConnInfo + defer func() { + rs.ConnInfo = origConnInfo + }() + for _, prov := range r.Provisioners { // Interpolate since we may have variables that depend on the // local resource. @@ -562,6 +575,41 @@ func (c *Context) applyProvisioners(r *Resource, rs *ResourceState) (*ResourceSt return rs, err } + // Interpolate the conn info, since it may contain variables + connInfo := NewResourceConfig(prov.ConnInfo) + if err := connInfo.interpolate(c); err != nil { + return rs, err + } + + // Merge the connection information + overlay := make(map[string]string) + if origConnInfo != nil { + for k, v := range origConnInfo { + overlay[k] = v + } + } + for k, v := range connInfo.Config { + switch vt := v.(type) { + case string: + overlay[k] = vt + case int64: + overlay[k] = strconv.FormatInt(vt, 10) + case int32: + overlay[k] = strconv.FormatInt(int64(vt), 10) + case int: + overlay[k] = strconv.FormatInt(int64(vt), 10) + case float32: + overlay[k] = strconv.FormatFloat(float64(vt), 'f', 3, 32) + case float64: + overlay[k] = strconv.FormatFloat(vt, 'f', 3, 64) + case bool: + overlay[k] = strconv.FormatBool(vt) + default: + overlay[k] = fmt.Sprintf("%v", vt) + } + } + rs.ConnInfo = overlay + // Invoke the Provisioner rs, err = prov.Provisioner.Apply(rs, prov.Config) if err != nil { diff --git a/terraform/context_test.go b/terraform/context_test.go index 4b27e5017..cfe7727a9 100644 --- a/terraform/context_test.go +++ b/terraform/context_test.go @@ -507,6 +507,81 @@ func TestContextApply_outputDiffVars(t *testing.T) { } } +func TestContextApply_Provisioner_ConnInfo(t *testing.T) { + c := testConfig(t, "apply-provisioner-conninfo") + p := testProvider("aws") + pr := testProvisioner() + + p.ApplyFn = func(s *ResourceState, d *ResourceDiff) (*ResourceState, error) { + if s.ConnInfo == nil { + t.Fatalf("ConnInfo not initialized") + } + + result, _ := testApplyFn(s, d) + result.ConnInfo = map[string]string{ + "type": "ssh", + "host": "127.0.0.1", + "port": "22", + } + return result, nil + } + p.DiffFn = testDiffFn + + pr.ApplyFn = func(rs *ResourceState, c *ResourceConfig) (*ResourceState, error) { + conn := rs.ConnInfo + if conn["type"] != "telnet" { + t.Fatalf("Bad: %#v", conn) + } + if conn["host"] != "127.0.0.1" { + t.Fatalf("Bad: %#v", conn) + } + if conn["port"] != "2222" { + t.Fatalf("Bad: %#v", conn) + } + if conn["user"] != "superuser" { + t.Fatalf("Bad: %#v", conn) + } + if conn["pass"] != "test" { + t.Fatalf("Bad: %#v", conn) + } + return rs, nil + } + + ctx := testContext(t, &ContextOpts{ + Config: c, + Providers: map[string]ResourceProviderFactory{ + "aws": testProviderFuncFixed(p), + }, + Provisioners: map[string]ResourceProvisionerFactory{ + "shell": testProvisionerFuncFixed(pr), + }, + Variables: map[string]string{ + "value": "1", + "pass": "test", + }, + }) + + if _, err := ctx.Plan(nil); err != nil { + t.Fatalf("err: %s", err) + } + + state, err := ctx.Apply() + if err != nil { + t.Fatalf("err: %s", err) + } + + actual := strings.TrimSpace(state.String()) + expected := strings.TrimSpace(testTerraformApplyProvisionerStr) + if actual != expected { + t.Fatalf("bad: \n%s", actual) + } + + // Verify apply was invoked + if !pr.ApplyCalled { + t.Fatalf("provisioner not invoked") + } +} + func TestContextApply_destroy(t *testing.T) { c := testConfig(t, "apply-destroy") h := new(HookRecordApplyOrder) diff --git a/terraform/graph.go b/terraform/graph.go index 445d9e99d..e757c3072 100644 --- a/terraform/graph.go +++ b/terraform/graph.go @@ -546,6 +546,9 @@ func graphAddVariableDeps(g *depgraph.Graph) { for _, p := range m.Resource.Provisioners { vars = p.RawConfig.Variables nounAddVariableDeps(g, n, vars) + + vars = p.ConnInfo.Variables + nounAddVariableDeps(g, n, vars) } case *GraphNodeResourceProvider: @@ -774,6 +777,7 @@ func graphMapResourceProvisioners(g *depgraph.Graph, Provisioner: provisioner, Config: NewResourceConfig(p.RawConfig), RawConfig: p.RawConfig, + ConnInfo: p.ConnInfo, }) } } diff --git a/terraform/graph_test.go b/terraform/graph_test.go index d1ef90c3b..4d1795878 100644 --- a/terraform/graph_test.go +++ b/terraform/graph_test.go @@ -195,11 +195,19 @@ func TestGraphProvisioners(t *testing.T) { if prov.RawConfig.Config()["cmd"] != "add ${aws_instance.web.id}" { t.Fatalf("bad: %#v", prov) } + if prov.ConnInfo == nil || len(prov.ConnInfo.Raw) != 2 { + t.Fatalf("bad: %#v", prov) + } // Check that the variable dependency is handled if !depends("aws_load_balancer.weblb", "aws_instance.web") { t.Fatalf("missing dependency from provisioner variable") } + + // Check that the connection variable dependency is handled + if !depends("aws_load_balancer.weblb", "aws_security_group.firewall") { + t.Fatalf("missing dependency from provisioner connection") + } } func TestGraphAddDiff(t *testing.T) { diff --git a/terraform/resource.go b/terraform/resource.go index 5d7aaee4d..d19949840 100644 --- a/terraform/resource.go +++ b/terraform/resource.go @@ -18,6 +18,7 @@ type ResourceProvisionerConfig struct { Provisioner ResourceProvisioner Config *ResourceConfig RawConfig *config.RawConfig + ConnInfo *config.RawConfig } // Resource encapsulates a resource, its configuration, its provider, diff --git a/terraform/state.go b/terraform/state.go index f8972364a..5683f95f6 100644 --- a/terraform/state.go +++ b/terraform/state.go @@ -145,14 +145,14 @@ func (s *State) String() string { // that should not be serialized. This is only used temporarily // and is restored into the state. type sensitiveState struct { - ConnInfo map[string]*ResourceConnectionInfo + ConnInfo map[string]map[string]string once sync.Once } func (s *sensitiveState) init() { s.once.Do(func() { - s.ConnInfo = make(map[string]*ResourceConnectionInfo) + s.ConnInfo = make(map[string]map[string]string) }) } @@ -245,21 +245,6 @@ func WriteState(d *State, dst io.Writer) error { return err } -// ResourceConnectionInfo holds addresses, credentials and configuration -// information require to connect to a resource. This is populated -// by a provider so that provisioners can connect and run on the -// resource. -type ResourceConnectionInfo struct { - // Type is set so that an appropriate connection can be formed. - // As an example, for a Linux machine, the Type may be "ssh" - Type string - - // Raw is used to store any relevant keys for the given Type - // so that a provisioner can connect to the resource. This could - // contain credentials or address information. - Raw map[string]string -} - // ResourceState holds the state of a resource that is used so that // a provider can find and manage an existing resource as well as for // storing attributes that are uesd to populate variables of child @@ -292,7 +277,7 @@ type ResourceState struct { // ConnInfo is used for the providers to export information which is // used to connect to the resource for provisioning. For example, // this could contain SSH or WinRM credentials. - ConnInfo *ResourceConnectionInfo + ConnInfo map[string]string // Extra information that the provider can store about a resource. // This data is opaque, never shown to the user, and is sent back to diff --git a/terraform/state_test.go b/terraform/state_test.go index 53af97908..26ae4381c 100644 --- a/terraform/state_test.go +++ b/terraform/state_test.go @@ -98,12 +98,10 @@ func TestReadWriteState(t *testing.T) { Resources: map[string]*ResourceState{ "foo": &ResourceState{ ID: "bar", - ConnInfo: &ResourceConnectionInfo{ - Type: "ssh", - Raw: map[string]string{ - "user": "root", - "password": "supersecret", - }, + ConnInfo: map[string]string{ + "type": "ssh", + "user": "root", + "password": "supersecret", }, }, }, diff --git a/terraform/test-fixtures/apply-provisioner-conninfo/main.tf b/terraform/test-fixtures/apply-provisioner-conninfo/main.tf new file mode 100644 index 000000000..f0bfe43ab --- /dev/null +++ b/terraform/test-fixtures/apply-provisioner-conninfo/main.tf @@ -0,0 +1,20 @@ +resource "aws_instance" "foo" { + num = "2" + compute = "dynamical" + compute_value = "${var.value}" +} + +resource "aws_instance" "bar" { + connection { + type = "telnet" + } + + provisioner "shell" { + foo = "${aws_instance.foo.dynamical}" + connection { + user = "superuser" + port = 2222 + pass = "${var.pass}" + } + } +} diff --git a/terraform/test-fixtures/graph-provisioners/main.tf b/terraform/test-fixtures/graph-provisioners/main.tf index 96035ecc4..7222c3f98 100644 --- a/terraform/test-fixtures/graph-provisioners/main.tf +++ b/terraform/test-fixtures/graph-provisioners/main.tf @@ -24,5 +24,9 @@ resource "aws_instance" "web" { resource "aws_load_balancer" "weblb" { provisioner "shell" { cmd = "add ${aws_instance.web.id}" + connection { + type = "magic" + user = "${aws_security_group.firewall.id}" + } } }