Added accountid to AWSClient and set it early in the initialization phase. We use iam.GetUser(nil) scattered around to get the account id, but this isn't the most reliable method. GetAccountId now uses one more method (sts:GetCallerIdentity) to get the account id, this works with federated users.

This commit is contained in:
Kraig Amador 2016-04-27 18:49:42 -07:00
parent e04e87361f
commit a23bcf2ec9
3 changed files with 98 additions and 31 deletions

View File

@ -14,10 +14,11 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-cleanhttp"
) )
func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) { func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
// If we have creds from instance profile, we can use metadata API // If we have creds from instance profile, we can use metadata API
if authProviderName == ec2rolecreds.ProviderName { if authProviderName == ec2rolecreds.ProviderName {
log.Println("[DEBUG] Trying to get account ID via AWS Metadata API") log.Println("[DEBUG] Trying to get account ID via AWS Metadata API")
@ -42,16 +43,24 @@ func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
return parseAccountIdFromArn(*outUser.User.Arn) return parseAccountIdFromArn(*outUser.User.Arn)
} }
// Then try IAM ListRoles
awsErr, ok := err.(awserr.Error) awsErr, ok := err.(awserr.Error)
// AccessDenied and ValidationError can be raised // AccessDenied and ValidationError can be raised
// if credentials belong to federated profile, so we ignore these // if credentials belong to federated profile, so we ignore these
if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") { if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") {
return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err) return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
} }
log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err) log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err)
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles instead")
// Then try STS GetCallerIdentity
log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity")
outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return *outCallerIdentity.Account, nil
}
log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err)
// Then try IAM ListRoles
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles")
outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{ outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{
MaxItems: aws.Int64(int64(1)), MaxItems: aws.Int64(int64(1)),
}) })

View File

