Merge pull request #113 from slackhq/fw-ca
Fixes the issues with caSha and caName
This commit is contained in:
		| @@ -141,7 +141,7 @@ firewall: | ||||
|  | ||||
|   # The firewall is default deny. There is no way to write a deny rule. | ||||
|   # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR | ||||
|   # Logical evaluation is roughly: port AND proto AND ca_sha AND ca_name AND (host OR group OR groups OR cidr) | ||||
|   # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) | ||||
|   # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). | ||||
|   #   code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` | ||||
|   #   proto: `any`, `tcp`, `udp`, or `icmp` | ||||
|   | ||||
							
								
								
									
										125
									
								
								firewall.go
									
									
									
									
									
								
							
							
						
						
									
										125
									
								
								firewall.go
									
									
									
									
									
								
							| @@ -83,19 +83,23 @@ func newFirewallTable() *FirewallTable { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type FirewallCA struct { | ||||
| 	Any     *FirewallRule | ||||
| 	CANames map[string]*FirewallRule | ||||
| 	CAShas  map[string]*FirewallRule | ||||
| } | ||||
|  | ||||
| type FirewallRule struct { | ||||
| 	// Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked | ||||
| 	Any     bool | ||||
| 	Hosts   map[string]struct{} | ||||
| 	Groups  [][]string | ||||
| 	CIDR    *CIDRTree | ||||
| 	CANames map[string]struct{} | ||||
| 	CAShas  map[string]struct{} | ||||
| 	// Any makes Hosts, Groups, and CIDR irrelevant | ||||
| 	Any    bool | ||||
| 	Hosts  map[string]struct{} | ||||
| 	Groups [][]string | ||||
| 	CIDR   *CIDRTree | ||||
| } | ||||
|  | ||||
| // Even though ports are uint16, int32 maps are faster for lookup | ||||
| // Plus we can use `-1` for fragment rules | ||||
| type firewallPort map[int32]*FirewallRule | ||||
| type firewallPort map[int32]*FirewallCA | ||||
|  | ||||
| type FirewallPacket struct { | ||||
| 	LocalIP    uint32 | ||||
| @@ -182,9 +186,9 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N | ||||
|  | ||||
| func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) { | ||||
| 	fw := NewFirewall( | ||||
| 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)), | ||||
| 		c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)), | ||||
| 		c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)), | ||||
| 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), | ||||
| 		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), | ||||
| 		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), | ||||
| 		nc, | ||||
| 		//TODO: max_connections | ||||
| 	) | ||||
| @@ -499,12 +503,9 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, | ||||
|  | ||||
| 	for i := startPort; i <= endPort; i++ { | ||||
| 		if _, ok := fp[i]; !ok { | ||||
| 			fp[i] = &FirewallRule{ | ||||
| 				Groups:  make([][]string, 0), | ||||
| 				Hosts:   make(map[string]struct{}), | ||||
| 				CIDR:    NewCIDRTree(), | ||||
| 				CANames: make(map[string]struct{}), | ||||
| 				CAShas:  make(map[string]struct{}), | ||||
| 			fp[i] = &FirewallCA{ | ||||
| 				CANames: make(map[string]*FirewallRule), | ||||
| 				CAShas:  make(map[string]*FirewallRule), | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @@ -539,15 +540,70 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert | ||||
| 	return fp[fwPortAny].match(p, c, caPool) | ||||
| } | ||||
|  | ||||
| func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error { | ||||
| 	if caName != "" { | ||||
| 		fr.CANames[caName] = struct{}{} | ||||
| func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { | ||||
| 	fr := func() *FirewallRule { | ||||
| 		return &FirewallRule{ | ||||
| 			Hosts:  make(map[string]struct{}), | ||||
| 			Groups: make([][]string, 0), | ||||
| 			CIDR:   NewCIDRTree(), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if caSha == "" && caName == "" { | ||||
| 		if fc.Any == nil { | ||||
| 			fc.Any = fr() | ||||
| 		} | ||||
|  | ||||
| 		return fc.Any.addRule(groups, host, ip) | ||||
| 	} | ||||
|  | ||||
| 	if caSha != "" { | ||||
| 		fr.CAShas[caSha] = struct{}{} | ||||
| 		if _, ok := fc.CAShas[caSha]; !ok { | ||||
| 			fc.CAShas[caSha] = fr() | ||||
| 		} | ||||
| 		err := fc.CAShas[caSha].addRule(groups, host, ip) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if caName != "" { | ||||
| 		if _, ok := fc.CANames[caName]; !ok { | ||||
| 			fc.CANames[caName] = fr() | ||||
| 		} | ||||
| 		err := fc.CANames[caName].addRule(groups, host, ip) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { | ||||
| 	if fc == nil { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if fc.Any.match(p, c) { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	if t, ok := fc.CAShas[c.Details.Issuer]; ok { | ||||
| 		if t.match(p, c) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	s, err := caPool.GetCAForCert(c) | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	return fc.CANames[s.Details.Name].match(p, c) | ||||
| } | ||||
|  | ||||
| func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error { | ||||
| 	if fr.Any { | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -576,6 +632,10 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caN | ||||
| } | ||||
|  | ||||
| func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { | ||||
| 	if len(groups) == 0 && host == "" && ip == nil { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	for _, group := range groups { | ||||
| 		if group == "any" { | ||||
| 			return true | ||||
| @@ -593,28 +653,11 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { | ||||
| func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool { | ||||
| 	if fr == nil { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// CASha and CAName always need to be checked | ||||
| 	if len(fr.CAShas) > 0 { | ||||
| 		if _, ok := fr.CAShas[c.Details.Issuer]; !ok { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(fr.CANames) > 0 { | ||||
| 		s, err := caPool.GetCAForCert(c) | ||||
| 		if err != nil { | ||||
| 			return false | ||||
| 		} | ||||
| 		if _, ok := fr.CANames[s.Details.Name]; !ok { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Shortcut path for if groups, hosts, or cidr contained an `any` | ||||
| 	if fr.Any { | ||||
| 		return true | ||||
| @@ -773,7 +816,7 @@ func setTCPRTTTracking(c *conn, p []byte) { | ||||
| 	ihl := int(p[0]&0x0f) << 2 | ||||
|  | ||||
| 	// Don't track FIN packets | ||||
| 	if uint8(p[ihl+13])&tcpFIN != 0 { | ||||
| 	if p[ihl+13]&tcpFIN != 0 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -787,7 +830,7 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool { | ||||
| 	} | ||||
|  | ||||
| 	ihl := int(p[0]&0x0f) << 2 | ||||
| 	if uint8(p[ihl+13])&tcpACK == 0 { | ||||
| 	if p[ihl+13]&tcpACK == 0 { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
|   | ||||
							
								
								
									
										226
									
								
								firewall_test.go
									
									
									
									
									
								
							
							
						
						
									
										226
									
								
								firewall_test.go
									
									
									
									
									
								
							| @@ -51,6 +51,11 @@ func TestNewFirewall(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestFirewall_AddRule(t *testing.T) { | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	out := l.Out | ||||
| 	l.SetOutput(ob) | ||||
| 	defer l.SetOutput(out) | ||||
|  | ||||
| 	c := &cert.NebulaCertificate{} | ||||
| 	fw := NewFirewall(time.Second, time.Minute, time.Hour, c) | ||||
| 	assert.NotNil(t, fw.InRules) | ||||
| @@ -59,39 +64,38 @@ func TestFirewall_AddRule(t *testing.T) { | ||||
| 	_, ti, _ := net.ParseCIDR("1.2.3.4/32") | ||||
|  | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", "")) | ||||
| 	// Make sure an empty rule creates structure but doesn't allow anything to flow | ||||
| 	//TODO: ideally an empty rule would return an error | ||||
| 	assert.False(t, fw.InRules.TCP[1].Any) | ||||
| 	assert.Empty(t, fw.InRules.TCP[1].Groups) | ||||
| 	assert.Empty(t, fw.InRules.TCP[1].Hosts) | ||||
| 	assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left) | ||||
| 	assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right) | ||||
| 	assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value) | ||||
| 	// An empty rule is any | ||||
| 	assert.True(t, fw.InRules.TCP[1].Any.Any) | ||||
| 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups) | ||||
| 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) | ||||
| 	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left) | ||||
| 	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right) | ||||
| 	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value) | ||||
|  | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) | ||||
| 	assert.False(t, fw.InRules.UDP[1].Any) | ||||
| 	assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1") | ||||
| 	assert.Empty(t, fw.InRules.UDP[1].Hosts) | ||||
| 	assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left) | ||||
| 	assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right) | ||||
| 	assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value) | ||||
| 	assert.False(t, fw.InRules.UDP[1].Any.Any) | ||||
| 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") | ||||
| 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) | ||||
| 	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left) | ||||
| 	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right) | ||||
| 	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value) | ||||
|  | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) | ||||
| 	assert.False(t, fw.InRules.ICMP[1].Any) | ||||
| 	assert.Empty(t, fw.InRules.ICMP[1].Groups) | ||||
| 	assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1") | ||||
| 	assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left) | ||||
| 	assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right) | ||||
| 	assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value) | ||||
| 	assert.False(t, fw.InRules.ICMP[1].Any.Any) | ||||
| 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) | ||||
| 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") | ||||
| 	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left) | ||||
| 	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right) | ||||
| 	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value) | ||||
|  | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c) | ||||
| 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", "")) | ||||
| 	assert.False(t, fw.OutRules.AnyProto[1].Any) | ||||
| 	assert.Empty(t, fw.OutRules.AnyProto[1].Groups) | ||||
| 	assert.Empty(t, fw.OutRules.AnyProto[1].Hosts) | ||||
| 	assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP))) | ||||
| 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any) | ||||
| 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) | ||||
| 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) | ||||
| 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP))) | ||||
|  | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) | ||||
| @@ -104,28 +108,29 @@ func TestFirewall_AddRule(t *testing.T) { | ||||
| 	// Set any and clear fields | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c) | ||||
| 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) | ||||
| 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0]) | ||||
| 	assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1") | ||||
| 	assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP))) | ||||
| 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) | ||||
| 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") | ||||
| 	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP))) | ||||
|  | ||||
| 	// run twice just to make sure | ||||
| 	//TODO: these ANY rules should clear the CA firewall portion | ||||
| 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) | ||||
| 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) | ||||
| 	assert.True(t, fw.OutRules.AnyProto[0].Any) | ||||
| 	assert.Empty(t, fw.OutRules.AnyProto[0].Groups) | ||||
| 	assert.Empty(t, fw.OutRules.AnyProto[0].Hosts) | ||||
| 	assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left) | ||||
| 	assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right) | ||||
| 	assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value) | ||||
| 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any) | ||||
| 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups) | ||||
| 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts) | ||||
| 	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left) | ||||
| 	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right) | ||||
| 	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value) | ||||
|  | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c) | ||||
| 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) | ||||
| 	assert.True(t, fw.OutRules.AnyProto[0].Any) | ||||
| 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any) | ||||
|  | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c) | ||||
| 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0") | ||||
| 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", "")) | ||||
| 	assert.True(t, fw.OutRules.AnyProto[0].Any) | ||||
| 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any) | ||||
|  | ||||
| 	// Test error conditions | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c) | ||||
| @@ -134,6 +139,11 @@ func TestFirewall_AddRule(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestFirewall_Drop(t *testing.T) { | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	out := l.Out | ||||
| 	l.SetOutput(ob) | ||||
| 	defer l.SetOutput(out) | ||||
|  | ||||
| 	p := FirewallPacket{ | ||||
| 		ip2int(net.IPv4(1, 2, 3, 4)), | ||||
| 		ip2int(net.IPv4(1, 2, 3, 4)), | ||||
| @@ -150,10 +160,11 @@ func TestFirewall_Drop(t *testing.T) { | ||||
|  | ||||
| 	c := cert.NebulaCertificate{ | ||||
| 		Details: cert.NebulaCertificateDetails{ | ||||
| 			Name:   "host1", | ||||
| 			Ips:    []*net.IPNet{&ipNet}, | ||||
| 			Groups: []string{"default-group"}, | ||||
| 			Issuer: "signer-shasum", | ||||
| 			Name:           "host1", | ||||
| 			Ips:            []*net.IPNet{&ipNet}, | ||||
| 			Groups:         []string{"default-group"}, | ||||
| 			InvertedGroups: map[string]struct{}{"default-group": {}}, | ||||
| 			Issuer:         "signer-shasum", | ||||
| 		}, | ||||
| 	} | ||||
| 	h := HostInfo{ | ||||
| @@ -170,6 +181,7 @@ func TestFirewall_Drop(t *testing.T) { | ||||
| 	// Drop outbound | ||||
| 	assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) | ||||
| 	// Allow inbound | ||||
| 	resetConntrack(fw) | ||||
| 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) | ||||
| 	// Allow outbound because conntrack | ||||
| 	assert.False(t, fw.Drop([]byte{}, p, false, &h, cp)) | ||||
| @@ -180,27 +192,31 @@ func TestFirewall_Drop(t *testing.T) { | ||||
| 	assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) | ||||
| 	p.RemoteIP = oldRemote | ||||
|  | ||||
| 	// test caSha assertions true | ||||
| 	// ensure signer doesn't get in the way of group checks | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum")) | ||||
| 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) | ||||
|  | ||||
| 	// test caSha assertions false | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope")) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) | ||||
| 	assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) | ||||
|  | ||||
| 	// test caName true | ||||
| 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} | ||||
| 	// test caSha doesn't drop on match | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", "")) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) | ||||
| 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) | ||||
|  | ||||
| 	// test caName false | ||||
| 	// ensure ca name doesn't get in the way of group checks | ||||
| 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", "")) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) | ||||
| 	assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) | ||||
|  | ||||
| 	// test caName doesn't drop on match | ||||
| 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} | ||||
| 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) | ||||
| 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) | ||||
| } | ||||
|  | ||||
| func BenchmarkFirewallTable_match(b *testing.B) { | ||||
| @@ -209,11 +225,11 @@ func BenchmarkFirewallTable_match(b *testing.B) { | ||||
| 	} | ||||
|  | ||||
| 	_, n, _ := net.ParseCIDR("172.1.1.1/32") | ||||
| 	ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") | ||||
| 	ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") | ||||
| 	ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") | ||||
| 	ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") | ||||
| 	ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") | ||||
| 	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") | ||||
| 	_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") | ||||
| 	_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") | ||||
| 	_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") | ||||
| 	_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") | ||||
| 	cp := cert.NewCAPool() | ||||
|  | ||||
| 	b.Run("fail on proto", func(b *testing.B) { | ||||
| @@ -281,7 +297,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") | ||||
| 	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") | ||||
|  | ||||
| 	b.Run("pass on ip with any port", func(b *testing.B) { | ||||
| 		ip := ip2int(net.IPv4(172, 1, 1, 1)) | ||||
| @@ -298,6 +314,11 @@ func BenchmarkFirewallTable_match(b *testing.B) { | ||||
| } | ||||
|  | ||||
| func TestFirewall_Drop2(t *testing.T) { | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	out := l.Out | ||||
| 	l.SetOutput(ob) | ||||
| 	defer l.SetOutput(out) | ||||
|  | ||||
| 	p := FirewallPacket{ | ||||
| 		ip2int(net.IPv4(1, 2, 3, 4)), | ||||
| 		ip2int(net.IPv4(1, 2, 3, 4)), | ||||
| @@ -347,9 +368,94 @@ func TestFirewall_Drop2(t *testing.T) { | ||||
| 	// h1/c1 lacks the proper groups | ||||
| 	assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp)) | ||||
| 	// c has the proper groups | ||||
| 	resetConntrack(fw) | ||||
| 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) | ||||
| } | ||||
|  | ||||
| func TestFirewall_Drop3(t *testing.T) { | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	out := l.Out | ||||
| 	l.SetOutput(ob) | ||||
| 	defer l.SetOutput(out) | ||||
|  | ||||
| 	p := FirewallPacket{ | ||||
| 		ip2int(net.IPv4(1, 2, 3, 4)), | ||||
| 		ip2int(net.IPv4(1, 2, 3, 4)), | ||||
| 		1, | ||||
| 		1, | ||||
| 		fwProtoUDP, | ||||
| 		false, | ||||
| 	} | ||||
|  | ||||
| 	ipNet := net.IPNet{ | ||||
| 		IP:   net.IPv4(1, 2, 3, 4), | ||||
| 		Mask: net.IPMask{255, 255, 255, 0}, | ||||
| 	} | ||||
|  | ||||
| 	c := cert.NebulaCertificate{ | ||||
| 		Details: cert.NebulaCertificateDetails{ | ||||
| 			Name: "host-owner", | ||||
| 			Ips:  []*net.IPNet{&ipNet}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	c1 := cert.NebulaCertificate{ | ||||
| 		Details: cert.NebulaCertificateDetails{ | ||||
| 			Name:   "host1", | ||||
| 			Ips:    []*net.IPNet{&ipNet}, | ||||
| 			Issuer: "signer-sha-bad", | ||||
| 		}, | ||||
| 	} | ||||
| 	h1 := HostInfo{ | ||||
| 		ConnectionState: &ConnectionState{ | ||||
| 			peerCert: &c1, | ||||
| 		}, | ||||
| 	} | ||||
| 	h1.CreateRemoteCIDR(&c1) | ||||
|  | ||||
| 	c2 := cert.NebulaCertificate{ | ||||
| 		Details: cert.NebulaCertificateDetails{ | ||||
| 			Name:   "host2", | ||||
| 			Ips:    []*net.IPNet{&ipNet}, | ||||
| 			Issuer: "signer-sha", | ||||
| 		}, | ||||
| 	} | ||||
| 	h2 := HostInfo{ | ||||
| 		ConnectionState: &ConnectionState{ | ||||
| 			peerCert: &c2, | ||||
| 		}, | ||||
| 	} | ||||
| 	h2.CreateRemoteCIDR(&c2) | ||||
|  | ||||
| 	c3 := cert.NebulaCertificate{ | ||||
| 		Details: cert.NebulaCertificateDetails{ | ||||
| 			Name:   "host3", | ||||
| 			Ips:    []*net.IPNet{&ipNet}, | ||||
| 			Issuer: "signer-sha-bad", | ||||
| 		}, | ||||
| 	} | ||||
| 	h3 := HostInfo{ | ||||
| 		ConnectionState: &ConnectionState{ | ||||
| 			peerCert: &c3, | ||||
| 		}, | ||||
| 	} | ||||
| 	h3.CreateRemoteCIDR(&c3) | ||||
|  | ||||
| 	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", "")) | ||||
| 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) | ||||
| 	cp := cert.NewCAPool() | ||||
|  | ||||
| 	// c1 should pass because host match | ||||
| 	assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp)) | ||||
| 	// c2 should pass because ca sha match | ||||
| 	resetConntrack(fw) | ||||
| 	assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp)) | ||||
| 	// c3 should fail because no match | ||||
| 	resetConntrack(fw) | ||||
| 	assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp)) | ||||
| } | ||||
|  | ||||
| func BenchmarkLookup(b *testing.B) { | ||||
| 	ml := func(m map[string]struct{}, a [][]string) { | ||||
| 		for n := 0; n < b.N; n++ { | ||||
| @@ -748,3 +854,9 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end | ||||
| 	mf.nextCallReturn = nil | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func resetConntrack(fw *Firewall) { | ||||
| 	fw.connMutex.Lock() | ||||
| 	fw.Conns = map[FirewallPacket]*conn{} | ||||
| 	fw.connMutex.Unlock() | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nathan Brown
					Nathan Brown