diff --git a/firewall.go b/firewall.go index 1c5ec9b..45373b6 100644 --- a/firewall.go +++ b/firewall.go @@ -541,11 +541,6 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert } func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { - // If there is an any rule then there is no need to establish specific ca rules - if fc.Any != nil { - return fc.Any.addRule(groups, host, ip) - } - fr := func() *FirewallRule { return &FirewallRule{ Hosts: make(map[string]struct{}), @@ -554,19 +549,11 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam } } - any := false if caSha == "" && caName == "" { - any = true - } - - if any { if fc.Any == nil { fc.Any = fr() } - // If it's any we need to wipe out any pre-existing rules to save on memory - fc.CAShas = make(map[string]*FirewallRule) - fc.CANames = make(map[string]*FirewallRule) return fc.Any.addRule(groups, host, ip) } @@ -598,8 +585,8 @@ func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool return false } - if fc.Any != nil { - return fc.Any.match(p, c) + if fc.Any.match(p, c) { + return true } if t, ok := fc.CAShas[c.Details.Issuer]; ok { @@ -645,6 +632,10 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err } func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { + if len(groups) == 0 && host == "" && ip == nil { + return true + } + for _, group := range groups { if group == "any" { return true diff --git a/firewall_test.go b/firewall_test.go index 3c6025f..ceb589d 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -64,9 +64,8 @@ func TestFirewall_AddRule(t *testing.T) { _, ti, _ := net.ParseCIDR("1.2.3.4/32") assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", "")) - // Make sure an empty rule creates structure but doesn't allow anything to flow - //TODO: ideally an empty rule would return an error - assert.False(t, fw.InRules.TCP[1].Any.Any) + // An empty rule is any + assert.True(t, fw.InRules.TCP[1].Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left) @@ -182,6 +181,7 @@ func TestFirewall_Drop(t *testing.T) { // Drop outbound assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) // Allow inbound + resetConntrack(fw) assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) // Allow outbound because conntrack assert.False(t, fw.Drop([]byte{}, p, false, &h, cp)) @@ -368,9 +368,94 @@ func TestFirewall_Drop2(t *testing.T) { // h1/c1 lacks the proper groups assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp)) // c has the proper groups + resetConntrack(fw) assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) } +func TestFirewall_Drop3(t *testing.T) { + ob := &bytes.Buffer{} + out := l.Out + l.SetOutput(ob) + defer l.SetOutput(out) + + p := FirewallPacket{ + ip2int(net.IPv4(1, 2, 3, 4)), + ip2int(net.IPv4(1, 2, 3, 4)), + 1, + 1, + fwProtoUDP, + false, + } + + ipNet := net.IPNet{ + IP: net.IPv4(1, 2, 3, 4), + Mask: net.IPMask{255, 255, 255, 0}, + } + + c := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host-owner", + Ips: []*net.IPNet{&ipNet}, + }, + } + + c1 := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host1", + Ips: []*net.IPNet{&ipNet}, + Issuer: "signer-sha-bad", + }, + } + h1 := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c1, + }, + } + h1.CreateRemoteCIDR(&c1) + + c2 := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host2", + Ips: []*net.IPNet{&ipNet}, + Issuer: "signer-sha", + }, + } + h2 := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c2, + }, + } + h2.CreateRemoteCIDR(&c2) + + c3 := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host3", + Ips: []*net.IPNet{&ipNet}, + Issuer: "signer-sha-bad", + }, + } + h3 := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c3, + }, + } + h3.CreateRemoteCIDR(&c3) + + fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", "")) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) + cp := cert.NewCAPool() + + // c1 should pass because host match + assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp)) + // c2 should pass because ca sha match + resetConntrack(fw) + assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp)) + // c3 should fail because no match + resetConntrack(fw) + assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp)) +} + func BenchmarkLookup(b *testing.B) { ml := func(m map[string]struct{}, a [][]string) { for n := 0; n < b.N; n++ { @@ -769,3 +854,9 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end mf.nextCallReturn = nil return err } + +func resetConntrack(fw *Firewall) { + fw.connMutex.Lock() + fw.Conns = map[FirewallPacket]*conn{} + fw.connMutex.Unlock() +}