diff --git a/builtin/providers/aws/auth_helpers.go b/builtin/providers/aws/auth_helpers.go index f7521c230..fab4928b6 100644 --- a/builtin/providers/aws/auth_helpers.go +++ b/builtin/providers/aws/auth_helpers.go @@ -21,7 +21,7 @@ import ( "github.com/hashicorp/go-cleanhttp" ) -func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) { +func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, string, error) { // If we have creds from instance profile, we can use metadata API if authProviderName == ec2rolecreds.ProviderName { log.Println("[DEBUG] Trying to get account ID via AWS Metadata API") @@ -30,7 +30,7 @@ func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) ( setOptionalEndpoint(cfg) sess, err := session.NewSession(cfg) if err != nil { - return "", errwrap.Wrapf("Error creating AWS session: %s", err) + return "", "", errwrap.Wrapf("Error creating AWS session: %s", err) } metadataClient := ec2metadata.New(sess) @@ -38,24 +38,24 @@ func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) ( if err != nil { // This can be triggered when no IAM Role is assigned // or AWS just happens to return invalid response - return "", fmt.Errorf("Failed getting EC2 IAM info: %s", err) + return "", "", fmt.Errorf("Failed getting EC2 IAM info: %s", err) } - return parseAccountIdFromArn(info.InstanceProfileArn) + return parseAccountInfoFromArn(info.InstanceProfileArn) } // Then try IAM GetUser log.Println("[DEBUG] Trying to get account ID via iam:GetUser") outUser, err := iamconn.GetUser(nil) if err == nil { - return parseAccountIdFromArn(*outUser.User.Arn) + return parseAccountInfoFromArn(*outUser.User.Arn) } awsErr, ok := err.(awserr.Error) // AccessDenied and ValidationError can be raised // if credentials belong to federated profile, so we ignore these if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") { - return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err) + return "", "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err) } log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err) @@ -63,7 +63,7 @@ func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) ( log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity") outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{}) if err == nil { - return *outCallerIdentity.Account, nil + return parseAccountInfoFromArn(*outCallerIdentity.Arn) } log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err) @@ -73,22 +73,22 @@ func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) ( MaxItems: aws.Int64(int64(1)), }) if err != nil { - return "", fmt.Errorf("Failed getting account ID via 'iam:ListRoles': %s", err) + return "", "", fmt.Errorf("Failed getting account ID via 'iam:ListRoles': %s", err) } if len(outRoles.Roles) < 1 { - return "", errors.New("Failed getting account ID via 'iam:ListRoles': No roles available") + return "", "", errors.New("Failed getting account ID via 'iam:ListRoles': No roles available") } - return parseAccountIdFromArn(*outRoles.Roles[0].Arn) + return parseAccountInfoFromArn(*outRoles.Roles[0].Arn) } -func parseAccountIdFromArn(arn string) (string, error) { +func parseAccountInfoFromArn(arn string) (string, string, error) { parts := strings.Split(arn, ":") if len(parts) < 5 { - return "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn) + return "", "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn) } - return parts[4], nil + return parts[1], parts[4], nil } // This function is responsible for reading credentials from the diff --git a/builtin/providers/aws/auth_helpers_test.go b/builtin/providers/aws/auth_helpers_test.go index 5600e9245..fb7dd6884 100644 --- a/builtin/providers/aws/auth_helpers_test.go +++ b/builtin/providers/aws/auth_helpers_test.go @@ -21,7 +21,7 @@ import ( "github.com/aws/aws-sdk-go/service/sts" ) -func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) { +func TestAWSGetAccountInfo_shouldBeValid_fromEC2Role(t *testing.T) { resetEnv := unsetEnv(t) defer resetEnv() // capture the test server's close method, to call after the test returns @@ -32,18 +32,23 @@ func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) { ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName) + part, id, err := GetAccountInfo(iamConn, stsConn, ec2rolecreds.ProviderName) if err != nil { t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) } + expectedPart := "aws" + if part != expectedPart { + t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) + } + expectedAccountId := "123456789013" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) { +func TestAWSGetAccountInfo_shouldBeValid_EC2RoleHasPriority(t *testing.T) { resetEnv := unsetEnv(t) defer resetEnv() // capture the test server's close method, to call after the test returns @@ -59,18 +64,23 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) { ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName) + part, id, err := GetAccountInfo(iamConn, stsConn, ec2rolecreds.ProviderName) if err != nil { t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) } + expectedPart := "aws" + if part != expectedPart { + t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) + } + expectedAccountId := "123456789013" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) { +func TestAWSGetAccountInfo_shouldBeValid_fromIamUser(t *testing.T) { iamEndpoints := []*iamEndpoint{ { Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -81,18 +91,23 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) { ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, stsConn, "") + part, id, err := GetAccountInfo(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via GetUser failed: %s", err) } + expectedPart := "aws" + if part != expectedPart { + t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) + } + expectedAccountId := "123456789012" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) { +func TestAWSGetAccountInfo_shouldBeValid_fromGetCallerIdentity(t *testing.T) { iamEndpoints := []*iamEndpoint{ { Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -106,18 +121,23 @@ func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) { ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, stsConn, "") + part, id, err := GetAccountInfo(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via GetUser failed: %s", err) } + expectedPart := "aws" + if part != expectedPart { + t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) + } + expectedAccountId := "123456789012" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) { +func TestAWSGetAccountInfo_shouldBeValid_fromIamListRoles(t *testing.T) { iamEndpoints := []*iamEndpoint{ { Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -135,18 +155,23 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) { ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, stsConn, "") + part, id, err := GetAccountInfo(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via ListRoles failed: %s", err) } + expectedPart := "aws" + if part != expectedPart { + t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) + } + expectedAccountId := "123456789012" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) { +func TestAWSGetAccountInfo_shouldBeValid_federatedRole(t *testing.T) { iamEndpoints := []*iamEndpoint{ { Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -160,18 +185,23 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) { ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, stsConn, "") + part, id, err := GetAccountInfo(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via ListRoles failed: %s", err) } + expectedPart := "aws" + if part != expectedPart { + t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) + } + expectedAccountId := "123456789012" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) { +func TestAWSGetAccountInfo_shouldError_unauthorizedFromIam(t *testing.T) { iamEndpoints := []*iamEndpoint{ { Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -185,29 +215,37 @@ func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) { ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, stsConn, "") + part, id, err := GetAccountInfo(iamConn, stsConn, "") if err == nil { t.Fatal("Expected error when getting account ID") } + if part != "" { + t.Fatalf("Expected no partition, given: %s", part) + } + if id != "" { t.Fatalf("Expected no account ID, given: %s", id) } } -func TestAWSParseAccountIdFromArn(t *testing.T) { +func TestAWSParseAccountInfoFromArn(t *testing.T) { validArn := "arn:aws:iam::101636750127:instance-profile/aws-elasticbeanstalk-ec2-role" + expectedPart := "aws" expectedId := "101636750127" - id, err := parseAccountIdFromArn(validArn) + part, id, err := parseAccountInfoFromArn(validArn) if err != nil { t.Fatalf("Expected no error when parsing valid ARN: %s", err) } + if part != expectedPart { + t.Fatalf("Parsed part doesn't match with expected (%q != %q)", part, expectedPart) + } if id != expectedId { t.Fatalf("Parsed id doesn't match with expected (%q != %q)", id, expectedId) } invalidArn := "blablah" - id, err = parseAccountIdFromArn(invalidArn) + part, id, err = parseAccountInfoFromArn(invalidArn) if err == nil { t.Fatalf("Expected error when parsing invalid ARN (%q)", invalidArn) } diff --git a/builtin/providers/aws/config.go b/builtin/providers/aws/config.go index 52c18f094..39c7c065a 100644 --- a/builtin/providers/aws/config.go +++ b/builtin/providers/aws/config.go @@ -118,6 +118,7 @@ type AWSClient struct { stsconn *sts.STS redshiftconn *redshift.Redshift r53conn *route53.Route53 + partition string accountid string region string rdsconn *rds.RDS @@ -226,8 +227,9 @@ func (c *Config) Client() (interface{}, error) { } if !c.SkipRequestingAccountId { - accountId, err := GetAccountId(client.iamconn, client.stsconn, cp.ProviderName) + partition, accountId, err := GetAccountInfo(client.iamconn, client.stsconn, cp.ProviderName) if err == nil { + client.partition = partition client.accountid = accountId } } diff --git a/builtin/providers/aws/resource_aws_db_instance.go b/builtin/providers/aws/resource_aws_db_instance.go index 6f1606f56..dd17cfe76 100644 --- a/builtin/providers/aws/resource_aws_db_instance.go +++ b/builtin/providers/aws/resource_aws_db_instance.go @@ -693,7 +693,7 @@ func resourceAwsDbInstanceRead(d *schema.ResourceData, meta interface{}) error { // list tags for resource // set tags conn := meta.(*AWSClient).rdsconn - arn, err := buildRDSARN(d.Id(), meta.(*AWSClient).accountid, meta.(*AWSClient).region) + arn, err := buildRDSARN(d.Id(), meta.(*AWSClient).partition, meta.(*AWSClient).accountid, meta.(*AWSClient).region) if err != nil { name := "" if v.DBName != nil && *v.DBName != "" { @@ -976,7 +976,7 @@ func resourceAwsDbInstanceUpdate(d *schema.ResourceData, meta interface{}) error } } - if arn, err := buildRDSARN(d.Id(), meta.(*AWSClient).accountid, meta.(*AWSClient).region); err == nil { + if arn, err := buildRDSARN(d.Id(), meta.(*AWSClient).partition, meta.(*AWSClient).accountid, meta.(*AWSClient).region); err == nil { if err := setTagsRDS(conn, d, arn); err != nil { return err } else { @@ -1052,11 +1052,10 @@ func resourceAwsDbInstanceStateRefreshFunc( } } -func buildRDSARN(identifier, accountid, region string) (string, error) { +func buildRDSARN(identifier, partition, accountid, region string) (string, error) { if accountid == "" { return "", fmt.Errorf("Unable to construct RDS ARN because of missing AWS Account ID") } - arn := fmt.Sprintf("arn:aws:rds:%s:%s:db:%s", region, accountid, identifier) + arn := fmt.Sprintf("arn:%s:rds:%s:%s:db:%s", partition, region, accountid, identifier) return arn, nil - } diff --git a/builtin/providers/aws/resource_aws_db_instance_test.go b/builtin/providers/aws/resource_aws_db_instance_test.go index d1fb05f96..91617faf7 100644 --- a/builtin/providers/aws/resource_aws_db_instance_test.go +++ b/builtin/providers/aws/resource_aws_db_instance_test.go @@ -350,7 +350,7 @@ func testAccCheckAWSDBInstanceSnapshot(s *terraform.State) error { } } else { // snapshot was found, // verify we have the tags copied to the snapshot - instanceARN, err := buildRDSARN(snapshot_identifier, testAccProvider.Meta().(*AWSClient).accountid, testAccProvider.Meta().(*AWSClient).region) + instanceARN, err := buildRDSARN(snapshot_identifier, testAccProvider.Meta().(*AWSClient).partition, testAccProvider.Meta().(*AWSClient).accountid, testAccProvider.Meta().(*AWSClient).region) // tags have a different ARN, just swapping :db: for :snapshot: tagsARN := strings.Replace(instanceARN, ":db:", ":snapshot:", 1) if err != nil { diff --git a/builtin/providers/aws/resource_aws_rds_cluster.go b/builtin/providers/aws/resource_aws_rds_cluster.go index f10d2dd41..7351fffb0 100644 --- a/builtin/providers/aws/resource_aws_rds_cluster.go +++ b/builtin/providers/aws/resource_aws_rds_cluster.go @@ -467,7 +467,7 @@ func resourceAwsRDSClusterRead(d *schema.ResourceData, meta interface{}) error { } // Fetch and save tags - arn, err := buildRDSClusterARN(d.Id(), meta.(*AWSClient).accountid, meta.(*AWSClient).region) + arn, err := buildRDSClusterARN(d.Id(), meta.(*AWSClient).partition, meta.(*AWSClient).accountid, meta.(*AWSClient).region) if err != nil { log.Printf("[DEBUG] Error building ARN for RDS Cluster (%s), not setting Tags", *dbc.DBClusterIdentifier) } else { @@ -536,7 +536,7 @@ func resourceAwsRDSClusterUpdate(d *schema.ResourceData, meta interface{}) error } } - if arn, err := buildRDSClusterARN(d.Id(), meta.(*AWSClient).accountid, meta.(*AWSClient).region); err == nil { + if arn, err := buildRDSClusterARN(d.Id(), meta.(*AWSClient).partition, meta.(*AWSClient).accountid, meta.(*AWSClient).region); err == nil { if err := setTagsRDS(conn, d, arn); err != nil { return err } else { @@ -625,12 +625,12 @@ func resourceAwsRDSClusterStateRefreshFunc( } } -func buildRDSClusterARN(identifier, accountid, region string) (string, error) { +func buildRDSClusterARN(identifier, partition, accountid, region string) (string, error) { if accountid == "" { return "", fmt.Errorf("Unable to construct RDS Cluster ARN because of missing AWS Account ID") } - arn := fmt.Sprintf("arn:aws:rds:%s:%s:cluster:%s", region, accountid, identifier) + arn := fmt.Sprintf("arn:%s:rds:%s:%s:cluster:%s", partition, region, accountid, identifier) return arn, nil } diff --git a/builtin/providers/aws/resource_aws_rds_cluster_instance.go b/builtin/providers/aws/resource_aws_rds_cluster_instance.go index 08f64ddd3..027e1b7c0 100644 --- a/builtin/providers/aws/resource_aws_rds_cluster_instance.go +++ b/builtin/providers/aws/resource_aws_rds_cluster_instance.go @@ -245,7 +245,7 @@ func resourceAwsRDSClusterInstanceRead(d *schema.ResourceData, meta interface{}) } // Fetch and save tags - arn, err := buildRDSARN(d.Id(), meta.(*AWSClient).accountid, meta.(*AWSClient).region) + arn, err := buildRDSARN(d.Id(), meta.(*AWSClient).partition, meta.(*AWSClient).accountid, meta.(*AWSClient).region) if err != nil { log.Printf("[DEBUG] Error building ARN for RDS Cluster Instance (%s), not setting Tags", *db.DBInstanceIdentifier) } else { @@ -322,7 +322,7 @@ func resourceAwsRDSClusterInstanceUpdate(d *schema.ResourceData, meta interface{ } - if arn, err := buildRDSARN(d.Id(), meta.(*AWSClient).accountid, meta.(*AWSClient).region); err == nil { + if arn, err := buildRDSARN(d.Id(), meta.(*AWSClient).partition, meta.(*AWSClient).accountid, meta.(*AWSClient).region); err == nil { if err := setTagsRDS(conn, d, arn); err != nil { return err }