Merge pull request #5881 from hashicorp/b-aws-sg-r-protocols

provider/aws: Convert protocols to standard format for Security Groups
This commit is contained in:
Clint 2016-03-28 13:23:52 -05:00
commit 163173df7a
4 changed files with 247 additions and 10 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt"
"log"
"sort"
"strconv"
"strings"
"time"
@ -93,8 +94,9 @@ func resourceAwsSecurityGroup() *schema.Resource {
},
"protocol": &schema.Schema{
Type: schema.TypeString,
Required: true,
Type: schema.TypeString,
Required: true,
StateFunc: protocolStateFunc,
},
"cidr_blocks": &schema.Schema{
@ -137,8 +139,9 @@ func resourceAwsSecurityGroup() *schema.Resource {
},
"protocol": &schema.Schema{
Type: schema.TypeString,
Required: true,
Type: schema.TypeString,
Required: true,
StateFunc: protocolStateFunc,
},
"cidr_blocks": &schema.Schema{
@ -373,7 +376,8 @@ func resourceAwsSecurityGroupRuleHash(v interface{}) int {
m := v.(map[string]interface{})
buf.WriteString(fmt.Sprintf("%d-", m["from_port"].(int)))
buf.WriteString(fmt.Sprintf("%d-", m["to_port"].(int)))
buf.WriteString(fmt.Sprintf("%s-", m["protocol"].(string)))
p := protocolForValue(m["protocol"].(string))
buf.WriteString(fmt.Sprintf("%s-", p))
buf.WriteString(fmt.Sprintf("%t-", m["self"].(bool)))
// We need to make sure to sort the strings below so that we always
@ -824,3 +828,51 @@ func idHash(rType, protocol string, toPort, fromPort int64, self bool) string {
return fmt.Sprintf("rule-%d", hashcode.String(buf.String()))
}
// protocolStateFunc ensures we only store a string in any protocol field
func protocolStateFunc(v interface{}) string {
switch v.(type) {
case string:
p := protocolForValue(v.(string))
return p
default:
log.Printf("[WARN] Non String value given for Protocol: %#v", v)
return ""
}
}
// protocolForValue converts a valid Internet Protocol number into it's name
// representation. If a name is given, it validates that it's a proper protocol
// name. Names/numbers are as defined at
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
func protocolForValue(v string) string {
// special case -1
protocol := strings.ToLower(v)
if protocol == "-1" || protocol == "all" {
return "-1"
}
// if it's a name like tcp, return that
if _, ok := protocolIntegers()[protocol]; ok {
return protocol
}
// convert to int, look for that value
p, err := strconv.Atoi(protocol)
if err != nil {
// we were unable to convert to int, suggesting a string name, but it wasn't
// found above
log.Printf("[WARN] Unable to determine valid protocol: %s", err)
return protocol
}
for k, v := range protocolIntegers() {
if p == v {
// guard against protocolIntegers sometime in the future not having lower
// case ids in the map
return strings.ToLower(k)
}
}
// fall through
log.Printf("[WARN] Unable to determine valid protocol: no matching protocols found")
return protocol
}

View File

@ -44,9 +44,10 @@ func resourceAwsSecurityGroupRule() *schema.Resource {
},
"protocol": &schema.Schema{
Type: schema.TypeString,
Required: true,
ForceNew: true,
Type: schema.TypeString,
Required: true,
ForceNew: true,
StateFunc: protocolStateFunc,
},
"cidr_blocks": &schema.Schema{
@ -411,7 +412,8 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss
perm.FromPort = aws.Int64(int64(d.Get("from_port").(int)))
perm.ToPort = aws.Int64(int64(d.Get("to_port").(int)))
perm.IpProtocol = aws.String(d.Get("protocol").(string))
protocol := protocolForValue(d.Get("protocol").(string))
perm.IpProtocol = aws.String(protocol)
// build a group map that behaves like a set
groups := make(map[string]bool)

View File

@ -144,6 +144,43 @@ func TestAccAWSSecurityGroupRule_Ingress_VPC(t *testing.T) {
})
}
func TestAccAWSSecurityGroupRule_Ingress_Protocol(t *testing.T) {
var group ec2.SecurityGroup
testRuleCount := func(*terraform.State) error {
if len(group.IpPermissions) != 1 {
return fmt.Errorf("Wrong Security Group rule count, expected %d, got %d",
1, len(group.IpPermissions))
}
rule := group.IpPermissions[0]
if *rule.FromPort != int64(80) {
return fmt.Errorf("Wrong Security Group port setting, expected %d, got %d",
80, int(*rule.FromPort))
}
return nil
}
resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
CheckDestroy: testAccCheckAWSSecurityGroupRuleDestroy,
Steps: []resource.TestStep{
resource.TestStep{
Config: testAccAWSSecurityGroupRuleIngress_protocolConfig,
Check: resource.ComposeTestCheckFunc(
testAccCheckAWSSecurityGroupRuleExists("aws_security_group.web", &group),
testAccCheckAWSSecurityGroupRuleAttributes("aws_security_group_rule.ingress_1", &group, nil, "ingress"),
resource.TestCheckResourceAttr(
"aws_security_group_rule.ingress_1", "from_port", "80"),
testRuleCount,
),
},
},
})
}
func TestAccAWSSecurityGroupRule_Ingress_Classic(t *testing.T) {
var group ec2.SecurityGroup
@ -545,6 +582,35 @@ resource "aws_security_group_rule" "ingress_1" {
}
`
const testAccAWSSecurityGroupRuleIngress_protocolConfig = `
resource "aws_vpc" "tftest" {
cidr_block = "10.0.0.0/16"
tags {
Name = "tf-testing"
}
}
resource "aws_security_group" "web" {
vpc_id = "${aws_vpc.tftest.id}"
tags {
Name = "tf-acc-test"
}
}
resource "aws_security_group_rule" "ingress_1" {
type = "ingress"
protocol = "6"
from_port = 80
to_port = 8000
cidr_blocks = ["10.0.0.0/8"]
security_group_id = "${aws_security_group.web.id}"
}
`
const testAccAWSSecurityGroupRuleIssue5310 = `
provider "aws" {
region = "us-east-1"

View File

@ -15,6 +15,123 @@ import (
"github.com/hashicorp/terraform/terraform"
)
func TestProtocolStateFunc(t *testing.T) {
cases := []struct {
input interface{}
expected string
}{
{
input: "tcp",
expected: "tcp",
},
{
input: 6,
expected: "",
},
{
input: "17",
expected: "udp",
},
{
input: "all",
expected: "-1",
},
{
input: "-1",
expected: "-1",
},
{
input: -1,
expected: "",
},
{
input: "1",
expected: "icmp",
},
{
input: "icmp",
expected: "icmp",
},
{
input: 1,
expected: "",
},
}
for _, c := range cases {
result := protocolStateFunc(c.input)
if result != c.expected {
t.Errorf("Error matching protocol, expected (%s), got (%s)", c.expected, result)
}
}
}
func TestProtocolForValue(t *testing.T) {
cases := []struct {
input string
expected string
}{
{
input: "tcp",
expected: "tcp",
},
{
input: "6",
expected: "tcp",
},
{
input: "udp",
expected: "udp",
},
{
input: "17",
expected: "udp",
},
{
input: "all",
expected: "-1",
},
{
input: "-1",
expected: "-1",
},
{
input: "tCp",
expected: "tcp",
},
{
input: "6",
expected: "tcp",
},
{
input: "UDp",
expected: "udp",
},
{
input: "17",
expected: "udp",
},
{
input: "ALL",
expected: "-1",
},
{
input: "icMp",
expected: "icmp",
},
{
input: "1",
expected: "icmp",
},
}
for _, c := range cases {
result := protocolForValue(c.input)
if result != c.expected {
t.Errorf("Error matching protocol, expected (%s), got (%s)", c.expected, result)
}
}
}
func TestResourceAwsSecurityGroupIPPermGather(t *testing.T) {
raw := []*ec2.IpPermission{
&ec2.IpPermission{
@ -846,7 +963,7 @@ resource "aws_security_group" "web" {
description = "Used in the terraform acceptance tests"
ingress {
protocol = "tcp"
protocol = "6"
from_port = 80
to_port = 8000
cidr_blocks = ["10.0.0.0/8"]