diff --git a/communicator/ssh/communicator_test.go b/communicator/ssh/communicator_test.go index c1ed00e2e..3f9d4f883 100644 --- a/communicator/ssh/communicator_test.go +++ b/communicator/ssh/communicator_test.go @@ -3,9 +3,15 @@ package ssh import ( + "bufio" "bytes" "fmt" + "io" + "io/ioutil" + "math/rand" "net" + "os" + "path/filepath" "regexp" "strings" "testing" @@ -165,6 +171,117 @@ func TestStart(t *testing.T) { } } +func TestAccUploadFile(t *testing.T) { + // use the local ssh server and scp binary to check uploads + if ok := os.Getenv("SSH_UPLOAD_TEST"); ok == "" { + t.Log("Skipping Upload Acceptance without SSH_UPLOAD_TEST set") + t.Skip() + } + + r := &terraform.InstanceState{ + Ephemeral: terraform.EphemeralState{ + ConnInfo: map[string]string{ + "type": "ssh", + "user": os.Getenv("USER"), + "host": "127.0.0.1", + "port": "22", + "timeout": "30s", + }, + }, + } + + c, err := New(r) + if err != nil { + t.Fatalf("error creating communicator: %s", err) + } + + tmpDir, err := ioutil.TempDir("", "communicator") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + content := []byte("this is the file content") + source := bytes.NewReader(content) + tmpFile := filepath.Join(tmpDir, "tempFile.out") + err = c.Upload(tmpFile, source) + if err != nil { + t.Fatalf("error uploading file: %s", err) + } + + data, err := ioutil.ReadFile(tmpFile) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, content) { + t.Fatalf("bad: %s", data) + } +} + +func TestAccHugeUploadFile(t *testing.T) { + // use the local ssh server and scp binary to check uploads + if ok := os.Getenv("SSH_UPLOAD_TEST"); ok == "" { + t.Log("Skipping Upload Acceptance without SSH_UPLOAD_TEST set") + t.Skip() + } + + r := &terraform.InstanceState{ + Ephemeral: terraform.EphemeralState{ + ConnInfo: map[string]string{ + "type": "ssh", + "user": os.Getenv("USER"), + "host": "127.0.0.1", + "port": "22", + "timeout": "30s", + }, + }, + } + + c, err := New(r) + if err != nil { + t.Fatalf("error creating communicator: %s", err) + } + + // copy 4GB of data, random to prevent compression. + size := int64(1 << 32) + source := io.LimitReader(rand.New(rand.NewSource(0)), size) + + dest, err := ioutil.TempFile("", "communicator") + if err != nil { + t.Fatal(err) + } + destName := dest.Name() + dest.Close() + defer os.Remove(destName) + + t.Log("Uploading to", destName) + + // bypass the Upload method so we can directly supply the file size + // preventing the extra copy of the huge file. + targetDir := filepath.Dir(destName) + targetFile := filepath.Base(destName) + + scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { + return scpUploadFile(targetFile, source, w, stdoutR, size) + } + + err = c.scpSession("scp -vt "+targetDir, scpFunc) + if err != nil { + t.Fatal(err) + } + + // check the final file size + fs, err := os.Stat(destName) + if err != nil { + t.Fatal(err) + } + + if fs.Size() != size { + t.Fatalf("expected file size of %d, got %d", size, fs.Size()) + } +} + func TestScriptPath(t *testing.T) { cases := []struct { Input string