diff --git a/builtin/providers/aws/resource_aws_s3_bucket.go b/builtin/providers/aws/resource_aws_s3_bucket.go index d832190b0..cb32d5fa3 100644 --- a/builtin/providers/aws/resource_aws_s3_bucket.go +++ b/builtin/providers/aws/resource_aws_s3_bucket.go @@ -14,6 +14,7 @@ func resourceAwsS3Bucket() *schema.Resource { return &schema.Resource{ Create: resourceAwsS3BucketCreate, Read: resourceAwsS3BucketRead, + Update: resourceAwsS3BucketUpdate, Delete: resourceAwsS3BucketDelete, Schema: map[string]*schema.Schema{ @@ -29,6 +30,8 @@ func resourceAwsS3Bucket() *schema.Resource { Optional: true, ForceNew: true, }, + + "tags": tagsSchema(), }, } } @@ -64,7 +67,15 @@ func resourceAwsS3BucketCreate(d *schema.ResourceData, meta interface{}) error { // Assign the bucket name as the resource ID d.SetId(bucket) - return nil + return resourceAwsS3BucketUpdate(d, meta) +} + +func resourceAwsS3BucketUpdate(d *schema.ResourceData, meta interface{}) error { + s3conn := meta.(*AWSClient).s3conn + if err := setTagsS3(s3conn, d); err != nil { + return err + } + return resourceAwsS3BucketRead(d, meta) } func resourceAwsS3BucketRead(d *schema.ResourceData, meta interface{}) error { @@ -76,6 +87,18 @@ func resourceAwsS3BucketRead(d *schema.ResourceData, meta interface{}) error { if err != nil { return err } + + resp, err := s3conn.GetBucketTagging(&s3.GetBucketTaggingRequest{ + Bucket: aws.String(d.Id()), + }) + if err != nil { + return err + } + + if err := d.Set("tags", tagsToMapS3(resp.TagSet)); err != nil { + return err + } + return nil } diff --git a/builtin/providers/aws/s3_tags.go b/builtin/providers/aws/s3_tags.go new file mode 100644 index 000000000..43678952b --- /dev/null +++ b/builtin/providers/aws/s3_tags.go @@ -0,0 +1,112 @@ +package aws + +import ( + "crypto/md5" + "encoding/base64" + "encoding/xml" + "log" + + "github.com/hashicorp/aws-sdk-go/aws" + "github.com/hashicorp/aws-sdk-go/gen/s3" + "github.com/hashicorp/terraform/helper/schema" +) + +// setTags is a helper to set the tags for a resource. It expects the +// tags field to be named "tags" +func setTagsS3(conn *s3.S3, d *schema.ResourceData) error { + if d.HasChange("tags") { + oraw, nraw := d.GetChange("tags") + o := oraw.(map[string]interface{}) + n := nraw.(map[string]interface{}) + create, remove := diffTagsS3(tagsFromMapS3(o), tagsFromMapS3(n)) + + // Set tags + if len(remove) > 0 { + log.Printf("[DEBUG] Removing tags: %#v", remove) + err := conn.DeleteBucketTagging(&s3.DeleteBucketTaggingRequest{ + Bucket: aws.String(d.Get("bucket").(string)), + }) + if err != nil { + return err + } + } + if len(create) > 0 { + log.Printf("[DEBUG] Creating tags: %#v", create) + tagging := s3.Tagging{ + TagSet: create, + XMLName: xml.Name{ + Space: "http://s3.amazonaws.com/doc/2006-03-01/", + Local: "Tagging", + }, + } + // AWS S3 API requires us to send a base64 encoded md5 hash of the + // content, which we need to build ourselves since aws-sdk-go does not. + b, err := xml.Marshal(tagging) + if err != nil { + return err + } + h := md5.New() + h.Write(b) + base := base64.StdEncoding.EncodeToString(h.Sum(nil)) + + req := &s3.PutBucketTaggingRequest{ + Bucket: aws.String(d.Get("bucket").(string)), + ContentMD5: aws.String(base), + Tagging: &tagging, + } + + err = conn.PutBucketTagging(req) + if err != nil { + return err + } + } + } + + return nil +} + +// diffTags takes our tags locally and the ones remotely and returns +// the set of tags that must be created, and the set of tags that must +// be destroyed. +func diffTagsS3(oldTags, newTags []s3.Tag) ([]s3.Tag, []s3.Tag) { + // First, we're creating everything we have + create := make(map[string]interface{}) + for _, t := range newTags { + create[*t.Key] = *t.Value + } + + // Build the list of what to remove + var remove []s3.Tag + for _, t := range oldTags { + old, ok := create[*t.Key] + if !ok || old != *t.Value { + // Delete it! + remove = append(remove, t) + } + } + + return tagsFromMapS3(create), remove +} + +// tagsFromMap returns the tags for the given map of data. +func tagsFromMapS3(m map[string]interface{}) []s3.Tag { + result := make([]s3.Tag, 0, len(m)) + for k, v := range m { + result = append(result, s3.Tag{ + Key: aws.String(k), + Value: aws.String(v.(string)), + }) + } + + return result +} + +// tagsToMap turns the list of tags into a map. +func tagsToMapS3(ts []s3.Tag) map[string]string { + result := make(map[string]string) + for _, t := range ts { + result[*t.Key] = *t.Value + } + + return result +} diff --git a/builtin/providers/aws/s3_tags_test.go b/builtin/providers/aws/s3_tags_test.go new file mode 100644 index 000000000..9b082c6e4 --- /dev/null +++ b/builtin/providers/aws/s3_tags_test.go @@ -0,0 +1,85 @@ +package aws + +import ( + "fmt" + "reflect" + "testing" + + "github.com/hashicorp/aws-sdk-go/gen/s3" + "github.com/hashicorp/terraform/helper/resource" + "github.com/hashicorp/terraform/terraform" +) + +func TestDiffTagsS3(t *testing.T) { + cases := []struct { + Old, New map[string]interface{} + Create, Remove map[string]string + }{ + // Basic add/remove + { + Old: map[string]interface{}{ + "foo": "bar", + }, + New: map[string]interface{}{ + "bar": "baz", + }, + Create: map[string]string{ + "bar": "baz", + }, + Remove: map[string]string{ + "foo": "bar", + }, + }, + + // Modify + { + Old: map[string]interface{}{ + "foo": "bar", + }, + New: map[string]interface{}{ + "foo": "baz", + }, + Create: map[string]string{ + "foo": "baz", + }, + Remove: map[string]string{ + "foo": "bar", + }, + }, + } + + for i, tc := range cases { + c, r := diffTagsS3(tagsFromMapS3(tc.Old), tagsFromMapS3(tc.New)) + cm := tagsToMapS3(c) + rm := tagsToMapS3(r) + if !reflect.DeepEqual(cm, tc.Create) { + t.Fatalf("%d: bad create: %#v", i, cm) + } + if !reflect.DeepEqual(rm, tc.Remove) { + t.Fatalf("%d: bad remove: %#v", i, rm) + } + } +} + +// testAccCheckTags can be used to check the tags on a resource. +func testAccCheckTagsS3( + ts *[]s3.Tag, key string, value string) resource.TestCheckFunc { + return func(s *terraform.State) error { + m := tagsToMapS3(*ts) + v, ok := m[key] + if value != "" && !ok { + return fmt.Errorf("Missing tag: %s", key) + } else if value == "" && ok { + return fmt.Errorf("Extra tag: %s", key) + } + if value == "" { + return nil + } + + if v != value { + return fmt.Errorf("%s: bad value: %s", key, v) + } + + return nil + } +}