diff --git a/firewall.go b/firewall.go index fd25098..91638e1 100644 --- a/firewall.go +++ b/firewall.go @@ -15,6 +15,7 @@ import ( "time" "github.com/rcrowley/go-metrics" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" ) @@ -37,13 +38,19 @@ type FirewallInterface interface { type conn struct { Expires time.Time // Time when this conntrack entry will expire - Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set + Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack + + // record why the original connection passed the firewall, so we can re-validate + // after ruleset changes. Note, rulesVersion is a uint16 so that these two + // fields pack for free after the uint32 above + incoming bool + rulesVersion uint16 } // TODO: need conntrack max tracked connections handling type Firewall struct { - Conns map[FirewallPacket]*conn + Conntrack *FirewallConntrack InRules *FirewallTable OutRules *FirewallTable @@ -54,18 +61,23 @@ type Firewall struct { UDPTimeout time.Duration //linux: 180s max DefaultTimeout time.Duration //linux: 600s - TimerWheel *TimerWheel - // Used to ensure we don't emit local packets for ips we don't own localIps *CIDRTree - connMutex sync.Mutex - rules string + rules string + rulesVersion uint16 trackTCPRTT bool metricTCPRTT metrics.Histogram } +type FirewallConntrack struct { + sync.Mutex + + Conns map[FirewallPacket]*conn + TimerWheel *TimerWheel +} + type FirewallTable struct { TCP firewallPort UDP firewallPort @@ -171,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N } return &Firewall{ - Conns: make(map[FirewallPacket]*conn), + Conntrack: &FirewallConntrack{ + Conns: make(map[FirewallPacket]*conn), + TimerWheel: NewTimerWheel(min, max), + }, InRules: newFirewallTable(), OutRules: newFirewallTable(), - TimerWheel: NewTimerWheel(min, max), TCPTimeout: tcpTimeout, UDPTimeout: UDPTimeout, DefaultTimeout: defaultTimeout, @@ -354,7 +368,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // returns nil if the packet should not be dropped. func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error { // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(packet, fp, incoming) { + if f.inConns(packet, fp, incoming, h, caPool) { return nil } @@ -398,26 +412,66 @@ func (f *Firewall) Destroy() { } func (f *Firewall) EmitStats() { - conntrackCount := len(f.Conns) + conntrack := f.Conntrack + conntrack.Lock() + conntrackCount := len(conntrack.Conns) + conntrack.Unlock() metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount)) + metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion)) } -func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool { - f.connMutex.Lock() +func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool { + conntrack := f.Conntrack + conntrack.Lock() // Purge every time we test - ep, has := f.TimerWheel.Purge() + ep, has := conntrack.TimerWheel.Purge() if has { f.evict(ep) } - c, ok := f.Conns[fp] + c, ok := conntrack.Conns[fp] if !ok { - f.connMutex.Unlock() + conntrack.Unlock() return false } + if c.rulesVersion != f.rulesVersion { + // This conntrack entry was for an older rule set, validate + // it still passes with the current rule set + table := f.OutRules + if c.incoming { + table = f.InRules + } + + // We now know which firewall table to check against + if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { + if l.Level >= logrus.DebugLevel { + h.logger(). + WithField("fwPacket", fp). + WithField("incoming", c.incoming). + WithField("rulesVersion", f.rulesVersion). + WithField("oldRulesVersion", c.rulesVersion). + Debugln("dropping old conntrack entry, does not match new ruleset") + } + delete(conntrack.Conns, fp) + conntrack.Unlock() + return false + } + + if l.Level >= logrus.DebugLevel { + h.logger(). + WithField("fwPacket", fp). + WithField("incoming", c.incoming). + WithField("rulesVersion", f.rulesVersion). + WithField("oldRulesVersion", c.rulesVersion). + Debugln("keeping old conntrack entry, does match new ruleset") + } + + c.rulesVersion = f.rulesVersion + } + switch fp.Protocol { case fwProtoTCP: c.Expires = time.Now().Add(f.TCPTimeout) @@ -432,7 +486,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool c.Expires = time.Now().Add(f.DefaultTimeout) } - f.connMutex.Unlock() + conntrack.Unlock() return true } @@ -453,14 +507,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) { timeout = f.DefaultTimeout } - f.connMutex.Lock() - if _, ok := f.Conns[fp]; !ok { - f.TimerWheel.Add(fp, timeout) + conntrack := f.Conntrack + conntrack.Lock() + if _, ok := conntrack.Conns[fp]; !ok { + conntrack.TimerWheel.Add(fp, timeout) } + // Record which rulesVersion allowed this connection, so we can retest after + // firewall reload + c.incoming = incoming + c.rulesVersion = f.rulesVersion c.Expires = time.Now().Add(timeout) - f.Conns[fp] = c - f.connMutex.Unlock() + conntrack.Conns[fp] = c + conntrack.Unlock() } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel @@ -468,7 +527,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) { func (f *Firewall) evict(p FirewallPacket) { //TODO: report a stat if the tcp rtt tracking was never resolved? // Are we still tracking this conn? - t, ok := f.Conns[p] + conntrack := f.Conntrack + t, ok := conntrack.Conns[p] if !ok { return } @@ -477,12 +537,12 @@ func (f *Firewall) evict(p FirewallPacket) { // Timeout is in the future, re-add the timer if newT > 0 { - f.TimerWheel.Add(p, newT) + conntrack.TimerWheel.Add(p, newT) return } // This conn is done - delete(f.Conns, p) + delete(conntrack.Conns, p) } func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { diff --git a/firewall_test.go b/firewall_test.go index d7ca789..8068c8a 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -17,37 +17,39 @@ import ( func TestNewFirewall(t *testing.T) { c := &cert.NebulaCertificate{} fw := NewFirewall(time.Second, time.Minute, time.Hour, c) - assert.NotNil(t, fw.Conns) + conntrack := fw.Conntrack + assert.NotNil(t, conntrack) + assert.NotNil(t, conntrack.Conns) + assert.NotNil(t, conntrack.TimerWheel) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) - assert.NotNil(t, fw.TimerWheel) assert.Equal(t, time.Second, fw.TCPTimeout) assert.Equal(t, time.Minute, fw.UDPTimeout) assert.Equal(t, time.Hour, fw.DefaultTimeout) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Second, time.Hour, time.Minute, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Hour, time.Second, time.Minute, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Hour, time.Minute, time.Second, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Minute, time.Hour, time.Second, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Minute, time.Second, time.Hour, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) } func TestFirewall_AddRule(t *testing.T) { @@ -461,6 +463,74 @@ func TestFirewall_Drop3(t *testing.T) { assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule) } +func TestFirewall_DropConntrackReload(t *testing.T) { + ob := &bytes.Buffer{} + out := l.Out + l.SetOutput(ob) + defer l.SetOutput(out) + + p := FirewallPacket{ + ip2int(net.IPv4(1, 2, 3, 4)), + ip2int(net.IPv4(1, 2, 3, 4)), + 10, + 90, + fwProtoUDP, + false, + } + + ipNet := net.IPNet{ + IP: net.IPv4(1, 2, 3, 4), + Mask: net.IPMask{255, 255, 255, 0}, + } + + c := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host1", + Ips: []*net.IPNet{&ipNet}, + Groups: []string{"default-group"}, + InvertedGroups: map[string]struct{}{"default-group": {}}, + Issuer: "signer-shasum", + }, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + hostId: ip2int(ipNet.IP), + } + h.CreateRemoteCIDR(&c) + + fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + cp := cert.NewCAPool() + + // Drop outbound + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + // Allow outbound because conntrack + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + + oldFw := fw + fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) + fw.Conntrack = oldFw.Conntrack + fw.rulesVersion = oldFw.rulesVersion + 1 + + // Allow outbound because conntrack and new rules allow port 10 + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + + oldFw = fw + fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) + fw.Conntrack = oldFw.Conntrack + fw.rulesVersion = oldFw.rulesVersion + 1 + + // Drop outbound because conntrack doesn't match new ruleset + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) +} + func BenchmarkLookup(b *testing.B) { ml := func(m map[string]struct{}, a [][]string) { for n := 0; n < b.N; n++ { @@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end } func resetConntrack(fw *Firewall) { - fw.connMutex.Lock() - fw.Conns = map[FirewallPacket]*conn{} - fw.connMutex.Unlock() + fw.Conntrack.Lock() + fw.Conntrack.Conns = map[FirewallPacket]*conn{} + fw.Conntrack.Unlock() } diff --git a/interface.go b/interface.go index 5739ea0..95caa12 100644 --- a/interface.go +++ b/interface.go @@ -219,11 +219,28 @@ func (f *Interface) reloadFirewall(c *Config) { } oldFw := f.firewall + conntrack := oldFw.Conntrack + conntrack.Lock() + defer conntrack.Unlock() + + fw.rulesVersion = oldFw.rulesVersion + 1 + // If rulesVersion is back to zero, we have wrapped all the way around. Be + // safe and just reset conntrack in this case. + if fw.rulesVersion == 0 { + l.WithField("firewallHash", fw.GetRuleHash()). + WithField("oldFirewallHash", oldFw.GetRuleHash()). + WithField("rulesVersion", fw.rulesVersion). + Warn("firewall rulesVersion has overflowed, resetting conntrack") + } else { + fw.Conntrack = conntrack + } + f.firewall = fw oldFw.Destroy() l.WithField("firewallHash", fw.GetRuleHash()). WithField("oldFirewallHash", oldFw.GetRuleHash()). + WithField("rulesVersion", fw.rulesVersion). Info("New firewall has been installed") }