terraform/builtin/provisioners/remote-exec/resource_provisioner_test.go

257 lines
5.2 KiB
Go

package remoteexec
import (
"bytes"
"context"
"io"
"log"
"testing"
"time"
"strings"
"github.com/hashicorp/terraform/communicator"
"github.com/hashicorp/terraform/communicator/remote"
"github.com/hashicorp/terraform/provisioners"
"github.com/mitchellh/cli"
"github.com/zclconf/go-cty/cty"
)
func TestResourceProvider_Validate_good(t *testing.T) {
c := cty.ObjectVal(map[string]cty.Value{
"inline": cty.ListVal([]cty.Value{cty.StringVal("echo foo")}),
})
resp := New().ValidateProvisionerConfig(provisioners.ValidateProvisionerConfigRequest{
Config: c,
})
if len(resp.Diagnostics) > 0 {
t.Fatal(resp.Diagnostics.ErrWithWarnings())
}
}
func TestResourceProvider_Validate_bad(t *testing.T) {
c := cty.ObjectVal(map[string]cty.Value{
"invalid": cty.StringVal("nope"),
})
resp := New().ValidateProvisionerConfig(provisioners.ValidateProvisionerConfigRequest{
Config: c,
})
if !resp.Diagnostics.HasErrors() {
t.Fatalf("Should have errors")
}
}
var expectedScriptOut = `cd /tmp
wget http://foobar
exit 0
`
func TestResourceProvider_generateScript(t *testing.T) {
inline := cty.ListVal([]cty.Value{
cty.StringVal("cd /tmp"),
cty.StringVal("wget http://foobar"),
cty.StringVal("exit 0"),
})
out, err := generateScripts(inline)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(out) != 1 {
t.Fatal("expected 1 out")
}
if out[0] != expectedScriptOut {
t.Fatalf("bad: %v", out)
}
}
func TestResourceProvider_generateScriptEmptyInline(t *testing.T) {
inline := cty.ListVal([]cty.Value{cty.StringVal("")})
_, err := generateScripts(inline)
if err == nil {
t.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "empty string") {
t.Fatalf("expected empty string error, got: %s", err)
}
}
func TestResourceProvider_CollectScripts_inline(t *testing.T) {
conf := map[string]cty.Value{
"inline": cty.ListVal([]cty.Value{
cty.StringVal("cd /tmp"),
cty.StringVal("wget http://foobar"),
cty.StringVal("exit 0"),
}),
}
scripts, err := collectScripts(cty.ObjectVal(conf))
if err != nil {
t.Fatalf("err: %v", err)
}
if len(scripts) != 1 {
t.Fatalf("bad: %v", scripts)
}
var out bytes.Buffer
_, err = io.Copy(&out, scripts[0])
if err != nil {
t.Fatalf("err: %v", err)
}
if out.String() != expectedScriptOut {
t.Fatalf("bad: %v", out.String())
}
}
func TestResourceProvider_CollectScripts_script(t *testing.T) {
p := New()
schema := p.GetSchema().Provisioner
conf, err := schema.CoerceValue(cty.ObjectVal(map[string]cty.Value{
"scripts": cty.ListVal([]cty.Value{
cty.StringVal("testdata/script1.sh"),
}),
}))
if err != nil {
t.Fatal(err)
}
scripts, err := collectScripts(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(scripts) != 1 {
t.Fatalf("bad: %v", scripts)
}
var out bytes.Buffer
_, err = io.Copy(&out, scripts[0])
if err != nil {
t.Fatalf("err: %v", err)
}
if out.String() != expectedScriptOut {
t.Fatalf("bad: %v", out.String())
}
}
func TestResourceProvider_CollectScripts_scripts(t *testing.T) {
p := New()
schema := p.GetSchema().Provisioner
conf, err := schema.CoerceValue(cty.ObjectVal(map[string]cty.Value{
"scripts": cty.ListVal([]cty.Value{
cty.StringVal("testdata/script1.sh"),
cty.StringVal("testdata/script1.sh"),
cty.StringVal("testdata/script1.sh"),
}),
}))
if err != nil {
log.Fatal(err)
}
scripts, err := collectScripts(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(scripts) != 3 {
t.Fatalf("bad: %v", scripts)
}
for idx := range scripts {
var out bytes.Buffer
_, err = io.Copy(&out, scripts[idx])
if err != nil {
t.Fatalf("err: %v", err)
}
if out.String() != expectedScriptOut {
t.Fatalf("bad: %v", out.String())
}
}
}
func TestResourceProvider_CollectScripts_scriptsEmpty(t *testing.T) {
p := New()
schema := p.GetSchema().Provisioner
conf, err := schema.CoerceValue(cty.ObjectVal(map[string]cty.Value{
"scripts": cty.ListVal([]cty.Value{cty.StringVal("")}),
}))
if err != nil {
t.Fatal(err)
}
_, err = collectScripts(conf)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "empty string") {
t.Fatalf("Expected empty string error, got: %s", err)
}
}
func TestProvisionerTimeout(t *testing.T) {
o := cli.NewMockUi()
c := new(communicator.MockCommunicator)
disconnected := make(chan struct{})
c.DisconnectFunc = func() error {
close(disconnected)
return nil
}
completed := make(chan struct{})
c.CommandFunc = func(cmd *remote.Cmd) error {
defer close(completed)
cmd.Init()
time.Sleep(2 * time.Second)
cmd.SetExitStatus(0, nil)
return nil
}
c.ConnTimeout = time.Second
c.UploadScripts = map[string]string{"hello": "echo hello"}
c.RemoteScriptPath = "hello"
conf := map[string]cty.Value{
"inline": cty.ListVal([]cty.Value{cty.StringVal("echo hello")}),
}
scripts, err := collectScripts(cty.ObjectVal(conf))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
done := make(chan struct{})
var runErr error
go func() {
defer close(done)
runErr = runScripts(ctx, o, c, scripts)
}()
select {
case <-disconnected:
t.Fatal("communicator disconnected before command completed")
case <-completed:
}
<-done
if runErr != nil {
t.Fatal(err)
}
}