diff --git a/builtin/providers/aws/resource_aws_security_group.go b/builtin/providers/aws/resource_aws_security_group.go index ead5b472f..d364c1fb7 100644 --- a/builtin/providers/aws/resource_aws_security_group.go +++ b/builtin/providers/aws/resource_aws_security_group.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "sort" + "strconv" "strings" "time" @@ -93,8 +94,9 @@ func resourceAwsSecurityGroup() *schema.Resource { }, "protocol": &schema.Schema{ - Type: schema.TypeString, - Required: true, + Type: schema.TypeString, + Required: true, + StateFunc: protocolStateFunc, }, "cidr_blocks": &schema.Schema{ @@ -137,8 +139,9 @@ func resourceAwsSecurityGroup() *schema.Resource { }, "protocol": &schema.Schema{ - Type: schema.TypeString, - Required: true, + Type: schema.TypeString, + Required: true, + StateFunc: protocolStateFunc, }, "cidr_blocks": &schema.Schema{ @@ -373,7 +376,8 @@ func resourceAwsSecurityGroupRuleHash(v interface{}) int { m := v.(map[string]interface{}) buf.WriteString(fmt.Sprintf("%d-", m["from_port"].(int))) buf.WriteString(fmt.Sprintf("%d-", m["to_port"].(int))) - buf.WriteString(fmt.Sprintf("%s-", m["protocol"].(string))) + p := protocolForValue(m["protocol"].(string)) + buf.WriteString(fmt.Sprintf("%s-", p)) buf.WriteString(fmt.Sprintf("%t-", m["self"].(bool))) // We need to make sure to sort the strings below so that we always @@ -824,3 +828,51 @@ func idHash(rType, protocol string, toPort, fromPort int64, self bool) string { return fmt.Sprintf("rule-%d", hashcode.String(buf.String())) } + +// protocolStateFunc ensures we only store a string in any protocol field +func protocolStateFunc(v interface{}) string { + switch v.(type) { + case string: + p := protocolForValue(v.(string)) + return p + default: + log.Printf("[WARN] Non String value given for Protocol: %#v", v) + return "" + } +} + +// protocolForValue converts a valid Internet Protocol number into it's name +// representation. If a name is given, it validates that it's a proper protocol +// name. Names/numbers are as defined at +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +func protocolForValue(v string) string { + // special case -1 + protocol := strings.ToLower(v) + if protocol == "-1" || protocol == "all" { + return "-1" + } + // if it's a name like tcp, return that + if _, ok := protocolIntegers()[protocol]; ok { + return protocol + } + // convert to int, look for that value + p, err := strconv.Atoi(protocol) + if err != nil { + // we were unable to convert to int, suggesting a string name, but it wasn't + // found above + log.Printf("[WARN] Unable to determine valid protocol: %s", err) + return protocol + } + + for k, v := range protocolIntegers() { + if p == v { + // guard against protocolIntegers sometime in the future not having lower + // case ids in the map + return strings.ToLower(k) + } + } + + // fall through + log.Printf("[WARN] Unable to determine valid protocol: no matching protocols found") + return protocol +} diff --git a/builtin/providers/aws/resource_aws_security_group_rule.go b/builtin/providers/aws/resource_aws_security_group_rule.go index 005163fba..715a0a5cd 100644 --- a/builtin/providers/aws/resource_aws_security_group_rule.go +++ b/builtin/providers/aws/resource_aws_security_group_rule.go @@ -44,9 +44,10 @@ func resourceAwsSecurityGroupRule() *schema.Resource { }, "protocol": &schema.Schema{ - Type: schema.TypeString, - Required: true, - ForceNew: true, + Type: schema.TypeString, + Required: true, + ForceNew: true, + StateFunc: protocolStateFunc, }, "cidr_blocks": &schema.Schema{ @@ -411,7 +412,8 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss perm.FromPort = aws.Int64(int64(d.Get("from_port").(int))) perm.ToPort = aws.Int64(int64(d.Get("to_port").(int))) - perm.IpProtocol = aws.String(d.Get("protocol").(string)) + protocol := protocolForValue(d.Get("protocol").(string)) + perm.IpProtocol = aws.String(protocol) // build a group map that behaves like a set groups := make(map[string]bool) diff --git a/builtin/providers/aws/resource_aws_security_group_rule_test.go b/builtin/providers/aws/resource_aws_security_group_rule_test.go index 9f459e4ca..0e46d02fc 100644 --- a/builtin/providers/aws/resource_aws_security_group_rule_test.go +++ b/builtin/providers/aws/resource_aws_security_group_rule_test.go @@ -144,6 +144,43 @@ func TestAccAWSSecurityGroupRule_Ingress_VPC(t *testing.T) { }) } +func TestAccAWSSecurityGroupRule_Ingress_Protocol(t *testing.T) { + var group ec2.SecurityGroup + + testRuleCount := func(*terraform.State) error { + if len(group.IpPermissions) != 1 { + return fmt.Errorf("Wrong Security Group rule count, expected %d, got %d", + 1, len(group.IpPermissions)) + } + + rule := group.IpPermissions[0] + if *rule.FromPort != int64(80) { + return fmt.Errorf("Wrong Security Group port setting, expected %d, got %d", + 80, int(*rule.FromPort)) + } + + return nil + } + + resource.Test(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + CheckDestroy: testAccCheckAWSSecurityGroupRuleDestroy, + Steps: []resource.TestStep{ + resource.TestStep{ + Config: testAccAWSSecurityGroupRuleIngress_protocolConfig, + Check: resource.ComposeTestCheckFunc( + testAccCheckAWSSecurityGroupRuleExists("aws_security_group.web", &group), + testAccCheckAWSSecurityGroupRuleAttributes("aws_security_group_rule.ingress_1", &group, nil, "ingress"), + resource.TestCheckResourceAttr( + "aws_security_group_rule.ingress_1", "from_port", "80"), + testRuleCount, + ), + }, + }, + }) +} + func TestAccAWSSecurityGroupRule_Ingress_Classic(t *testing.T) { var group ec2.SecurityGroup @@ -545,6 +582,35 @@ resource "aws_security_group_rule" "ingress_1" { } ` +const testAccAWSSecurityGroupRuleIngress_protocolConfig = ` +resource "aws_vpc" "tftest" { + cidr_block = "10.0.0.0/16" + + tags { + Name = "tf-testing" + } +} + +resource "aws_security_group" "web" { + vpc_id = "${aws_vpc.tftest.id}" + + tags { + Name = "tf-acc-test" + } +} + +resource "aws_security_group_rule" "ingress_1" { + type = "ingress" + protocol = "6" + from_port = 80 + to_port = 8000 + cidr_blocks = ["10.0.0.0/8"] + + security_group_id = "${aws_security_group.web.id}" +} + +` + const testAccAWSSecurityGroupRuleIssue5310 = ` provider "aws" { region = "us-east-1" diff --git a/builtin/providers/aws/resource_aws_security_group_test.go b/builtin/providers/aws/resource_aws_security_group_test.go index af89df87a..2b5f97140 100644 --- a/builtin/providers/aws/resource_aws_security_group_test.go +++ b/builtin/providers/aws/resource_aws_security_group_test.go @@ -15,6 +15,123 @@ import ( "github.com/hashicorp/terraform/terraform" ) +func TestProtocolStateFunc(t *testing.T) { + cases := []struct { + input interface{} + expected string + }{ + { + input: "tcp", + expected: "tcp", + }, + { + input: 6, + expected: "", + }, + { + input: "17", + expected: "udp", + }, + { + input: "all", + expected: "-1", + }, + { + input: "-1", + expected: "-1", + }, + { + input: -1, + expected: "", + }, + { + input: "1", + expected: "icmp", + }, + { + input: "icmp", + expected: "icmp", + }, + { + input: 1, + expected: "", + }, + } + for _, c := range cases { + result := protocolStateFunc(c.input) + if result != c.expected { + t.Errorf("Error matching protocol, expected (%s), got (%s)", c.expected, result) + } + } +} + +func TestProtocolForValue(t *testing.T) { + cases := []struct { + input string + expected string + }{ + { + input: "tcp", + expected: "tcp", + }, + { + input: "6", + expected: "tcp", + }, + { + input: "udp", + expected: "udp", + }, + { + input: "17", + expected: "udp", + }, + { + input: "all", + expected: "-1", + }, + { + input: "-1", + expected: "-1", + }, + { + input: "tCp", + expected: "tcp", + }, + { + input: "6", + expected: "tcp", + }, + { + input: "UDp", + expected: "udp", + }, + { + input: "17", + expected: "udp", + }, + { + input: "ALL", + expected: "-1", + }, + { + input: "icMp", + expected: "icmp", + }, + { + input: "1", + expected: "icmp", + }, + } + + for _, c := range cases { + result := protocolForValue(c.input) + if result != c.expected { + t.Errorf("Error matching protocol, expected (%s), got (%s)", c.expected, result) + } + } +} + func TestResourceAwsSecurityGroupIPPermGather(t *testing.T) { raw := []*ec2.IpPermission{ &ec2.IpPermission{ @@ -846,7 +963,7 @@ resource "aws_security_group" "web" { description = "Used in the terraform acceptance tests" ingress { - protocol = "tcp" + protocol = "6" from_port = 80 to_port = 8000 cidr_blocks = ["10.0.0.0/8"]