diff --git a/firewall.go b/firewall.go index 42919fc..f09a701 100644 --- a/firewall.go +++ b/firewall.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/rcrowley/go-metrics" @@ -372,9 +373,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error { +func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(packet, fp, incoming, h, caPool) { + if f.inConns(packet, fp, incoming, h, caPool, localCache) { return nil } @@ -426,7 +427,12 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion)) } -func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool { +func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool { + if localCache != nil { + if _, ok := localCache[fp]; ok { + return true + } + } conntrack := f.Conntrack conntrack.Lock() @@ -494,6 +500,10 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H conntrack.Unlock() + if localCache != nil { + localCache[fp] = struct{}{} + } + return true } @@ -923,3 +933,54 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool { c.Seq = 0 return true } + +// ConntrackCache is used as a local routine cache to know if a given flow +// has been seen in the conntrack table. +type ConntrackCache map[FirewallPacket]struct{} + +type ConntrackCacheTicker struct { + cacheV uint64 + cacheTick uint64 + + cache ConntrackCache +} + +func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { + if d == 0 { + return nil + } + + c := &ConntrackCacheTicker{ + cache: ConntrackCache{}, + } + + go c.tick(d) + + return c +} + +func (c *ConntrackCacheTicker) tick(d time.Duration) { + for { + time.Sleep(d) + atomic.AddUint64(&c.cacheTick, 1) + } +} + +// Get checks if the cache ticker has moved to the next version before returning +// the map. If it has moved, we reset the map. +func (c *ConntrackCacheTicker) Get() ConntrackCache { + if c == nil { + return nil + } + if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV { + c.cacheV = tick + if ll := len(c.cache); ll > 0 { + if l.GetLevel() == logrus.DebugLevel { + l.WithField("len", ll).Debug("resetting conntrack cache") + } + c.cache = make(ConntrackCache, ll) + } + } + + return c.cache +} diff --git a/firewall_test.go b/firewall_test.go index 8068c8a..3995e8d 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -182,44 +182,44 @@ func TestFirewall_Drop(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteIP p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10)) - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteIP = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -370,10 +370,10 @@ func TestFirewall_Drop2(t *testing.T) { cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp), ErrNoMatchingRule) + assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -454,13 +454,13 @@ func TestFirewall_Drop3(t *testing.T) { cp := cert.NewCAPool() // c1 should pass because host match - assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -505,12 +505,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) @@ -519,7 +519,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) @@ -528,7 +528,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) } func BenchmarkLookup(b *testing.B) { diff --git a/inside.go b/inside.go index 6192a1c..302b22b 100644 --- a/inside.go +++ b/inside.go @@ -7,7 +7,7 @@ import ( "github.com/sirupsen/logrus" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) @@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, ci.queueLock.Unlock() } - dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) + dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache) if dropReason == nil { mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) if f.lightHouse != nil && mc%5000 == 0 { @@ -129,7 +129,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) + dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil) if dropReason != nil { if l.Level >= logrus.DebugLevel { l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index 825ba97..d17f6a8 100644 --- a/interface.go +++ b/interface.go @@ -40,6 +40,8 @@ type InterfaceConfig struct { routines int MessageMetrics *MessageMetrics version string + + ConntrackCacheTimeout time.Duration } type Interface struct { @@ -61,6 +63,8 @@ type Interface struct { routines int version string + conntrackCacheTimeout time.Duration + writers []*udpConn readers []io.ReadWriteCloser @@ -102,6 +106,8 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { writers: make([]*udpConn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), + conntrackCacheTimeout: c.ConntrackCacheTimeout, + metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), messageMetrics: c.MessageMetrics, } @@ -173,6 +179,8 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { fwPacket := &FirewallPacket{} nb := make([]byte, 12, 12) + conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) + for { n, err := reader.Read(packet) if err != nil { @@ -181,7 +189,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { os.Exit(2) } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i) + f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) } } diff --git a/main.go b/main.go index 0800ffc..2f81fac 100644 --- a/main.go +++ b/main.go @@ -117,6 +117,18 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } } + // EXPERIMENTAL + // Intentionally not documented yet while we do more testing and determine + // a good default value. + conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0) + if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") { + // Use a different default if we are running with multiple routines + conntrackCacheTimeout = 1 * time.Second + } + if conntrackCacheTimeout > 0 { + l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") + } + var tun Inside if !configTest { config.CatchHUP() @@ -359,6 +371,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L routines: routines, MessageMetrics: messageMetrics, version: buildVersion, + + ConntrackCacheTimeout: conntrackCacheTimeout, } switch ifConfig.Cipher { diff --git a/outside.go b/outside.go index e0f9aaa..75f4eba 100644 --- a/outside.go +++ b/outside.go @@ -17,7 +17,7 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int) { +func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) { err := header.Parse(packet) if err != nil { // TODO: best if we return this and let caller log @@ -45,7 +45,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, return } - f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q) + f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache) // Fallthrough to the bottom to record incoming traffic @@ -257,7 +257,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int) { +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) { var err error out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) @@ -281,7 +281,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return } - dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) + dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache) if dropReason != nil { if l.Level >= logrus.DebugLevel { hostinfo.logger().WithField("fwPacket", fwPacket). diff --git a/udp_generic.go b/udp_generic.go index 5a1d204..2de6e29 100644 --- a/udp_generic.go +++ b/udp_generic.go @@ -115,6 +115,8 @@ func (u *udpConn) ListenOut(f *Interface, q int) { lhh := f.lightHouse.NewRequestHandler() + conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) + for { // Just read one packet at a time n, rua, err := u.ReadFromUDP(buffer) @@ -124,7 +126,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) { } udpAddr.UDPAddr = *rua - f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q) + f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get()) } } diff --git a/udp_linux.go b/udp_linux.go index 69eee31..dbdad2c 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -174,6 +174,8 @@ func (u *udpConn) ListenOut(f *Interface, q int) { read = u.ReadSingle } + conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) + for { n, err := read(msgs) if err != nil { @@ -186,7 +188,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) { udpAddr.IP = binary.BigEndian.Uint32(names[i][4:8]) udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q) + f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get()) } } }