From 48a92cfb1f6f0aea100c7f55608653cb25093bbb Mon Sep 17 00:00:00 2001 From: Jakub Janczak Date: Thu, 5 Feb 2015 09:33:56 +0100 Subject: [PATCH] remote/s3: s3 remote state storage support --- command/remote.go | 27 +++++-- remote/README.md | 11 +++ remote/client.go | 2 + remote/s3.go | 192 ++++++++++++++++++++++++++++++++++++++++++++++ remote/s3_test.go | 134 ++++++++++++++++++++++++++++++++ 5 files changed, 359 insertions(+), 7 deletions(-) create mode 100644 remote/README.md create mode 100644 remote/s3.go create mode 100644 remote/s3_test.go diff --git a/command/remote.go b/command/remote.go index d9a773704..3b9e8c57b 100644 --- a/command/remote.go +++ b/command/remote.go @@ -32,7 +32,7 @@ type RemoteCommand struct { func (c *RemoteCommand) Run(args []string) int { args = c.Meta.process(args, false) - var address, accessToken, name, path string + var address, accessToken, name, path, region, securityToken, bucket string cmdFlags := flag.NewFlagSet("remote", flag.ContinueOnError) cmdFlags.BoolVar(&c.conf.disableRemote, "disable", false, "") cmdFlags.BoolVar(&c.conf.pullOnDisable, "pull", true, "") @@ -41,6 +41,9 @@ func (c *RemoteCommand) Run(args []string) int { cmdFlags.StringVar(&c.remoteConf.Type, "backend", "atlas", "") cmdFlags.StringVar(&address, "address", "", "") cmdFlags.StringVar(&accessToken, "access-token", "", "") + cmdFlags.StringVar(&securityToken, "security-token", "", "") + cmdFlags.StringVar(&bucket, "bucket", "", "") + cmdFlags.StringVar(®ion, "region", "", "") cmdFlags.StringVar(&name, "name", "", "") cmdFlags.StringVar(&path, "path", "", "") cmdFlags.Usage = func() { c.Ui.Error(c.Help()) } @@ -57,10 +60,13 @@ func (c *RemoteCommand) Run(args []string) int { // Populate the various configurations c.remoteConf.Config = map[string]string{ - "address": address, - "access_token": accessToken, - "name": name, - "path": path, + "address": address, + "access_token": accessToken, + "security_token": securityToken, + "name": name, + "path": path, + "bucket": bucket, + "region": region, } // Check if have an existing local state file @@ -329,13 +335,17 @@ Options: -access-token=token Authentication token for state storage server. Required for Atlas backend, optional for Consul. + -security-token=token Security token. Specific to S3 (required). + -backend=Atlas Specifies the type of remote backend. Must be one - of Atlas, Consul, or HTTP. Defaults to Atlas. + of Atlas, Consul,HTTP or S3. Defaults to Atlas. -backup=path Path to backup the existing state file before modifying. Defaults to the "-state" path with ".backup" extension. Set to "-" to disable backup. + -bucket=bucket S3 bucket name. Specific to S3 (required). + -disable Disables remote state management and migrates the state to the -state path. @@ -343,12 +353,15 @@ Options: Required for Atlas backend. -path=path Path of the remote state in Consul. Required for the - Consul backend. + Consul and S3 backend. -pull=true Controls if the remote state is pulled before disabling. This defaults to true to ensure the latest state is cached before disabling. + -region=region AWS region to use. Specific for S3 (not required if AWS_DEFAULT_REGION + env variable is set). + -state=path Path to read state. Defaults to "terraform.tfstate" unless remote state is enabled. diff --git a/remote/README.md b/remote/README.md new file mode 100644 index 000000000..60d53a1e2 --- /dev/null +++ b/remote/README.md @@ -0,0 +1,11 @@ +## How to test + +### S3 remote state storage +To run S3 integration tests you need following env variables to be set: + * AWS_ACCESS_KEY + * AWS_SECRET_KEY + * AWS_DEFAULT_REGION + * TERRAFORM_STATE_BUCKET + +Additionally specified bucket should exist in the defined region and should be accessible +using specified credentials. diff --git a/remote/client.go b/remote/client.go index c57e40c78..56849553e 100644 --- a/remote/client.go +++ b/remote/client.go @@ -59,6 +59,8 @@ func NewClientByType(ctype string, conf map[string]string) (RemoteClient, error) return NewConsulRemoteClient(conf) case "http": return NewHTTPRemoteClient(conf) + case "s3": + return NewS3RemoteClient(conf) default: return nil, fmt.Errorf("Unknown remote client type '%s'", ctype) } diff --git a/remote/s3.go b/remote/s3.go new file mode 100644 index 000000000..04b202094 --- /dev/null +++ b/remote/s3.go @@ -0,0 +1,192 @@ +package remote + +import ( + "bytes" + "crypto/md5" + "encoding/base64" + "fmt" + "io" + "net/http" + "os" + "time" + + "github.com/goamz/goamz/aws" + "github.com/goamz/goamz/s3" +) + +type S3RemoteClient struct { + Bucket *s3.Bucket + Path string +} + +func GetRegion(conf map[string]string) (aws.Region, error) { + regionName, ok := conf["region"] + if !ok || regionName == "" { + regionName = os.Getenv("AWS_DEFAULT_REGION") + if regionName == "" { + return aws.Region{}, fmt.Errorf("AWS region not set") + } + } + + region, ok := aws.Regions[regionName] + if !ok { + return aws.Region{}, fmt.Errorf("AWS region set in configuration '%v' doesn't exist", regionName) + } + return region, nil +} + +func NewS3RemoteClient(conf map[string]string) (*S3RemoteClient, error) { + client := &S3RemoteClient{} + + auth, err := aws.GetAuth(conf["access_token"], conf["secret_token"], "", time.Now()) + if err != nil { + return nil, err + } + + region, err := GetRegion(conf) + if err != nil { + return nil, err + } + + bucketName, ok := conf["bucket"] + if !ok { + return nil, fmt.Errorf("Missing 'bucket_name' configuration") + } + + client.Bucket = s3.New(auth, region).Bucket(bucketName) + + path, ok := conf["path"] + if !ok { + return nil, fmt.Errorf("Missing 'path' configuration") + } + client.Path = path + + return client, nil +} + +func (c *S3RemoteClient) GetState() (*RemoteStatePayload, error) { + resp, err := c.Bucket.GetResponse(c.Path) + defer func() { + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + }() + + if err != nil { + switch err.(type) { + case *s3.Error: + s3Err := err.(*s3.Error) + + // FIXME copied from Atlas + // Handle the common status codes + switch s3Err.StatusCode { + case http.StatusOK: + // Handled after + case http.StatusNoContent: + return nil, nil + case http.StatusNotFound: + return nil, nil + case http.StatusUnauthorized: + return nil, ErrRequireAuth + case http.StatusForbidden: + return nil, ErrInvalidAuth + case http.StatusInternalServerError: + return nil, ErrRemoteInternal + default: + return nil, fmt.Errorf("Unexpected HTTP response code %d", s3Err.StatusCode) + } + default: + return nil, err + } + } + + // Read in the body + buf := bytes.NewBuffer(nil) + if _, err := io.Copy(buf, resp.Body); err != nil { + return nil, fmt.Errorf("Failed to read remote state: %v", err) + } + + // Create the payload + payload := &RemoteStatePayload{ + State: buf.Bytes(), + } + + // Check for the MD5 + if raw := resp.Header.Get("Content-MD5"); raw != "" { + md5, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return nil, fmt.Errorf("Failed to decode Content-MD5 '%s': %v", raw, err) + } + payload.MD5 = md5 + + } else { + // Generate the MD5 + hash := md5.Sum(payload.State) + payload.MD5 = hash[:md5.Size] + } + + return payload, nil +} + +func (c *S3RemoteClient) PutState(state []byte, force bool) error { + // Generate the MD5 + hash := md5.Sum(state) + b64 := base64.StdEncoding.EncodeToString(hash[:md5.Size]) + + options := s3.Options{ + ContentMD5: b64, + } + + err := c.Bucket.Put(c.Path, state, "application/json", s3.Private, options) + switch err.(type) { + case *s3.Error: + s3Err := err.(*s3.Error) + + // Handle the error codes + switch s3Err.StatusCode { + case http.StatusOK: + return nil + case http.StatusConflict: + return ErrConflict + case http.StatusPreconditionFailed: + return ErrServerNewer + case http.StatusUnauthorized: + return ErrRequireAuth + case http.StatusForbidden: + return ErrInvalidAuth + case http.StatusInternalServerError: + return ErrRemoteInternal + default: + return fmt.Errorf("Unexpected HTTP response code %d", s3Err.StatusCode) + } + default: + return err + } +} + +func (c *S3RemoteClient) DeleteState() error { + err := c.Bucket.Del(c.Path) + switch err.(type) { + case *s3.Error: + s3Err := err.(*s3.Error) + // Handle the error codes + switch s3Err.StatusCode { + case http.StatusOK: + return nil + case http.StatusNoContent: + return nil + case http.StatusNotFound: + return nil + case http.StatusUnauthorized: + return ErrRequireAuth + case http.StatusForbidden: + return ErrInvalidAuth + case http.StatusInternalServerError: + return ErrRemoteInternal + default: + return fmt.Errorf("Unexpected HTTP response code %d", s3Err.StatusCode) + } + default: + return err + } +} diff --git a/remote/s3_test.go b/remote/s3_test.go new file mode 100644 index 000000000..41309ae3f --- /dev/null +++ b/remote/s3_test.go @@ -0,0 +1,134 @@ +package remote + +import ( + "bytes" + "crypto/md5" + "os" + "testing" + + "github.com/hashicorp/terraform/terraform" +) + +func TestS3Remote_NewClient(t *testing.T) { + conf := map[string]string{} + if _, err := NewS3RemoteClient(conf); err == nil { + t.Fatalf("expect error") + } + + conf["access_token"] = "test" + conf["secret_token"] = "test" + conf["path"] = "hashicorp/test-state" + conf["bucket"] = "plan3-test" + conf["region"] = "eu-west-1" + if _, err := NewS3RemoteClient(conf); err != nil { + t.Fatalf("err: %v", err) + } +} + +func TestS3Remote_Validate_envVar(t *testing.T) { + conf := map[string]string{} + if _, err := NewS3RemoteClient(conf); err == nil { + t.Fatalf("expect error") + } + + defer os.Setenv("AWS_ACCESS_KEY", os.Getenv("AWS_ACCESS_KEY")) + os.Setenv("AWS_ACCESS_KEY", "foo") + + defer os.Setenv("AWS_SECRET_KEY", os.Getenv("AWS_SECRET_KEY")) + os.Setenv("AWS_SECRET_KEY", "foo") + + defer os.Setenv("AWS_DEFAULT_REGION", os.Getenv("AWS_DEFAULT_REGION")) + os.Setenv("AWS_DEFAULT_REGION", "eu-west-1") + + conf["path"] = "hashicorp/test-state" + conf["bucket"] = "plan3-test" + if _, err := NewS3RemoteClient(conf); err != nil { + t.Fatalf("err: %v", err) + } +} + +func checkS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY") == "" || os.Getenv("AWS_SECRET_KEY") == "" || os.Getenv("AWS_DEFAULT_REGION") == "" || os.Getenv("TERRAFORM_STATE_BUCKET") == "" { + t.SkipNow() + } +} + +func TestS3Remote(t *testing.T) { + checkS3(t) + remote := &terraform.RemoteState{ + Type: "atlas", + Config: map[string]string{ + "access_token": "some-access-token", + "name": "hashicorp/test-remote-state", + }, + } + r, err := NewClientByType("s3", map[string]string{ + "bucket": os.Getenv("TERRAFORM_STATE_BUCKET"), + "path": "test-remote-state", + }) + if err != nil { + t.Fatalf("Err: %v", err) + } + + // Get a valid input + inp, err := blankState(remote) + if err != nil { + t.Fatalf("Err: %v", err) + } + inpMD5 := md5.Sum(inp) + hash := inpMD5[:16] + + // Delete the state, should be none + err = r.DeleteState() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Ensure no state + payload, err := r.GetState() + if err != nil { + t.Fatalf("Err: %v", err) + } + if payload != nil { + t.Fatalf("unexpected payload") + } + + // Put the state + err = r.PutState(inp, false) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Get it back + payload, err = r.GetState() + if err != nil { + t.Fatalf("Err: %v", err) + } + if payload == nil { + t.Fatalf("unexpected payload") + } + + // Check the payload + if !bytes.Equal(payload.MD5, hash) { + t.Fatalf("bad hash: %x %x", payload.MD5, hash) + } + if !bytes.Equal(payload.State, inp) { + t.Errorf("inp: %s", inp) + t.Fatalf("bad response: %s", payload.State) + } + + // Delete the state + err = r.DeleteState() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Should be gone + payload, err = r.GetState() + if err != nil { + t.Fatalf("Err: %v", err) + } + if payload != nil { + t.Fatalf("unexpected payload") + } +}