Added accountid to AWSClient and set it early in the initialization phase. We use iam.GetUser(nil) scattered around to get the account id, but this isn't the most reliable method. GetAccountId now uses one more method (sts:GetCallerIdentity) to get the account id, this works with federated users.

This commit is contained in:
Kraig Amador 2016-04-27 18:49:42 -07:00
parent e04e87361f
commit a23bcf2ec9
3 changed files with 98 additions and 31 deletions

View File

@ -14,10 +14,11 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp"
)
func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
// If we have creds from instance profile, we can use metadata API
if authProviderName == ec2rolecreds.ProviderName {
log.Println("[DEBUG] Trying to get account ID via AWS Metadata API")
@ -42,16 +43,24 @@ func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
return parseAccountIdFromArn(*outUser.User.Arn)
}
// Then try IAM ListRoles
awsErr, ok := err.(awserr.Error)
// AccessDenied and ValidationError can be raised
// if credentials belong to federated profile, so we ignore these
if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") {
return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
}
log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err)
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles instead")
// Then try STS GetCallerIdentity
log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity")
outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return *outCallerIdentity.Account, nil
}
log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err)
// Then try IAM ListRoles
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles")
outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{
MaxItems: aws.Int64(int64(1)),
})

View File

@ -18,6 +18,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
)
func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
@ -28,10 +29,10 @@ func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
defer awsTs()
iamEndpoints := []*iamEndpoint{}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()
id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}
@ -55,10 +56,10 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()
id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}
@ -76,10 +77,36 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()
id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}
expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}
func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"},
},
}
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}
@ -96,15 +123,19 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()
id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}
@ -126,10 +157,10 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()
id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}
@ -151,10 +182,10 @@ func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) {
Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()
id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err == nil {
t.Fatal("Expected error when getting account ID")
}
@ -586,15 +617,15 @@ func invalidAwsEnv(t *testing.T) func() {
return ts.Close
}
// getMockedAwsIamApi establishes a httptest server to simulate behaviour
// of a real AWS' IAM server
func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
// getMockedAwsIamStsApi establishes a httptest server to simulate behaviour
// of a real AWS' IAM & STS server
func getMockedAwsIamStsApi(endpoints []*iamEndpoint) (func(), *iam.IAM, *sts.STS) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf := new(bytes.Buffer)
buf.ReadFrom(r.Body)
requestBody := buf.String()
log.Printf("[DEBUG] Received IAM API %q request to %q: %s",
log.Printf("[DEBUG] Received API %q request to %q: %s",
r.Method, r.RequestURI, requestBody)
for _, e := range endpoints {
@ -624,8 +655,8 @@ func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
CredentialsChainVerboseErrors: aws.Bool(true),
})
iamConn := iam.New(sess)
return ts.Close, iamConn
stsConn := sts.New(sess)
return ts.Close, iamConn, stsConn
}
func getEnv() *currentEnv {
@ -718,6 +749,26 @@ const iamResponse_GetUser_unauthorized = `<ErrorResponse xmlns="https://iam.amaz
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ErrorResponse>`
const stsResponse_GetCallerIdentity_valid = `<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult>
<Arn>arn:aws:iam::123456789012:user/Alice</Arn>
<UserId>AKIAI44QH8DHBEXAMPLE</UserId>
<Account>123456789012</Account>
</GetCallerIdentityResult>
<ResponseMetadata>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ResponseMetadata>
</GetCallerIdentityResponse>`
const stsResponse_GetCallerIdentity_unauthorized = `<ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<Error>
<Type>Sender</Type>
<Code>AccessDenied</Code>
<Message>User: arn:aws:iam::123456789012:user/Bob is not authorized to perform: sts:GetCallerIdentity</Message>
</Error>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ErrorResponse>`
const iamResponse_GetUser_federatedFailure = `<ErrorResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<Error>
<Type>Sender</Type>

View File

@ -50,6 +50,7 @@ import (
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sts"
)
type Config struct {
@ -92,8 +93,10 @@ type AWSClient struct {
s3conn *s3.S3
sqsconn *sqs.SQS
snsconn *sns.SNS
stsconn *sts.STS
redshiftconn *redshift.Redshift
r53conn *route53.Route53
accountid string
region string
rdsconn *rds.RDS
iamconn *iam.IAM
@ -172,6 +175,9 @@ func (c *Config) Client() (interface{}, error) {
awsIamSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.IamEndpoint)})
client.iamconn = iam.New(awsIamSess)
log.Println("[INFO] Initializing STS connection")
client.stsconn = sts.New(sess)
err = c.ValidateCredentials(client.iamconn)
if err != nil {
errs = append(errs, err)
@ -185,6 +191,11 @@ func (c *Config) Client() (interface{}, error) {
// http://docs.aws.amazon.com/general/latest/gr/sigv4_changes.html
usEast1Sess := sess.Copy(&aws.Config{Region: aws.String("us-east-1")})
accountId, err := GetAccountId(client.iamconn, client.stsconn, cp.ProviderName)
if err == nil {
client.accountid = accountId
}
log.Println("[INFO] Initializing DynamoDB connection")
dynamoSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.DynamoDBEndpoint)})
client.dynamodbconn = dynamodb.New(dynamoSess)
@ -215,7 +226,7 @@ func (c *Config) Client() (interface{}, error) {
log.Println("[INFO] Initializing Elastic Beanstalk Connection")
client.elasticbeanstalkconn = elasticbeanstalk.New(sess)
authErr := c.ValidateAccountId(client.iamconn, cp.ProviderName)
authErr := c.ValidateAccountId(client.accountid)
if authErr != nil {
errs = append(errs, authErr)
}
@ -338,20 +349,16 @@ func (c *Config) ValidateCredentials(iamconn *iam.IAM) error {
// ValidateAccountId returns a context-specific error if the configured account
// id is explicitly forbidden or not authorised; and nil if it is authorised.
func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) error {
func (c *Config) ValidateAccountId(accountId string) error {
if c.AllowedAccountIds == nil && c.ForbiddenAccountIds == nil {
return nil
}
log.Printf("[INFO] Validating account ID")
account_id, err := GetAccountId(iamconn, authProviderName)
if err != nil {
return err
}
if c.ForbiddenAccountIds != nil {
for _, id := range c.ForbiddenAccountIds {
if id == account_id {
if id == accountId {
return fmt.Errorf("Forbidden account ID (%s)", id)
}
}
@ -359,11 +366,11 @@ func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) er
if c.AllowedAccountIds != nil {
for _, id := range c.AllowedAccountIds {
if id == account_id {
if id == accountId {
return nil
}
}
return fmt.Errorf("Account ID not allowed (%s)", account_id)
return fmt.Errorf("Account ID not allowed (%s)", accountId)
}
return nil