diff --git a/builtin/providers/aws/auth_helpers.go b/builtin/providers/aws/auth_helpers.go index 97087e0ae..6e48679ba 100644 --- a/builtin/providers/aws/auth_helpers.go +++ b/builtin/providers/aws/auth_helpers.go @@ -11,12 +11,14 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" awsCredentials "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/sts" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-multierror" ) func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) { @@ -92,7 +94,9 @@ func parseAccountIdFromArn(arn string) (string, error) { // This function is responsible for reading credentials from the // environment in the case that they're not explicitly specified // in the Terraform configuration. -func GetCredentials(c *Config) *awsCredentials.Credentials { +func GetCredentials(c *Config) (*awsCredentials.Credentials, error) { + var errs []error + // build a chain provider, lazy-evaulated by aws-sdk providers := []awsCredentials.Provider{ &awsCredentials.StaticProvider{Value: awsCredentials.Value{ @@ -137,7 +141,40 @@ func GetCredentials(c *Config) *awsCredentials.Credentials { } } - return awsCredentials.NewChainCredentials(providers) + if c.RoleArn != "" { + log.Printf("[INFO] attempting to assume role %s", c.RoleArn) + + creds := awsCredentials.NewChainCredentials(providers) + cp, err := creds.Get() + if err != nil { + if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" { + errs = append(errs, fmt.Errorf(`No valid credential sources found for AWS Provider. + Please see https://terraform.io/docs/providers/aws/index.html for more information on + providing credentials for the AWS Provider`)) + } else { + errs = append(errs, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)) + } + return nil, &multierror.Error{Errors: errs} + } + + log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName) + + awsConfig := &aws.Config{ + Credentials: creds, + Region: aws.String(c.Region), + MaxRetries: aws.Int(c.MaxRetries), + HTTPClient: cleanhttp.DefaultClient(), + S3ForcePathStyle: aws.Bool(c.S3ForcePathStyle), + } + + stsclient := sts.New(session.New(awsConfig)) + providers = []awsCredentials.Provider{&stscreds.AssumeRoleProvider{ + Client: stsclient, + RoleARN: c.RoleArn, + }} + } + + return awsCredentials.NewChainCredentials(providers), nil } func setOptionalEndpoint(cfg *aws.Config) string { diff --git a/builtin/providers/aws/auth_helpers_test.go b/builtin/providers/aws/auth_helpers_test.go index 1248259b7..8c5f60c53 100644 --- a/builtin/providers/aws/auth_helpers_test.go +++ b/builtin/providers/aws/auth_helpers_test.go @@ -218,8 +218,13 @@ func TestAWSGetCredentials_shouldError(t *testing.T) { defer resetEnv() cfg := Config{} - c := GetCredentials(&cfg) - _, err := c.Get() + c, err := GetCredentials(&cfg) + if awsErr, ok := err.(awserr.Error); ok { + if awsErr.Code() != "NoCredentialProviders" { + t.Fatalf("Expected NoCredentialProviders error") + } + } + _, err = c.Get() if awsErr, ok := err.(awserr.Error); ok { if awsErr.Code() != "NoCredentialProviders" { t.Fatalf("Expected NoCredentialProviders error") @@ -251,10 +256,13 @@ func TestAWSGetCredentials_shouldBeStatic(t *testing.T) { Token: c.Token, } - creds := GetCredentials(&cfg) + creds, err := GetCredentials(&cfg) if creds == nil { t.Fatalf("Expected a static creds provider to be returned") } + if err != nil { + t.Fatalf("Error gettings creds: %s", err) + } v, err := creds.Get() if err != nil { t.Fatalf("Error gettings creds: %s", err) @@ -286,11 +294,13 @@ func TestAWSGetCredentials_shouldIAM(t *testing.T) { // An empty config, no key supplied cfg := Config{} - creds := GetCredentials(&cfg) + creds, err := GetCredentials(&cfg) if creds == nil { t.Fatalf("Expected a static creds provider to be returned") } - + if err != nil { + t.Fatalf("Error gettings creds: %s", err) + } v, err := creds.Get() if err != nil { t.Fatalf("Error gettings creds: %s", err) @@ -335,10 +345,13 @@ func TestAWSGetCredentials_shouldIgnoreIAM(t *testing.T) { Token: c.Token, } - creds := GetCredentials(&cfg) + creds, err := GetCredentials(&cfg) if creds == nil { t.Fatalf("Expected a static creds provider to be returned") } + if err != nil { + t.Fatalf("Error gettings creds: %s", err) + } v, err := creds.Get() if err != nil { t.Fatalf("Error gettings creds: %s", err) @@ -362,7 +375,10 @@ func TestAWSGetCredentials_shouldErrorWithInvalidEndpoint(t *testing.T) { ts := invalidAwsEnv(t) defer ts() - creds := GetCredentials(&Config{}) + creds, err := GetCredentials(&Config{}) + if err != nil { + t.Fatalf("Error gettings creds: %s", err) + } v, err := creds.Get() if err == nil { t.Fatal("Expected error returned when getting creds w/ invalid EC2 endpoint") @@ -380,7 +396,10 @@ func TestAWSGetCredentials_shouldIgnoreInvalidEndpoint(t *testing.T) { ts := invalidAwsEnv(t) defer ts() - creds := GetCredentials(&Config{AccessKey: "accessKey", SecretKey: "secretKey"}) + creds, err := GetCredentials(&Config{AccessKey: "accessKey", SecretKey: "secretKey"}) + if err != nil { + t.Fatalf("Error gettings creds: %s", err) + } v, err := creds.Get() if err != nil { t.Fatalf("Getting static credentials w/ invalid EC2 endpoint failed: %s", err) @@ -406,10 +425,13 @@ func TestAWSGetCredentials_shouldCatchEC2RoleProvider(t *testing.T) { ts := awsEnv(t) defer ts() - creds := GetCredentials(&Config{}) + creds, err := GetCredentials(&Config{}) if creds == nil { t.Fatalf("Expected an EC2Role creds provider to be returned") } + if err != nil { + t.Fatalf("Error gettings creds: %s", err) + } v, err := creds.Get() if err != nil { t.Fatalf("Expected no error when getting creds: %s", err) @@ -452,10 +474,13 @@ func TestAWSGetCredentials_shouldBeShared(t *testing.T) { t.Fatalf("Error resetting env var AWS_SHARED_CREDENTIALS_FILE: %s", err) } - creds := GetCredentials(&Config{Profile: "myprofile", CredsFilename: file.Name()}) + creds, err := GetCredentials(&Config{Profile: "myprofile", CredsFilename: file.Name()}) if creds == nil { t.Fatalf("Expected a provider chain to be returned") } + if err != nil { + t.Fatalf("Error gettings creds: %s", err) + } v, err := creds.Get() if err != nil { t.Fatalf("Error gettings creds: %s", err) @@ -479,10 +504,13 @@ func TestAWSGetCredentials_shouldBeENV(t *testing.T) { defer resetEnv() cfg := Config{} - creds := GetCredentials(&cfg) + creds, err := GetCredentials(&cfg) if creds == nil { t.Fatalf("Expected a static creds provider to be returned") } + if err != nil { + t.Fatalf("Error gettings creds: %s", err) + } v, err := creds.Get() if err != nil { t.Fatalf("Error gettings creds: %s", err) diff --git a/builtin/providers/aws/config.go b/builtin/providers/aws/config.go index e67fef103..ddd8e4977 100644 --- a/builtin/providers/aws/config.go +++ b/builtin/providers/aws/config.go @@ -66,6 +66,7 @@ type Config struct { Profile string Token string Region string + RoleArn string MaxRetries int AllowedAccountIds []interface{} @@ -150,7 +151,10 @@ func (c *Config) Client() (interface{}, error) { client.region = c.Region log.Println("[INFO] Building AWS auth structure") - creds := GetCredentials(c) + creds, err := GetCredentials(c) + if err != nil { + return nil, &multierror.Error{Errors: errs} + } // Call Get to check for credential provider. If nothing found, we'll get an // error, and we can present it nicely to the user cp, err := creds.Get() diff --git a/builtin/providers/aws/provider.go b/builtin/providers/aws/provider.go index 9d1580d2b..4ae3d9942 100644 --- a/builtin/providers/aws/provider.go +++ b/builtin/providers/aws/provider.go @@ -64,6 +64,13 @@ func Provider() terraform.ResourceProvider { InputDefault: "us-east-1", }, + "role_arn": &schema.Schema{ + Type: schema.TypeString, + Optional: true, + Default: "", + Description: descriptions["role_arn"], + }, + "max_retries": &schema.Schema{ Type: schema.TypeInt, Optional: true, @@ -353,6 +360,8 @@ func init() { "profile": "The profile for API operations. If not set, the default profile\n" + "created with `aws configure` will be used.", + "role_arn": "The role to be assumed using the supplied access_key and secret_key", + "shared_credentials_file": "The path to the shared credentials file. If not set\n" + "this defaults to ~/.aws/credentials.", @@ -404,6 +413,7 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) { CredsFilename: d.Get("shared_credentials_file").(string), Token: d.Get("token").(string), Region: d.Get("region").(string), + RoleArn: d.Get("role_arn").(string), MaxRetries: d.Get("max_retries").(int), DynamoDBEndpoint: d.Get("dynamodb_endpoint").(string), KinesisEndpoint: d.Get("kinesis_endpoint").(string), diff --git a/state/remote/s3.go b/state/remote/s3.go index 026e50a11..eb810eceb 100644 --- a/state/remote/s3.go +++ b/state/remote/s3.go @@ -60,7 +60,7 @@ func s3Factory(conf map[string]string) (Client, error) { kmsKeyID := conf["kms_key_id"] var errs []error - creds := terraformAws.GetCredentials(&terraformAws.Config{ + creds, err := terraformAws.GetCredentials(&terraformAws.Config{ AccessKey: conf["access_key"], SecretKey: conf["secret_key"], Token: conf["token"], @@ -69,7 +69,7 @@ func s3Factory(conf map[string]string) (Client, error) { }) // Call Get to check for credential provider. If nothing found, we'll get an // error, and we can present it nicely to the user - _, err := creds.Get() + _, err = creds.Get() if err != nil { if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" { errs = append(errs, fmt.Errorf(`No valid credential sources found for AWS S3 remote. diff --git a/website/source/docs/providers/aws/index.html.markdown b/website/source/docs/providers/aws/index.html.markdown index 853347735..31748fe4c 100644 --- a/website/source/docs/providers/aws/index.html.markdown +++ b/website/source/docs/providers/aws/index.html.markdown @@ -111,6 +111,19 @@ You can provide custom metadata API endpoint via `AWS_METADATA_ENDPOINT` variabl which expects the endpoint URL including the version and defaults to `http://169.254.169.254:80/latest`. +###Assume role + +If provided with a role arn, terraform will attempt to assume this role +using the supplied credentials. + +Usage: + +``` +provider "aws" { + role_arn = "arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME" +} +``` + ## Argument Reference The following arguments are supported in the `provider` block: