Consider security groups with source security groups when hashing

Previously they would conflict you had multiple security group rules
with the same ingress or egress ports but different source security
groups because only the CIDR blocks were considered (which are empty
when using source security groups).

Updated to include migrations (from clint@ctshryock.com)

Signed-off-by: Clint Shryock <clint@ctshryock.com>
This commit is contained in:
Jesse Szwedko 2015-06-11 14:57:43 -07:00 committed by Clint Shryock
parent 8814384dea
commit 7e0a340baf
3 changed files with 212 additions and 14 deletions

View File

@ -21,6 +21,9 @@ func resourceAwsSecurityGroupRule() *schema.Resource {
Read: resourceAwsSecurityGroupRuleRead,
Delete: resourceAwsSecurityGroupRuleDelete,
SchemaVersion: 1,
MigrateState: resourceAwsSecurityGroupRuleMigrateState,
Schema: map[string]*schema.Schema{
"type": &schema.Schema{
Type: schema.TypeString,
@ -48,10 +51,11 @@ func resourceAwsSecurityGroupRule() *schema.Resource {
},
"cidr_blocks": &schema.Schema{
Type: schema.TypeList,
Optional: true,
ForceNew: true,
Elem: &schema.Schema{Type: schema.TypeString},
Type: schema.TypeList,
Optional: true,
ForceNew: true,
Elem: &schema.Schema{Type: schema.TypeString},
ConflictsWith: []string{"source_security_group_id"},
},
"security_group_id": &schema.Schema{
@ -61,10 +65,11 @@ func resourceAwsSecurityGroupRule() *schema.Resource {
},
"source_security_group_id": &schema.Schema{
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Computed: true,
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Computed: true,
ConflictsWith: []string{"cidr_blocks"},
},
"self": &schema.Schema{
@ -263,15 +268,32 @@ func findResourceSecurityGroup(conn *ec2.EC2, id string) (*ec2.SecurityGroup, er
return resp.SecurityGroups[0], nil
}
// byUserIDAndGroup implements sort.Interface for []*ec2.UserIDGroupPairs based on
// UserID and then the GroupID or GroupName field (only one should be set).
type byUserIDAndGroup []*ec2.UserIDGroupPair
func (b byUserIDAndGroup) Len() int { return len(b) }
func (b byUserIDAndGroup) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
func (b byUserIDAndGroup) Less(i, j int) bool {
if b[i].GroupID != nil && b[j].GroupID != nil {
return *b[i].GroupID < *b[j].GroupID
}
if b[i].GroupName != nil && b[j].GroupName != nil {
return *b[i].GroupName < *b[j].GroupName
}
panic("mismatched security group rules, may be a terraform bug")
}
func ipPermissionIDHash(ruleType string, ip *ec2.IPPermission) string {
var buf bytes.Buffer
// for egress rules, an TCP rule of -1 is automatically added, in which case
// the to and from ports will be nil. We don't record this rule locally.
if ip.IPProtocol != nil && *ip.IPProtocol != "-1" {
if ip.FromPort != nil && *ip.FromPort > 0 {
buf.WriteString(fmt.Sprintf("%d-", *ip.FromPort))
buf.WriteString(fmt.Sprintf("%d-", *ip.ToPort))
buf.WriteString(fmt.Sprintf("%s-", *ip.IPProtocol))
}
if ip.ToPort != nil && *ip.ToPort > 0 {
buf.WriteString(fmt.Sprintf("%d-", *ip.ToPort))
}
buf.WriteString(fmt.Sprintf("%s-", *ip.IPProtocol))
buf.WriteString(fmt.Sprintf("%s-", ruleType))
// We need to make sure to sort the strings below so that we always
@ -288,6 +310,22 @@ func ipPermissionIDHash(ruleType string, ip *ec2.IPPermission) string {
}
}
if len(ip.UserIDGroupPairs) > 0 {
sort.Sort(byUserIDAndGroup(ip.UserIDGroupPairs))
for _, pair := range ip.UserIDGroupPairs {
if pair.GroupID != nil {
buf.WriteString(fmt.Sprintf("%s-", *pair.GroupID))
} else {
buf.WriteString("-")
}
if pair.GroupName != nil {
buf.WriteString(fmt.Sprintf("%s-", *pair.GroupName))
} else {
buf.WriteString("-")
}
}
}
return fmt.Sprintf("sg-%d", hashcode.String(buf.String()))
}

View File

@ -0,0 +1,118 @@
package aws
import (
"fmt"
"log"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/hashicorp/terraform/terraform"
)
func resourceAwsSecurityGroupRuleMigrateState(
v int, is *terraform.InstanceState, meta interface{}) (*terraform.InstanceState, error) {
switch v {
case 0:
log.Println("[INFO] Found AWS Security Group State v0; migrating to v1")
return migrateSGRuleStateV0toV1(is, meta)
default:
return is, fmt.Errorf("Unexpected schema version: %d", v)
}
return is, nil
}
func migrateSGRuleStateV0toV1(is *terraform.InstanceState, meta interface{}) (*terraform.InstanceState, error) {
if is.Empty() {
log.Println("[DEBUG] Empty InstanceState; nothing to migrate.")
return is, nil
}
conn := meta.(*AWSClient).ec2conn
sg_id := is.Attributes["security_group_id"]
sg, err := findResourceSecurityGroup(conn, sg_id)
if err != nil {
return nil, fmt.Errorf("[WARN] Error finding security group for Security Group migration")
}
perm, err := migrateExpandIPPerm(is.Attributes, sg)
if err != nil {
return nil, fmt.Errorf("[WARN] Error making new IP Permission in Security Group migration")
}
log.Printf("[DEBUG] Attributes before migration: %#v", is.Attributes)
newID := ipPermissionIDHash(is.Attributes["type"], perm)
is.Attributes["id"] = newID
is.ID = newID
log.Printf("[DEBUG] Attributes after migration: %#v, new id: %s", is.Attributes, newID)
return is, nil
}
func migrateExpandIPPerm(attrs map[string]string, sg *ec2.SecurityGroup) (*ec2.IPPermission, error) {
var perm ec2.IPPermission
tp, err := strconv.Atoi(attrs["to_port"])
if err != nil {
return nil, fmt.Errorf("Error converting to_port in Security Group migration")
}
fp, err := strconv.Atoi(attrs["from_port"])
if err != nil {
return nil, fmt.Errorf("Error converting from_port in Security Group migration")
}
perm.ToPort = aws.Long(int64(tp))
perm.FromPort = aws.Long(int64(fp))
perm.IPProtocol = aws.String(attrs["protocol"])
groups := make(map[string]bool)
if attrs["self"] == "true" {
if sg.VPCID != nil && *sg.VPCID != "" {
groups[*sg.GroupID] = true
} else {
groups[*sg.GroupName] = true
}
}
if attrs["source_security_group_id"] != "" {
groups[attrs["source_security_group_id"]] = true
}
if len(groups) > 0 {
perm.UserIDGroupPairs = make([]*ec2.UserIDGroupPair, len(groups))
// build string list of group name/ids
var gl []string
for k, _ := range groups {
gl = append(gl, k)
}
for i, name := range gl {
perm.UserIDGroupPairs[i] = &ec2.UserIDGroupPair{
GroupID: aws.String(name),
}
if sg.VPCID == nil || *sg.VPCID == "" {
perm.UserIDGroupPairs[i].GroupID = nil
perm.UserIDGroupPairs[i].GroupName = aws.String(name)
perm.UserIDGroupPairs[i].UserID = nil
}
}
}
var cb []string
for k, v := range attrs {
if k != "cidr_blocks.#" && strings.HasPrefix(k, "cidr_blocks") {
cb = append(cb, v)
}
}
if len(cb) > 0 {
perm.IPRanges = make([]*ec2.IPRange, len(cb))
for i, v := range cb {
perm.IPRanges[i] = &ec2.IPRange{CIDRIP: aws.String(v)}
}
}
return &perm, nil
}

View File

@ -44,6 +44,46 @@ func TestIpPermissionIDHash(t *testing.T) {
},
}
vpc_security_group_source := &ec2.IPPermission{
IPProtocol: aws.String("tcp"),
FromPort: aws.Long(int64(80)),
ToPort: aws.Long(int64(8000)),
UserIDGroupPairs: []*ec2.UserIDGroupPair{
&ec2.UserIDGroupPair{
UserID: aws.String("987654321"),
GroupID: aws.String("sg-12345678"),
},
&ec2.UserIDGroupPair{
UserID: aws.String("123456789"),
GroupID: aws.String("sg-987654321"),
},
&ec2.UserIDGroupPair{
UserID: aws.String("123456789"),
GroupID: aws.String("sg-12345678"),
},
},
}
security_group_source := &ec2.IPPermission{
IPProtocol: aws.String("tcp"),
FromPort: aws.Long(int64(80)),
ToPort: aws.Long(int64(8000)),
UserIDGroupPairs: []*ec2.UserIDGroupPair{
&ec2.UserIDGroupPair{
UserID: aws.String("987654321"),
GroupName: aws.String("my-security-group"),
},
&ec2.UserIDGroupPair{
UserID: aws.String("123456789"),
GroupName: aws.String("my-security-group"),
},
&ec2.UserIDGroupPair{
UserID: aws.String("123456789"),
GroupName: aws.String("my-other-security-group"),
},
},
}
// hardcoded hashes, to detect future change
cases := []struct {
Input *ec2.IPPermission
@ -53,12 +93,14 @@ func TestIpPermissionIDHash(t *testing.T) {
{simple, "ingress", "sg-82613597"},
{egress, "egress", "sg-363054720"},
{egress_all, "egress", "sg-857124156"},
{vpc_security_group_source, "egress", "sg-1900174468"},
{security_group_source, "egress", "sg-1409919872"},
}
for _, tc := range cases {
actual := ipPermissionIDHash(tc.Type, tc.Input)
if actual != tc.Output {
t.Fatalf("input: %s - %#v\noutput: %s", tc.Type, tc.Input, actual)
t.Errorf("input: %s - %#v\noutput: %s", tc.Type, tc.Input, actual)
}
}
}