Merge pull request #5881 from hashicorp/b-aws-sg-r-protocols
provider/aws: Convert protocols to standard format for Security Groups
This commit is contained in:
commit
163173df7a
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -93,8 +94,9 @@ func resourceAwsSecurityGroup() *schema.Resource {
|
||||||
},
|
},
|
||||||
|
|
||||||
"protocol": &schema.Schema{
|
"protocol": &schema.Schema{
|
||||||
Type: schema.TypeString,
|
Type: schema.TypeString,
|
||||||
Required: true,
|
Required: true,
|
||||||
|
StateFunc: protocolStateFunc,
|
||||||
},
|
},
|
||||||
|
|
||||||
"cidr_blocks": &schema.Schema{
|
"cidr_blocks": &schema.Schema{
|
||||||
|
@ -137,8 +139,9 @@ func resourceAwsSecurityGroup() *schema.Resource {
|
||||||
},
|
},
|
||||||
|
|
||||||
"protocol": &schema.Schema{
|
"protocol": &schema.Schema{
|
||||||
Type: schema.TypeString,
|
Type: schema.TypeString,
|
||||||
Required: true,
|
Required: true,
|
||||||
|
StateFunc: protocolStateFunc,
|
||||||
},
|
},
|
||||||
|
|
||||||
"cidr_blocks": &schema.Schema{
|
"cidr_blocks": &schema.Schema{
|
||||||
|
@ -373,7 +376,8 @@ func resourceAwsSecurityGroupRuleHash(v interface{}) int {
|
||||||
m := v.(map[string]interface{})
|
m := v.(map[string]interface{})
|
||||||
buf.WriteString(fmt.Sprintf("%d-", m["from_port"].(int)))
|
buf.WriteString(fmt.Sprintf("%d-", m["from_port"].(int)))
|
||||||
buf.WriteString(fmt.Sprintf("%d-", m["to_port"].(int)))
|
buf.WriteString(fmt.Sprintf("%d-", m["to_port"].(int)))
|
||||||
buf.WriteString(fmt.Sprintf("%s-", m["protocol"].(string)))
|
p := protocolForValue(m["protocol"].(string))
|
||||||
|
buf.WriteString(fmt.Sprintf("%s-", p))
|
||||||
buf.WriteString(fmt.Sprintf("%t-", m["self"].(bool)))
|
buf.WriteString(fmt.Sprintf("%t-", m["self"].(bool)))
|
||||||
|
|
||||||
// We need to make sure to sort the strings below so that we always
|
// We need to make sure to sort the strings below so that we always
|
||||||
|
@ -824,3 +828,51 @@ func idHash(rType, protocol string, toPort, fromPort int64, self bool) string {
|
||||||
|
|
||||||
return fmt.Sprintf("rule-%d", hashcode.String(buf.String()))
|
return fmt.Sprintf("rule-%d", hashcode.String(buf.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// protocolStateFunc ensures we only store a string in any protocol field
|
||||||
|
func protocolStateFunc(v interface{}) string {
|
||||||
|
switch v.(type) {
|
||||||
|
case string:
|
||||||
|
p := protocolForValue(v.(string))
|
||||||
|
return p
|
||||||
|
default:
|
||||||
|
log.Printf("[WARN] Non String value given for Protocol: %#v", v)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// protocolForValue converts a valid Internet Protocol number into it's name
|
||||||
|
// representation. If a name is given, it validates that it's a proper protocol
|
||||||
|
// name. Names/numbers are as defined at
|
||||||
|
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
|
||||||
|
func protocolForValue(v string) string {
|
||||||
|
// special case -1
|
||||||
|
protocol := strings.ToLower(v)
|
||||||
|
if protocol == "-1" || protocol == "all" {
|
||||||
|
return "-1"
|
||||||
|
}
|
||||||
|
// if it's a name like tcp, return that
|
||||||
|
if _, ok := protocolIntegers()[protocol]; ok {
|
||||||
|
return protocol
|
||||||
|
}
|
||||||
|
// convert to int, look for that value
|
||||||
|
p, err := strconv.Atoi(protocol)
|
||||||
|
if err != nil {
|
||||||
|
// we were unable to convert to int, suggesting a string name, but it wasn't
|
||||||
|
// found above
|
||||||
|
log.Printf("[WARN] Unable to determine valid protocol: %s", err)
|
||||||
|
return protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range protocolIntegers() {
|
||||||
|
if p == v {
|
||||||
|
// guard against protocolIntegers sometime in the future not having lower
|
||||||
|
// case ids in the map
|
||||||
|
return strings.ToLower(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fall through
|
||||||
|
log.Printf("[WARN] Unable to determine valid protocol: no matching protocols found")
|
||||||
|
return protocol
|
||||||
|
}
|
||||||
|
|
|
@ -44,9 +44,10 @@ func resourceAwsSecurityGroupRule() *schema.Resource {
|
||||||
},
|
},
|
||||||
|
|
||||||
"protocol": &schema.Schema{
|
"protocol": &schema.Schema{
|
||||||
Type: schema.TypeString,
|
Type: schema.TypeString,
|
||||||
Required: true,
|
Required: true,
|
||||||
ForceNew: true,
|
ForceNew: true,
|
||||||
|
StateFunc: protocolStateFunc,
|
||||||
},
|
},
|
||||||
|
|
||||||
"cidr_blocks": &schema.Schema{
|
"cidr_blocks": &schema.Schema{
|
||||||
|
@ -411,7 +412,8 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss
|
||||||
|
|
||||||
perm.FromPort = aws.Int64(int64(d.Get("from_port").(int)))
|
perm.FromPort = aws.Int64(int64(d.Get("from_port").(int)))
|
||||||
perm.ToPort = aws.Int64(int64(d.Get("to_port").(int)))
|
perm.ToPort = aws.Int64(int64(d.Get("to_port").(int)))
|
||||||
perm.IpProtocol = aws.String(d.Get("protocol").(string))
|
protocol := protocolForValue(d.Get("protocol").(string))
|
||||||
|
perm.IpProtocol = aws.String(protocol)
|
||||||
|
|
||||||
// build a group map that behaves like a set
|
// build a group map that behaves like a set
|
||||||
groups := make(map[string]bool)
|
groups := make(map[string]bool)
|
||||||
|
|
|
@ -144,6 +144,43 @@ func TestAccAWSSecurityGroupRule_Ingress_VPC(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccAWSSecurityGroupRule_Ingress_Protocol(t *testing.T) {
|
||||||
|
var group ec2.SecurityGroup
|
||||||
|
|
||||||
|
testRuleCount := func(*terraform.State) error {
|
||||||
|
if len(group.IpPermissions) != 1 {
|
||||||
|
return fmt.Errorf("Wrong Security Group rule count, expected %d, got %d",
|
||||||
|
1, len(group.IpPermissions))
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := group.IpPermissions[0]
|
||||||
|
if *rule.FromPort != int64(80) {
|
||||||
|
return fmt.Errorf("Wrong Security Group port setting, expected %d, got %d",
|
||||||
|
80, int(*rule.FromPort))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resource.Test(t, resource.TestCase{
|
||||||
|
PreCheck: func() { testAccPreCheck(t) },
|
||||||
|
Providers: testAccProviders,
|
||||||
|
CheckDestroy: testAccCheckAWSSecurityGroupRuleDestroy,
|
||||||
|
Steps: []resource.TestStep{
|
||||||
|
resource.TestStep{
|
||||||
|
Config: testAccAWSSecurityGroupRuleIngress_protocolConfig,
|
||||||
|
Check: resource.ComposeTestCheckFunc(
|
||||||
|
testAccCheckAWSSecurityGroupRuleExists("aws_security_group.web", &group),
|
||||||
|
testAccCheckAWSSecurityGroupRuleAttributes("aws_security_group_rule.ingress_1", &group, nil, "ingress"),
|
||||||
|
resource.TestCheckResourceAttr(
|
||||||
|
"aws_security_group_rule.ingress_1", "from_port", "80"),
|
||||||
|
testRuleCount,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccAWSSecurityGroupRule_Ingress_Classic(t *testing.T) {
|
func TestAccAWSSecurityGroupRule_Ingress_Classic(t *testing.T) {
|
||||||
var group ec2.SecurityGroup
|
var group ec2.SecurityGroup
|
||||||
|
|
||||||
|
@ -545,6 +582,35 @@ resource "aws_security_group_rule" "ingress_1" {
|
||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const testAccAWSSecurityGroupRuleIngress_protocolConfig = `
|
||||||
|
resource "aws_vpc" "tftest" {
|
||||||
|
cidr_block = "10.0.0.0/16"
|
||||||
|
|
||||||
|
tags {
|
||||||
|
Name = "tf-testing"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_security_group" "web" {
|
||||||
|
vpc_id = "${aws_vpc.tftest.id}"
|
||||||
|
|
||||||
|
tags {
|
||||||
|
Name = "tf-acc-test"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_security_group_rule" "ingress_1" {
|
||||||
|
type = "ingress"
|
||||||
|
protocol = "6"
|
||||||
|
from_port = 80
|
||||||
|
to_port = 8000
|
||||||
|
cidr_blocks = ["10.0.0.0/8"]
|
||||||
|
|
||||||
|
security_group_id = "${aws_security_group.web.id}"
|
||||||
|
}
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
const testAccAWSSecurityGroupRuleIssue5310 = `
|
const testAccAWSSecurityGroupRuleIssue5310 = `
|
||||||
provider "aws" {
|
provider "aws" {
|
||||||
region = "us-east-1"
|
region = "us-east-1"
|
||||||
|
|
|
@ -15,6 +15,123 @@ import (
|
||||||
"github.com/hashicorp/terraform/terraform"
|
"github.com/hashicorp/terraform/terraform"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestProtocolStateFunc(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
input interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
input: "tcp",
|
||||||
|
expected: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: 6,
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "17",
|
||||||
|
expected: "udp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "all",
|
||||||
|
expected: "-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "-1",
|
||||||
|
expected: "-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: -1,
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "1",
|
||||||
|
expected: "icmp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "icmp",
|
||||||
|
expected: "icmp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: 1,
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
result := protocolStateFunc(c.input)
|
||||||
|
if result != c.expected {
|
||||||
|
t.Errorf("Error matching protocol, expected (%s), got (%s)", c.expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtocolForValue(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
input: "tcp",
|
||||||
|
expected: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "6",
|
||||||
|
expected: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "udp",
|
||||||
|
expected: "udp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "17",
|
||||||
|
expected: "udp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "all",
|
||||||
|
expected: "-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "-1",
|
||||||
|
expected: "-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "tCp",
|
||||||
|
expected: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "6",
|
||||||
|
expected: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "UDp",
|
||||||
|
expected: "udp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "17",
|
||||||
|
expected: "udp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "ALL",
|
||||||
|
expected: "-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "icMp",
|
||||||
|
expected: "icmp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "1",
|
||||||
|
expected: "icmp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
result := protocolForValue(c.input)
|
||||||
|
if result != c.expected {
|
||||||
|
t.Errorf("Error matching protocol, expected (%s), got (%s)", c.expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResourceAwsSecurityGroupIPPermGather(t *testing.T) {
|
func TestResourceAwsSecurityGroupIPPermGather(t *testing.T) {
|
||||||
raw := []*ec2.IpPermission{
|
raw := []*ec2.IpPermission{
|
||||||
&ec2.IpPermission{
|
&ec2.IpPermission{
|
||||||
|
@ -846,7 +963,7 @@ resource "aws_security_group" "web" {
|
||||||
description = "Used in the terraform acceptance tests"
|
description = "Used in the terraform acceptance tests"
|
||||||
|
|
||||||
ingress {
|
ingress {
|
||||||
protocol = "tcp"
|
protocol = "6"
|
||||||
from_port = 80
|
from_port = 80
|
||||||
to_port = 8000
|
to_port = 8000
|
||||||
cidr_blocks = ["10.0.0.0/8"]
|
cidr_blocks = ["10.0.0.0/8"]
|
||||||
|
|
Loading…
Reference in New Issue