From a23bcf2ec9231cf9e60388eac6eda179de116bad Mon Sep 17 00:00:00 2001 From: Kraig Amador Date: Wed, 27 Apr 2016 18:49:42 -0700 Subject: [PATCH] Added accountid to AWSClient and set it early in the initialization phase. We use iam.GetUser(nil) scattered around to get the account id, but this isn't the most reliable method. GetAccountId now uses one more method (sts:GetCallerIdentity) to get the account id, this works with federated users. --- builtin/providers/aws/auth_helpers.go | 17 ++++- builtin/providers/aws/auth_helpers_test.go | 87 +++++++++++++++++----- builtin/providers/aws/config.go | 25 ++++--- 3 files changed, 98 insertions(+), 31 deletions(-) diff --git a/builtin/providers/aws/auth_helpers.go b/builtin/providers/aws/auth_helpers.go index 914c7e971..552a4234f 100644 --- a/builtin/providers/aws/auth_helpers.go +++ b/builtin/providers/aws/auth_helpers.go @@ -14,10 +14,11 @@ import ( "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" "github.com/hashicorp/go-cleanhttp" ) -func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) { +func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) { // If we have creds from instance profile, we can use metadata API if authProviderName == ec2rolecreds.ProviderName { log.Println("[DEBUG] Trying to get account ID via AWS Metadata API") @@ -42,16 +43,24 @@ func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) { return parseAccountIdFromArn(*outUser.User.Arn) } - // Then try IAM ListRoles awsErr, ok := err.(awserr.Error) // AccessDenied and ValidationError can be raised // if credentials belong to federated profile, so we ignore these if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") { return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err) } - log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err) - log.Println("[DEBUG] Trying to get account ID via iam:ListRoles instead") + + // Then try STS GetCallerIdentity + log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity") + outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err == nil { + return *outCallerIdentity.Account, nil + } + log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err) + + // Then try IAM ListRoles + log.Println("[DEBUG] Trying to get account ID via iam:ListRoles") outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{ MaxItems: aws.Int64(int64(1)), }) diff --git a/builtin/providers/aws/auth_helpers_test.go b/builtin/providers/aws/auth_helpers_test.go index a5fcf8f16..a9de0fcc6 100644 --- a/builtin/providers/aws/auth_helpers_test.go +++ b/builtin/providers/aws/auth_helpers_test.go @@ -18,6 +18,7 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" ) func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) { @@ -28,10 +29,10 @@ func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) { defer awsTs() iamEndpoints := []*iamEndpoint{} - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName) + id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName) if err != nil { t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) } @@ -55,10 +56,10 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) { Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName) + id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName) if err != nil { t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) } @@ -76,10 +77,36 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) { Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, "") + id, err := GetAccountId(iamConn, stsConn, "") + if err != nil { + t.Fatalf("Getting account ID via GetUser failed: %s", err) + } + + expectedAccountId := "123456789012" + if id != expectedAccountId { + t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) + } +} + +func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) { + iamEndpoints := []*iamEndpoint{ + &iamEndpoint{ + Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, + }, + &iamEndpoint{ + Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &iamResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"}, + }, + } + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) + defer ts() + + id, err := GetAccountId(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via GetUser failed: %s", err) } @@ -96,15 +123,19 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) { Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, }, + &iamEndpoint{ + Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &iamResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"}, + }, &iamEndpoint{ Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, "") + id, err := GetAccountId(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via ListRoles failed: %s", err) } @@ -126,10 +157,10 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) { Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, "") + id, err := GetAccountId(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via ListRoles failed: %s", err) } @@ -151,10 +182,10 @@ func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) { Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, "") + id, err := GetAccountId(iamConn, stsConn, "") if err == nil { t.Fatal("Expected error when getting account ID") } @@ -586,15 +617,15 @@ func invalidAwsEnv(t *testing.T) func() { return ts.Close } -// getMockedAwsIamApi establishes a httptest server to simulate behaviour -// of a real AWS' IAM server -func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) { +// getMockedAwsIamStsApi establishes a httptest server to simulate behaviour +// of a real AWS' IAM & STS server +func getMockedAwsIamStsApi(endpoints []*iamEndpoint) (func(), *iam.IAM, *sts.STS) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { buf := new(bytes.Buffer) buf.ReadFrom(r.Body) requestBody := buf.String() - log.Printf("[DEBUG] Received IAM API %q request to %q: %s", + log.Printf("[DEBUG] Received API %q request to %q: %s", r.Method, r.RequestURI, requestBody) for _, e := range endpoints { @@ -624,8 +655,8 @@ func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) { CredentialsChainVerboseErrors: aws.Bool(true), }) iamConn := iam.New(sess) - - return ts.Close, iamConn + stsConn := sts.New(sess) + return ts.Close, iamConn, stsConn } func getEnv() *currentEnv { @@ -718,6 +749,26 @@ const iamResponse_GetUser_unauthorized = ` + + arn:aws:iam::123456789012:user/Alice + AKIAI44QH8DHBEXAMPLE + 123456789012 + + + 01234567-89ab-cdef-0123-456789abcdef + +` + +const stsResponse_GetCallerIdentity_unauthorized = ` + + Sender + AccessDenied + User: arn:aws:iam::123456789012:user/Bob is not authorized to perform: sts:GetCallerIdentity + + 01234567-89ab-cdef-0123-456789abcdef +` + const iamResponse_GetUser_federatedFailure = ` Sender diff --git a/builtin/providers/aws/config.go b/builtin/providers/aws/config.go index 82a82e016..0db0ec0f9 100644 --- a/builtin/providers/aws/config.go +++ b/builtin/providers/aws/config.go @@ -50,6 +50,7 @@ import ( "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/sns" "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sts" ) type Config struct { @@ -92,8 +93,10 @@ type AWSClient struct { s3conn *s3.S3 sqsconn *sqs.SQS snsconn *sns.SNS + stsconn *sts.STS redshiftconn *redshift.Redshift r53conn *route53.Route53 + accountid string region string rdsconn *rds.RDS iamconn *iam.IAM @@ -172,6 +175,9 @@ func (c *Config) Client() (interface{}, error) { awsIamSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.IamEndpoint)}) client.iamconn = iam.New(awsIamSess) + log.Println("[INFO] Initializing STS connection") + client.stsconn = sts.New(sess) + err = c.ValidateCredentials(client.iamconn) if err != nil { errs = append(errs, err) @@ -185,6 +191,11 @@ func (c *Config) Client() (interface{}, error) { // http://docs.aws.amazon.com/general/latest/gr/sigv4_changes.html usEast1Sess := sess.Copy(&aws.Config{Region: aws.String("us-east-1")}) + accountId, err := GetAccountId(client.iamconn, client.stsconn, cp.ProviderName) + if err == nil { + client.accountid = accountId + } + log.Println("[INFO] Initializing DynamoDB connection") dynamoSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.DynamoDBEndpoint)}) client.dynamodbconn = dynamodb.New(dynamoSess) @@ -215,7 +226,7 @@ func (c *Config) Client() (interface{}, error) { log.Println("[INFO] Initializing Elastic Beanstalk Connection") client.elasticbeanstalkconn = elasticbeanstalk.New(sess) - authErr := c.ValidateAccountId(client.iamconn, cp.ProviderName) + authErr := c.ValidateAccountId(client.accountid) if authErr != nil { errs = append(errs, authErr) } @@ -338,20 +349,16 @@ func (c *Config) ValidateCredentials(iamconn *iam.IAM) error { // ValidateAccountId returns a context-specific error if the configured account // id is explicitly forbidden or not authorised; and nil if it is authorised. -func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) error { +func (c *Config) ValidateAccountId(accountId string) error { if c.AllowedAccountIds == nil && c.ForbiddenAccountIds == nil { return nil } log.Printf("[INFO] Validating account ID") - account_id, err := GetAccountId(iamconn, authProviderName) - if err != nil { - return err - } if c.ForbiddenAccountIds != nil { for _, id := range c.ForbiddenAccountIds { - if id == account_id { + if id == accountId { return fmt.Errorf("Forbidden account ID (%s)", id) } } @@ -359,11 +366,11 @@ func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) er if c.AllowedAccountIds != nil { for _, id := range c.AllowedAccountIds { - if id == account_id { + if id == accountId { return nil } } - return fmt.Errorf("Account ID not allowed (%s)", account_id) + return fmt.Errorf("Account ID not allowed (%s)", accountId) } return nil