Merge pull request #5881 from hashicorp/b-aws-sg-r-protocols
provider/aws: Convert protocols to standard format for Security Groups
This commit is contained in:
commit
163173df7a
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -93,8 +94,9 @@ func resourceAwsSecurityGroup() *schema.Resource {
|
|||
},
|
||||
|
||||
"protocol": &schema.Schema{
|
||||
Type: schema.TypeString,
|
||||
Required: true,
|
||||
Type: schema.TypeString,
|
||||
Required: true,
|
||||
StateFunc: protocolStateFunc,
|
||||
},
|
||||
|
||||
"cidr_blocks": &schema.Schema{
|
||||
|
@ -137,8 +139,9 @@ func resourceAwsSecurityGroup() *schema.Resource {
|
|||
},
|
||||
|
||||
"protocol": &schema.Schema{
|
||||
Type: schema.TypeString,
|
||||
Required: true,
|
||||
Type: schema.TypeString,
|
||||
Required: true,
|
||||
StateFunc: protocolStateFunc,
|
||||
},
|
||||
|
||||
"cidr_blocks": &schema.Schema{
|
||||
|
@ -373,7 +376,8 @@ func resourceAwsSecurityGroupRuleHash(v interface{}) int {
|
|||
m := v.(map[string]interface{})
|
||||
buf.WriteString(fmt.Sprintf("%d-", m["from_port"].(int)))
|
||||
buf.WriteString(fmt.Sprintf("%d-", m["to_port"].(int)))
|
||||
buf.WriteString(fmt.Sprintf("%s-", m["protocol"].(string)))
|
||||
p := protocolForValue(m["protocol"].(string))
|
||||
buf.WriteString(fmt.Sprintf("%s-", p))
|
||||
buf.WriteString(fmt.Sprintf("%t-", m["self"].(bool)))
|
||||
|
||||
// We need to make sure to sort the strings below so that we always
|
||||
|
@ -824,3 +828,51 @@ func idHash(rType, protocol string, toPort, fromPort int64, self bool) string {
|
|||
|
||||
return fmt.Sprintf("rule-%d", hashcode.String(buf.String()))
|
||||
}
|
||||
|
||||
// protocolStateFunc ensures we only store a string in any protocol field
|
||||
func protocolStateFunc(v interface{}) string {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
p := protocolForValue(v.(string))
|
||||
return p
|
||||
default:
|
||||
log.Printf("[WARN] Non String value given for Protocol: %#v", v)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// protocolForValue converts a valid Internet Protocol number into it's name
|
||||
// representation. If a name is given, it validates that it's a proper protocol
|
||||
// name. Names/numbers are as defined at
|
||||
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
|
||||
func protocolForValue(v string) string {
|
||||
// special case -1
|
||||
protocol := strings.ToLower(v)
|
||||
if protocol == "-1" || protocol == "all" {
|
||||
return "-1"
|
||||
}
|
||||
// if it's a name like tcp, return that
|
||||
if _, ok := protocolIntegers()[protocol]; ok {
|
||||
return protocol
|
||||
}
|
||||
// convert to int, look for that value
|
||||
p, err := strconv.Atoi(protocol)
|
||||
if err != nil {
|
||||
// we were unable to convert to int, suggesting a string name, but it wasn't
|
||||
// found above
|
||||
log.Printf("[WARN] Unable to determine valid protocol: %s", err)
|
||||
return protocol
|
||||
}
|
||||
|
||||
for k, v := range protocolIntegers() {
|
||||
if p == v {
|
||||
// guard against protocolIntegers sometime in the future not having lower
|
||||
// case ids in the map
|
||||
return strings.ToLower(k)
|
||||
}
|
||||
}
|
||||
|
||||
// fall through
|
||||
log.Printf("[WARN] Unable to determine valid protocol: no matching protocols found")
|
||||
return protocol
|
||||
}
|
||||
|
|
|
@ -44,9 +44,10 @@ func resourceAwsSecurityGroupRule() *schema.Resource {
|
|||
},
|
||||
|
||||
"protocol": &schema.Schema{
|
||||
Type: schema.TypeString,
|
||||
Required: true,
|
||||
ForceNew: true,
|
||||
Type: schema.TypeString,
|
||||
Required: true,
|
||||
ForceNew: true,
|
||||
StateFunc: protocolStateFunc,
|
||||
},
|
||||
|
||||
"cidr_blocks": &schema.Schema{
|
||||
|
@ -411,7 +412,8 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss
|
|||
|
||||
perm.FromPort = aws.Int64(int64(d.Get("from_port").(int)))
|
||||
perm.ToPort = aws.Int64(int64(d.Get("to_port").(int)))
|
||||
perm.IpProtocol = aws.String(d.Get("protocol").(string))
|
||||
protocol := protocolForValue(d.Get("protocol").(string))
|
||||
perm.IpProtocol = aws.String(protocol)
|
||||
|
||||
// build a group map that behaves like a set
|
||||
groups := make(map[string]bool)
|
||||
|
|
|
@ -144,6 +144,43 @@ func TestAccAWSSecurityGroupRule_Ingress_VPC(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestAccAWSSecurityGroupRule_Ingress_Protocol(t *testing.T) {
|
||||
var group ec2.SecurityGroup
|
||||
|
||||
testRuleCount := func(*terraform.State) error {
|
||||
if len(group.IpPermissions) != 1 {
|
||||
return fmt.Errorf("Wrong Security Group rule count, expected %d, got %d",
|
||||
1, len(group.IpPermissions))
|
||||
}
|
||||
|
||||
rule := group.IpPermissions[0]
|
||||
if *rule.FromPort != int64(80) {
|
||||
return fmt.Errorf("Wrong Security Group port setting, expected %d, got %d",
|
||||
80, int(*rule.FromPort))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
resource.Test(t, resource.TestCase{
|
||||
PreCheck: func() { testAccPreCheck(t) },
|
||||
Providers: testAccProviders,
|
||||
CheckDestroy: testAccCheckAWSSecurityGroupRuleDestroy,
|
||||
Steps: []resource.TestStep{
|
||||
resource.TestStep{
|
||||
Config: testAccAWSSecurityGroupRuleIngress_protocolConfig,
|
||||
Check: resource.ComposeTestCheckFunc(
|
||||
testAccCheckAWSSecurityGroupRuleExists("aws_security_group.web", &group),
|
||||
testAccCheckAWSSecurityGroupRuleAttributes("aws_security_group_rule.ingress_1", &group, nil, "ingress"),
|
||||
resource.TestCheckResourceAttr(
|
||||
"aws_security_group_rule.ingress_1", "from_port", "80"),
|
||||
testRuleCount,
|
||||
),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccAWSSecurityGroupRule_Ingress_Classic(t *testing.T) {
|
||||
var group ec2.SecurityGroup
|
||||
|
||||
|
@ -545,6 +582,35 @@ resource "aws_security_group_rule" "ingress_1" {
|
|||
}
|
||||
`
|
||||
|
||||
const testAccAWSSecurityGroupRuleIngress_protocolConfig = `
|
||||
resource "aws_vpc" "tftest" {
|
||||
cidr_block = "10.0.0.0/16"
|
||||
|
||||
tags {
|
||||
Name = "tf-testing"
|
||||
}
|
||||
}
|
||||
|
||||
resource "aws_security_group" "web" {
|
||||
vpc_id = "${aws_vpc.tftest.id}"
|
||||
|
||||
tags {
|
||||
Name = "tf-acc-test"
|
||||
}
|
||||
}
|
||||
|
||||
resource "aws_security_group_rule" "ingress_1" {
|
||||
type = "ingress"
|
||||
protocol = "6"
|
||||
from_port = 80
|
||||
to_port = 8000
|
||||
cidr_blocks = ["10.0.0.0/8"]
|
||||
|
||||
security_group_id = "${aws_security_group.web.id}"
|
||||
}
|
||||
|
||||
`
|
||||
|
||||
const testAccAWSSecurityGroupRuleIssue5310 = `
|
||||
provider "aws" {
|
||||
region = "us-east-1"
|
||||
|
|
|
@ -15,6 +15,123 @@ import (
|
|||
"github.com/hashicorp/terraform/terraform"
|
||||
)
|
||||
|
||||
func TestProtocolStateFunc(t *testing.T) {
|
||||
cases := []struct {
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
input: "tcp",
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
input: 6,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
input: "17",
|
||||
expected: "udp",
|
||||
},
|
||||
{
|
||||
input: "all",
|
||||
expected: "-1",
|
||||
},
|
||||
{
|
||||
input: "-1",
|
||||
expected: "-1",
|
||||
},
|
||||
{
|
||||
input: -1,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
input: "1",
|
||||
expected: "icmp",
|
||||
},
|
||||
{
|
||||
input: "icmp",
|
||||
expected: "icmp",
|
||||
},
|
||||
{
|
||||
input: 1,
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
result := protocolStateFunc(c.input)
|
||||
if result != c.expected {
|
||||
t.Errorf("Error matching protocol, expected (%s), got (%s)", c.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocolForValue(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
input: "tcp",
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
input: "6",
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
input: "udp",
|
||||
expected: "udp",
|
||||
},
|
||||
{
|
||||
input: "17",
|
||||
expected: "udp",
|
||||
},
|
||||
{
|
||||
input: "all",
|
||||
expected: "-1",
|
||||
},
|
||||
{
|
||||
input: "-1",
|
||||
expected: "-1",
|
||||
},
|
||||
{
|
||||
input: "tCp",
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
input: "6",
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
input: "UDp",
|
||||
expected: "udp",
|
||||
},
|
||||
{
|
||||
input: "17",
|
||||
expected: "udp",
|
||||
},
|
||||
{
|
||||
input: "ALL",
|
||||
expected: "-1",
|
||||
},
|
||||
{
|
||||
input: "icMp",
|
||||
expected: "icmp",
|
||||
},
|
||||
{
|
||||
input: "1",
|
||||
expected: "icmp",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
result := protocolForValue(c.input)
|
||||
if result != c.expected {
|
||||
t.Errorf("Error matching protocol, expected (%s), got (%s)", c.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceAwsSecurityGroupIPPermGather(t *testing.T) {
|
||||
raw := []*ec2.IpPermission{
|
||||
&ec2.IpPermission{
|
||||
|
@ -846,7 +963,7 @@ resource "aws_security_group" "web" {
|
|||
description = "Used in the terraform acceptance tests"
|
||||
|
||||
ingress {
|
||||
protocol = "tcp"
|
||||
protocol = "6"
|
||||
from_port = 80
|
||||
to_port = 8000
|
||||
cidr_blocks = ["10.0.0.0/8"]
|
||||
|
|
Loading…
Reference in New Issue