diff --git a/firewall.go b/firewall.go index 763a66d..1c5ec9b 100644 --- a/firewall.go +++ b/firewall.go @@ -83,19 +83,23 @@ func newFirewallTable() *FirewallTable { } } +type FirewallCA struct { + Any *FirewallRule + CANames map[string]*FirewallRule + CAShas map[string]*FirewallRule +} + type FirewallRule struct { - // Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked - Any bool - Hosts map[string]struct{} - Groups [][]string - CIDR *CIDRTree - CANames map[string]struct{} - CAShas map[string]struct{} + // Any makes Hosts, Groups, and CIDR irrelevant + Any bool + Hosts map[string]struct{} + Groups [][]string + CIDR *CIDRTree } // Even though ports are uint16, int32 maps are faster for lookup // Plus we can use `-1` for fragment rules -type firewallPort map[int32]*FirewallRule +type firewallPort map[int32]*FirewallCA type FirewallPacket struct { LocalIP uint32 @@ -182,9 +186,9 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) { fw := NewFirewall( - c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)), - c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)), - c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)), + c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), + c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), + c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), nc, //TODO: max_connections ) @@ -499,12 +503,9 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, for i := startPort; i <= endPort; i++ { if _, ok := fp[i]; !ok { - fp[i] = &FirewallRule{ - Groups: make([][]string, 0), - Hosts: make(map[string]struct{}), - CIDR: NewCIDRTree(), - CANames: make(map[string]struct{}), - CAShas: make(map[string]struct{}), + fp[i] = &FirewallCA{ + CANames: make(map[string]*FirewallRule), + CAShas: make(map[string]*FirewallRule), } } @@ -539,15 +540,83 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert return fp[fwPortAny].match(p, c, caPool) } -func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error { - if caName != "" { - fr.CANames[caName] = struct{}{} +func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { + // If there is an any rule then there is no need to establish specific ca rules + if fc.Any != nil { + return fc.Any.addRule(groups, host, ip) + } + + fr := func() *FirewallRule { + return &FirewallRule{ + Hosts: make(map[string]struct{}), + Groups: make([][]string, 0), + CIDR: NewCIDRTree(), + } + } + + any := false + if caSha == "" && caName == "" { + any = true + } + + if any { + if fc.Any == nil { + fc.Any = fr() + } + + // If it's any we need to wipe out any pre-existing rules to save on memory + fc.CAShas = make(map[string]*FirewallRule) + fc.CANames = make(map[string]*FirewallRule) + return fc.Any.addRule(groups, host, ip) } if caSha != "" { - fr.CAShas[caSha] = struct{}{} + if _, ok := fc.CAShas[caSha]; !ok { + fc.CAShas[caSha] = fr() + } + err := fc.CAShas[caSha].addRule(groups, host, ip) + if err != nil { + return err + } } + if caName != "" { + if _, ok := fc.CANames[caName]; !ok { + fc.CANames[caName] = fr() + } + err := fc.CANames[caName].addRule(groups, host, ip) + if err != nil { + return err + } + } + + return nil +} + +func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { + if fc == nil { + return false + } + + if fc.Any != nil { + return fc.Any.match(p, c) + } + + if t, ok := fc.CAShas[c.Details.Issuer]; ok { + if t.match(p, c) { + return true + } + } + + s, err := caPool.GetCAForCert(c) + if err != nil { + return false + } + + return fc.CANames[s.Details.Name].match(p, c) +} + +func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error { if fr.Any { return nil } @@ -593,28 +662,11 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool return false } -func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool { if fr == nil { return false } - // CASha and CAName always need to be checked - if len(fr.CAShas) > 0 { - if _, ok := fr.CAShas[c.Details.Issuer]; !ok { - return false - } - } - - if len(fr.CANames) > 0 { - s, err := caPool.GetCAForCert(c) - if err != nil { - return false - } - if _, ok := fr.CANames[s.Details.Name]; !ok { - return false - } - } - // Shortcut path for if groups, hosts, or cidr contained an `any` if fr.Any { return true @@ -773,7 +825,7 @@ func setTCPRTTTracking(c *conn, p []byte) { ihl := int(p[0]&0x0f) << 2 // Don't track FIN packets - if uint8(p[ihl+13])&tcpFIN != 0 { + if p[ihl+13]&tcpFIN != 0 { return } @@ -787,7 +839,7 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool { } ihl := int(p[0]&0x0f) << 2 - if uint8(p[ihl+13])&tcpACK == 0 { + if p[ihl+13]&tcpACK == 0 { return false } diff --git a/firewall_test.go b/firewall_test.go index 371bb91..b897b44 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "math" "net" "testing" @@ -61,37 +62,37 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", "")) // Make sure an empty rule creates structure but doesn't allow anything to flow //TODO: ideally an empty rule would return an error - assert.False(t, fw.InRules.TCP[1].Any) - assert.Empty(t, fw.InRules.TCP[1].Groups) - assert.Empty(t, fw.InRules.TCP[1].Hosts) - assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left) - assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right) - assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value) + assert.False(t, fw.InRules.TCP[1].Any.Any) + assert.Empty(t, fw.InRules.TCP[1].Any.Groups) + assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) + assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left) + assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right) + assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) - assert.False(t, fw.InRules.UDP[1].Any) - assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1") - assert.Empty(t, fw.InRules.UDP[1].Hosts) - assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left) - assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right) - assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value) + assert.False(t, fw.InRules.UDP[1].Any.Any) + assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") + assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) + assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left) + assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right) + assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) - assert.False(t, fw.InRules.ICMP[1].Any) - assert.Empty(t, fw.InRules.ICMP[1].Groups) - assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1") - assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left) - assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right) - assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value) + assert.False(t, fw.InRules.ICMP[1].Any.Any) + assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) + assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") + assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left) + assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right) + assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", "")) - assert.False(t, fw.OutRules.AnyProto[1].Any) - assert.Empty(t, fw.OutRules.AnyProto[1].Groups) - assert.Empty(t, fw.OutRules.AnyProto[1].Hosts) - assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP))) + assert.False(t, fw.OutRules.AnyProto[1].Any.Any) + assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) + assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) + assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP))) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) @@ -104,28 +105,30 @@ func TestFirewall_AddRule(t *testing.T) { // Set any and clear fields fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) - assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0]) - assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1") - assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP))) + assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) + assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") + assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP))) // run twice just to make sure + //TODO: these ANY rules should clear the CA firewall portion assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any) - assert.Empty(t, fw.OutRules.AnyProto[0].Groups) - assert.Empty(t, fw.OutRules.AnyProto[0].Hosts) - assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left) - assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right) - assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any) + assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups) + assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts) + assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left) + assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right) + assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value) + fmt.Printf("%+v\n", fw.OutRules.AnyProto[0]) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any) // Test error conditions fw = NewFirewall(time.Second, time.Minute, time.Hour, c) @@ -209,11 +212,11 @@ func BenchmarkFirewallTable_match(b *testing.B) { } _, n, _ := net.ParseCIDR("172.1.1.1/32") - ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") - ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") - ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") - ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") - ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { @@ -281,7 +284,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { } }) - ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") + _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") b.Run("pass on ip with any port", func(b *testing.B) { ip := ip2int(net.IPv4(172, 1, 1, 1))