@ -18,6 +18,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
) )
func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) { func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
@ -28,10 +29,10 @@ func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
defer awsTs() defer awsTs()
iamEndpoints := []*iamEndpoint{} iamEndpoints := []*iamEndpoint{}
ts, iamConn := getMockedAwsIamApi(iamEndpoints) ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts() defer ts()
id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName) id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil { if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
} }
@ -55,10 +56,10 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"}, Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
}, },
} }
ts, iamConn := getMockedAwsIamApi(iamEndpoints) ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts() defer ts()
id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName) id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil { if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
} }
@ -76,10 +77,36 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"}, Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
}, },
} }
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts() defer ts()
id, err := GetAccountId(iamConn, "") id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}
expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}
func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"},
},
}
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil { if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err) t.Fatalf("Getting account ID via GetUser failed: %s", err)
} }
@ -96,15 +123,19 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
}, },
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"},
},
&iamEndpoint{ &iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"}, Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
}, },
} }
ts, iamConn := getMockedAwsIamApi(iamEndpoints) ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts() defer ts()
id, err := GetAccountId(iamConn, "") id, err := GetAccountId(iamConn, stsConn, "")
if err != nil { if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err) t.Fatalf("Getting account ID via ListRoles failed: %s", err)
} }
@ -126,10 +157,10 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"}, Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
}, },
} }
ts, iamConn := getMockedAwsIamApi(iamEndpoints) ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts() defer ts()
id, err := GetAccountId(iamConn, "") id, err := GetAccountId(iamConn, stsConn, "")
if err != nil { if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err) t.Fatalf("Getting account ID via ListRoles failed: %s", err)
} }
@ -151,10 +182,10 @@ func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) {
Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"}, Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"},
}, },
} }
ts, iamConn := getMockedAwsIamApi(iamEndpoints) ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts() defer ts()
id, err := GetAccountId(iamConn, "") id, err := GetAccountId(iamConn, stsConn, "")
if err == nil { if err == nil {
t.Fatal("Expected error when getting account ID") t.Fatal("Expected error when getting account ID")
} }
@ -586,15 +617,15 @@ func invalidAwsEnv(t *testing.T) func() {
return ts.Close return ts.Close
} }
// getMockedAwsIamApi establishes a httptest server to simulate behaviour // getMockedAwsIamStsApi establishes a httptest server to simulate behaviour
// of a real AWS' IAM server // of a real AWS' IAM & STS server
func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) { func getMockedAwsIamStsApi(endpoints []*iamEndpoint) (func(), *iam.IAM, *sts.STS) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
buf.ReadFrom(r.Body) buf.ReadFrom(r.Body)
requestBody := buf.String() requestBody := buf.String()
log.Printf("[DEBUG] Received IAM API %q request to %q: %s", log.Printf("[DEBUG] Received API %q request to %q: %s",
r.Method, r.RequestURI, requestBody) r.Method, r.RequestURI, requestBody)
for _, e := range endpoints { for _, e := range endpoints {
@ -624,8 +655,8 @@ func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
CredentialsChainVerboseErrors: aws.Bool(true), CredentialsChainVerboseErrors: aws.Bool(true),
}) })
iamConn := iam.New(sess) iamConn := iam.New(sess)
stsConn := sts.New(sess)
return ts.Close, iamConn return ts.Close, iamConn, stsConn
} }
func getEnv() *currentEnv { func getEnv() *currentEnv {
@ -718,6 +749,26 @@ const iamResponse_GetUser_unauthorized = `<ErrorResponse xmlns="https://iam.amaz
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId> <RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ErrorResponse>` </ErrorResponse>`
const stsResponse_GetCallerIdentity_valid = `<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult>
<Arn>arn:aws:iam::123456789012:user/Alice</Arn>
<UserId>AKIAI44QH8DHBEXAMPLE</UserId>
<Account>123456789012</Account>
</GetCallerIdentityResult>
<ResponseMetadata>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ResponseMetadata>
</GetCallerIdentityResponse>`
const stsResponse_GetCallerIdentity_unauthorized = `<ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<Error>
<Type>Sender</Type>
<Code>AccessDenied</Code>
<Message>User: arn:aws:iam::123456789012:user/Bob is not authorized to perform: sts:GetCallerIdentity</Message>
</Error>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ErrorResponse>`
const iamResponse_GetUser_federatedFailure = `<ErrorResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/"> const iamResponse_GetUser_federatedFailure = `<ErrorResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<Error> <Error>
<Type>Sender</Type> <Type>Sender</Type>

View File

@ -50,6 +50,7 @@ import (
"github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sns" "github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sts"
) )
type Config struct { type Config struct {
@ -92,8 +93,10 @@ type AWSClient struct {
s3conn *s3.S3 s3conn *s3.S3
sqsconn *sqs.SQS sqsconn *sqs.SQS
snsconn *sns.SNS snsconn *sns.SNS
stsconn *sts.STS
redshiftconn *redshift.Redshift redshiftconn *redshift.Redshift
r53conn *route53.Route53 r53conn *route53.Route53
accountid string
region string region string
rdsconn *rds.RDS rdsconn *rds.RDS
iamconn *iam.IAM iamconn *iam.IAM
@ -172,6 +175,9 @@ func (c *Config) Client() (interface{}, error) {
awsIamSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.IamEndpoint)}) awsIamSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.IamEndpoint)})
client.iamconn = iam.New(awsIamSess) client.iamconn = iam.New(awsIamSess)
log.Println("[INFO] Initializing STS connection")
client.stsconn = sts.New(sess)
err = c.ValidateCredentials(client.iamconn) err = c.ValidateCredentials(client.iamconn)
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
@ -185,6 +191,11 @@ func (c *Config) Client() (interface{}, error) {
// http://docs.aws.amazon.com/general/latest/gr/sigv4_changes.html // http://docs.aws.amazon.com/general/latest/gr/sigv4_changes.html
usEast1Sess := sess.Copy(&aws.Config{Region: aws.String("us-east-1")}) usEast1Sess := sess.Copy(&aws.Config{Region: aws.String("us-east-1")})
accountId, err := GetAccountId(client.iamconn, client.stsconn, cp.ProviderName)
if err == nil {
client.accountid = accountId
}
log.Println("[INFO] Initializing DynamoDB connection") log.Println("[INFO] Initializing DynamoDB connection")
dynamoSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.DynamoDBEndpoint)}) dynamoSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.DynamoDBEndpoint)})
client.dynamodbconn = dynamodb.New(dynamoSess) client.dynamodbconn = dynamodb.New(dynamoSess)
@ -215,7 +226,7 @@ func (c *Config) Client() (interface{}, error) {
log.Println("[INFO] Initializing Elastic Beanstalk Connection") log.Println("[INFO] Initializing Elastic Beanstalk Connection")
client.elasticbeanstalkconn = elasticbeanstalk.New(sess) client.elasticbeanstalkconn = elasticbeanstalk.New(sess)
authErr := c.ValidateAccountId(client.iamconn, cp.ProviderName) authErr := c.ValidateAccountId(client.accountid)
if authErr != nil { if authErr != nil {
errs = append(errs, authErr) errs = append(errs, authErr)
} }
@ -338,20 +349,16 @@ func (c *Config) ValidateCredentials(iamconn *iam.IAM) error {
// ValidateAccountId returns a context-specific error if the configured account // ValidateAccountId returns a context-specific error if the configured account
// id is explicitly forbidden or not authorised; and nil if it is authorised. // id is explicitly forbidden or not authorised; and nil if it is authorised.
func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) error { func (c *Config) ValidateAccountId(accountId string) error {
if c.AllowedAccountIds == nil && c.ForbiddenAccountIds == nil { if c.AllowedAccountIds == nil && c.ForbiddenAccountIds == nil {
return nil return nil
} }
log.Printf("[INFO] Validating account ID") log.Printf("[INFO] Validating account ID")
account_id, err := GetAccountId(iamconn, authProviderName)
if err != nil {
return err
}
if c.ForbiddenAccountIds != nil { if c.ForbiddenAccountIds != nil {
for _, id := range c.ForbiddenAccountIds { for _, id := range c.ForbiddenAccountIds {
if id == account_id { if id == accountId {
return fmt.Errorf("Forbidden account ID (%s)", id) return fmt.Errorf("Forbidden account ID (%s)", id)
} }
} }
@ -359,11 +366,11 @@ func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) er
if c.AllowedAccountIds != nil { if c.AllowedAccountIds != nil {
for _, id := range c.AllowedAccountIds { for _, id := range c.AllowedAccountIds {
if id == account_id { if id == accountId {
return nil return nil
} }
} }
return fmt.Errorf("Account ID not allowed (%s)", account_id) return fmt.Errorf("Account ID not allowed (%s)", accountId)
} }
return nil return nil