Added accountid to AWSClient and set it early in the initialization phase. We use iam.GetUser(nil) scattered around to get the account id, but this isn't the most reliable method. GetAccountId now uses one more method (sts:GetCallerIdentity) to get the account id, this works with federated users.
This commit is contained in:
parent
e04e87361f
commit
a23bcf2ec9
|
@ -14,10 +14,11 @@ import (
|
|||
"github.com/aws/aws-sdk-go/aws/ec2metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/iam"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
)
|
||||
|
||||
func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
|
||||
func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
|
||||
// If we have creds from instance profile, we can use metadata API
|
||||
if authProviderName == ec2rolecreds.ProviderName {
|
||||
log.Println("[DEBUG] Trying to get account ID via AWS Metadata API")
|
||||
|
@ -42,16 +43,24 @@ func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
|
|||
return parseAccountIdFromArn(*outUser.User.Arn)
|
||||
}
|
||||
|
||||
// Then try IAM ListRoles
|
||||
awsErr, ok := err.(awserr.Error)
|
||||
// AccessDenied and ValidationError can be raised
|
||||
// if credentials belong to federated profile, so we ignore these
|
||||
if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") {
|
||||
return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err)
|
||||
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles instead")
|
||||
|
||||
// Then try STS GetCallerIdentity
|
||||
log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity")
|
||||
outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
|
||||
if err == nil {
|
||||
return *outCallerIdentity.Account, nil
|
||||
}
|
||||
log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err)
|
||||
|
||||
// Then try IAM ListRoles
|
||||
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles")
|
||||
outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{
|
||||
MaxItems: aws.Int64(int64(1)),
|
||||
})
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/iam"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
)
|
||||
|
||||
func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
|
||||
|
@ -28,10 +29,10 @@ func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
|
|||
defer awsTs()
|
||||
|
||||
iamEndpoints := []*iamEndpoint{}
|
||||
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
|
||||
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
|
||||
defer ts()
|
||||
|
||||
id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
|
||||
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
|
||||
if err != nil {
|
||||
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
|
||||
}
|
||||
|
@ -55,10 +56,10 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
|
|||
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
|
||||
},
|
||||
}
|
||||
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
|
||||
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
|
||||
defer ts()
|
||||
|
||||
id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
|
||||
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
|
||||
if err != nil {
|
||||
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
|
||||
}
|
||||
|
@ -76,10 +77,36 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {
|
|||
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
|
||||
},
|
||||
}
|
||||
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
|
||||
|
||||
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
|
||||
defer ts()
|
||||
|
||||
id, err := GetAccountId(iamConn, "")
|
||||
id, err := GetAccountId(iamConn, stsConn, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Getting account ID via GetUser failed: %s", err)
|
||||
}
|
||||
|
||||
expectedAccountId := "123456789012"
|
||||
if id != expectedAccountId {
|
||||
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
|
||||
iamEndpoints := []*iamEndpoint{
|
||||
&iamEndpoint{
|
||||
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
|
||||
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
|
||||
},
|
||||
&iamEndpoint{
|
||||
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
|
||||
Response: &iamResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"},
|
||||
},
|
||||
}
|
||||
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
|
||||
defer ts()
|
||||
|
||||
id, err := GetAccountId(iamConn, stsConn, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Getting account ID via GetUser failed: %s", err)
|
||||
}
|
||||
|
@ -96,15 +123,19 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {
|
|||
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
|
||||
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
|
||||
},
|
||||
&iamEndpoint{
|
||||
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
|
||||
Response: &iamResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"},
|
||||
},
|
||||
&iamEndpoint{
|
||||
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
|
||||
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
|
||||
},
|
||||
}
|
||||
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
|
||||
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
|
||||
defer ts()
|
||||
|
||||
id, err := GetAccountId(iamConn, "")
|
||||
id, err := GetAccountId(iamConn, stsConn, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
|
||||
}
|
||||
|
@ -126,10 +157,10 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {
|
|||
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
|
||||
},
|
||||
}
|
||||
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
|
||||
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
|
||||
defer ts()
|
||||
|
||||
id, err := GetAccountId(iamConn, "")
|
||||
id, err := GetAccountId(iamConn, stsConn, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
|
||||
}
|
||||
|
@ -151,10 +182,10 @@ func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) {
|
|||
Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"},
|
||||
},
|
||||
}
|
||||
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
|
||||
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
|
||||
defer ts()
|
||||
|
||||
id, err := GetAccountId(iamConn, "")
|
||||
id, err := GetAccountId(iamConn, stsConn, "")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when getting account ID")
|
||||
}
|
||||
|
@ -586,15 +617,15 @@ func invalidAwsEnv(t *testing.T) func() {
|
|||
return ts.Close
|
||||
}
|
||||
|
||||
// getMockedAwsIamApi establishes a httptest server to simulate behaviour
|
||||
// of a real AWS' IAM server
|
||||
func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
|
||||
// getMockedAwsIamStsApi establishes a httptest server to simulate behaviour
|
||||
// of a real AWS' IAM & STS server
|
||||
func getMockedAwsIamStsApi(endpoints []*iamEndpoint) (func(), *iam.IAM, *sts.STS) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(r.Body)
|
||||
requestBody := buf.String()
|
||||
|
||||
log.Printf("[DEBUG] Received IAM API %q request to %q: %s",
|
||||
log.Printf("[DEBUG] Received API %q request to %q: %s",
|
||||
r.Method, r.RequestURI, requestBody)
|
||||
|
||||
for _, e := range endpoints {
|
||||
|
@ -624,8 +655,8 @@ func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
|
|||
CredentialsChainVerboseErrors: aws.Bool(true),
|
||||
})
|
||||
iamConn := iam.New(sess)
|
||||
|
||||
return ts.Close, iamConn
|
||||
stsConn := sts.New(sess)
|
||||
return ts.Close, iamConn, stsConn
|
||||
}
|
||||
|
||||
func getEnv() *currentEnv {
|
||||
|
@ -718,6 +749,26 @@ const iamResponse_GetUser_unauthorized = `<ErrorResponse xmlns="https://iam.amaz
|
|||
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
|
||||
</ErrorResponse>`
|
||||
|
||||
const stsResponse_GetCallerIdentity_valid = `<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
|
||||
<GetCallerIdentityResult>
|
||||
<Arn>arn:aws:iam::123456789012:user/Alice</Arn>
|
||||
<UserId>AKIAI44QH8DHBEXAMPLE</UserId>
|
||||
<Account>123456789012</Account>
|
||||
</GetCallerIdentityResult>
|
||||
<ResponseMetadata>
|
||||
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
|
||||
</ResponseMetadata>
|
||||
</GetCallerIdentityResponse>`
|
||||
|
||||
const stsResponse_GetCallerIdentity_unauthorized = `<ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
|
||||
<Error>
|
||||
<Type>Sender</Type>
|
||||
<Code>AccessDenied</Code>
|
||||
<Message>User: arn:aws:iam::123456789012:user/Bob is not authorized to perform: sts:GetCallerIdentity</Message>
|
||||
</Error>
|
||||
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
|
||||
</ErrorResponse>`
|
||||
|
||||
const iamResponse_GetUser_federatedFailure = `<ErrorResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
|
||||
<Error>
|
||||
<Type>Sender</Type>
|
||||
|
|
|
@ -50,6 +50,7 @@ import (
|
|||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/sns"
|
||||
"github.com/aws/aws-sdk-go/service/sqs"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
|
@ -92,8 +93,10 @@ type AWSClient struct {
|
|||
s3conn *s3.S3
|
||||
sqsconn *sqs.SQS
|
||||
snsconn *sns.SNS
|
||||
stsconn *sts.STS
|
||||
redshiftconn *redshift.Redshift
|
||||
r53conn *route53.Route53
|
||||
accountid string
|
||||
region string
|
||||
rdsconn *rds.RDS
|
||||
iamconn *iam.IAM
|
||||
|
@ -172,6 +175,9 @@ func (c *Config) Client() (interface{}, error) {
|
|||
awsIamSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.IamEndpoint)})
|
||||
client.iamconn = iam.New(awsIamSess)
|
||||
|
||||
log.Println("[INFO] Initializing STS connection")
|
||||
client.stsconn = sts.New(sess)
|
||||
|
||||
err = c.ValidateCredentials(client.iamconn)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
@ -185,6 +191,11 @@ func (c *Config) Client() (interface{}, error) {
|
|||
// http://docs.aws.amazon.com/general/latest/gr/sigv4_changes.html
|
||||
usEast1Sess := sess.Copy(&aws.Config{Region: aws.String("us-east-1")})
|
||||
|
||||
accountId, err := GetAccountId(client.iamconn, client.stsconn, cp.ProviderName)
|
||||
if err == nil {
|
||||
client.accountid = accountId
|
||||
}
|
||||
|
||||
log.Println("[INFO] Initializing DynamoDB connection")
|
||||
dynamoSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.DynamoDBEndpoint)})
|
||||
client.dynamodbconn = dynamodb.New(dynamoSess)
|
||||
|
@ -215,7 +226,7 @@ func (c *Config) Client() (interface{}, error) {
|
|||
log.Println("[INFO] Initializing Elastic Beanstalk Connection")
|
||||
client.elasticbeanstalkconn = elasticbeanstalk.New(sess)
|
||||
|
||||
authErr := c.ValidateAccountId(client.iamconn, cp.ProviderName)
|
||||
authErr := c.ValidateAccountId(client.accountid)
|
||||
if authErr != nil {
|
||||
errs = append(errs, authErr)
|
||||
}
|
||||
|
@ -338,20 +349,16 @@ func (c *Config) ValidateCredentials(iamconn *iam.IAM) error {
|
|||
|
||||
// ValidateAccountId returns a context-specific error if the configured account
|
||||
// id is explicitly forbidden or not authorised; and nil if it is authorised.
|
||||
func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) error {
|
||||
func (c *Config) ValidateAccountId(accountId string) error {
|
||||
if c.AllowedAccountIds == nil && c.ForbiddenAccountIds == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("[INFO] Validating account ID")
|
||||
account_id, err := GetAccountId(iamconn, authProviderName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.ForbiddenAccountIds != nil {
|
||||
for _, id := range c.ForbiddenAccountIds {
|
||||
if id == account_id {
|
||||
if id == accountId {
|
||||
return fmt.Errorf("Forbidden account ID (%s)", id)
|
||||
}
|
||||
}
|
||||
|
@ -359,11 +366,11 @@ func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) er
|
|||
|
||||
if c.AllowedAccountIds != nil {
|
||||
for _, id := range c.AllowedAccountIds {
|
||||
if id == account_id {
|
||||
if id == accountId {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("Account ID not allowed (%s)", account_id)
|
||||
return fmt.Errorf("Account ID not allowed (%s)", accountId)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Reference in New Issue