Preserve conntrack table during firewall rules reload (SIGHUP) (#233)
Currently, we drop the conntrack table when firewall rules change during a SIGHUP reload. This means responses to inflight HTTP requests can be dropped, among other issues. This change copies the conntrack table over to the new firewall (it holds the conntrack mutex lock during this process, to be safe). This change also records which firewall rules hash each conntrack entry used, so that we can re-verify the rules after the new firewall has been loaded.
This commit is contained in:
parent
9b06748506
commit
f3a6d8d990
104
firewall.go
104
firewall.go
|
@ -15,6 +15,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
|
@ -37,13 +38,19 @@ type FirewallInterface interface {
|
|||
|
||||
type conn struct {
|
||||
Expires time.Time // Time when this conntrack entry will expire
|
||||
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
|
||||
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
|
||||
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
|
||||
|
||||
// record why the original connection passed the firewall, so we can re-validate
|
||||
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
|
||||
// fields pack for free after the uint32 above
|
||||
incoming bool
|
||||
rulesVersion uint16
|
||||
}
|
||||
|
||||
// TODO: need conntrack max tracked connections handling
|
||||
type Firewall struct {
|
||||
Conns map[FirewallPacket]*conn
|
||||
Conntrack *FirewallConntrack
|
||||
|
||||
InRules *FirewallTable
|
||||
OutRules *FirewallTable
|
||||
|
@ -54,18 +61,23 @@ type Firewall struct {
|
|||
UDPTimeout time.Duration //linux: 180s max
|
||||
DefaultTimeout time.Duration //linux: 600s
|
||||
|
||||
TimerWheel *TimerWheel
|
||||
|
||||
// Used to ensure we don't emit local packets for ips we don't own
|
||||
localIps *CIDRTree
|
||||
|
||||
connMutex sync.Mutex
|
||||
rules string
|
||||
rulesVersion uint16
|
||||
|
||||
trackTCPRTT bool
|
||||
metricTCPRTT metrics.Histogram
|
||||
}
|
||||
|
||||
type FirewallConntrack struct {
|
||||
sync.Mutex
|
||||
|
||||
Conns map[FirewallPacket]*conn
|
||||
TimerWheel *TimerWheel
|
||||
}
|
||||
|
||||
type FirewallTable struct {
|
||||
TCP firewallPort
|
||||
UDP firewallPort
|
||||
|
@ -171,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
|
|||
}
|
||||
|
||||
return &Firewall{
|
||||
Conntrack: &FirewallConntrack{
|
||||
Conns: make(map[FirewallPacket]*conn),
|
||||
TimerWheel: NewTimerWheel(min, max),
|
||||
},
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
TimerWheel: NewTimerWheel(min, max),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
|
@ -354,7 +368,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
|||
// returns nil if the packet should not be dropped.
|
||||
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
|
||||
// Check if we spoke to this tuple, if we did then allow this packet
|
||||
if f.inConns(packet, fp, incoming) {
|
||||
if f.inConns(packet, fp, incoming, h, caPool) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -398,26 +412,66 @@ func (f *Firewall) Destroy() {
|
|||
}
|
||||
|
||||
func (f *Firewall) EmitStats() {
|
||||
conntrackCount := len(f.Conns)
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
conntrackCount := len(conntrack.Conns)
|
||||
conntrack.Unlock()
|
||||
metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
|
||||
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool {
|
||||
f.connMutex.Lock()
|
||||
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
|
||||
// Purge every time we test
|
||||
ep, has := f.TimerWheel.Purge()
|
||||
ep, has := conntrack.TimerWheel.Purge()
|
||||
if has {
|
||||
f.evict(ep)
|
||||
}
|
||||
|
||||
c, ok := f.Conns[fp]
|
||||
c, ok := conntrack.Conns[fp]
|
||||
|
||||
if !ok {
|
||||
f.connMutex.Unlock()
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
if c.rulesVersion != f.rulesVersion {
|
||||
// This conntrack entry was for an older rule set, validate
|
||||
// it still passes with the current rule set
|
||||
table := f.OutRules
|
||||
if c.incoming {
|
||||
table = f.InRules
|
||||
}
|
||||
|
||||
// We now know which firewall table to check against
|
||||
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
h.logger().
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
WithField("oldRulesVersion", c.rulesVersion).
|
||||
Debugln("dropping old conntrack entry, does not match new ruleset")
|
||||
}
|
||||
delete(conntrack.Conns, fp)
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
h.logger().
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
WithField("oldRulesVersion", c.rulesVersion).
|
||||
Debugln("keeping old conntrack entry, does match new ruleset")
|
||||
}
|
||||
|
||||
c.rulesVersion = f.rulesVersion
|
||||
}
|
||||
|
||||
switch fp.Protocol {
|
||||
case fwProtoTCP:
|
||||
c.Expires = time.Now().Add(f.TCPTimeout)
|
||||
|
@ -432,7 +486,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool
|
|||
c.Expires = time.Now().Add(f.DefaultTimeout)
|
||||
}
|
||||
|
||||
f.connMutex.Unlock()
|
||||
conntrack.Unlock()
|
||||
|
||||
return true
|
||||
}
|
||||
|
@ -453,14 +507,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
|||
timeout = f.DefaultTimeout
|
||||
}
|
||||
|
||||
f.connMutex.Lock()
|
||||
if _, ok := f.Conns[fp]; !ok {
|
||||
f.TimerWheel.Add(fp, timeout)
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
if _, ok := conntrack.Conns[fp]; !ok {
|
||||
conntrack.TimerWheel.Add(fp, timeout)
|
||||
}
|
||||
|
||||
// Record which rulesVersion allowed this connection, so we can retest after
|
||||
// firewall reload
|
||||
c.incoming = incoming
|
||||
c.rulesVersion = f.rulesVersion
|
||||
c.Expires = time.Now().Add(timeout)
|
||||
f.Conns[fp] = c
|
||||
f.connMutex.Unlock()
|
||||
conntrack.Conns[fp] = c
|
||||
conntrack.Unlock()
|
||||
}
|
||||
|
||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||
|
@ -468,7 +527,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
|||
func (f *Firewall) evict(p FirewallPacket) {
|
||||
//TODO: report a stat if the tcp rtt tracking was never resolved?
|
||||
// Are we still tracking this conn?
|
||||
t, ok := f.Conns[p]
|
||||
conntrack := f.Conntrack
|
||||
t, ok := conntrack.Conns[p]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
@ -477,12 +537,12 @@ func (f *Firewall) evict(p FirewallPacket) {
|
|||
|
||||
// Timeout is in the future, re-add the timer
|
||||
if newT > 0 {
|
||||
f.TimerWheel.Add(p, newT)
|
||||
conntrack.TimerWheel.Add(p, newT)
|
||||
return
|
||||
}
|
||||
|
||||
// This conn is done
|
||||
delete(f.Conns, p)
|
||||
delete(conntrack.Conns, p)
|
||||
}
|
||||
|
||||
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
|
|
106
firewall_test.go
106
firewall_test.go
|
@ -17,37 +17,39 @@ import (
|
|||
func TestNewFirewall(t *testing.T) {
|
||||
c := &cert.NebulaCertificate{}
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
assert.NotNil(t, fw.Conns)
|
||||
conntrack := fw.Conntrack
|
||||
assert.NotNil(t, conntrack)
|
||||
assert.NotNil(t, conntrack.Conns)
|
||||
assert.NotNil(t, conntrack.TimerWheel)
|
||||
assert.NotNil(t, fw.InRules)
|
||||
assert.NotNil(t, fw.OutRules)
|
||||
assert.NotNil(t, fw.TimerWheel)
|
||||
assert.Equal(t, time.Second, fw.TCPTimeout)
|
||||
assert.Equal(t, time.Minute, fw.UDPTimeout)
|
||||
assert.Equal(t, time.Hour, fw.DefaultTimeout)
|
||||
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
}
|
||||
|
||||
func TestFirewall_AddRule(t *testing.T) {
|
||||
|
@ -461,6 +463,74 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
ob := &bytes.Buffer{}
|
||||
out := l.Out
|
||||
l.SetOutput(ob)
|
||||
defer l.SetOutput(out)
|
||||
|
||||
p := FirewallPacket{
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
10,
|
||||
90,
|
||||
fwProtoUDP,
|
||||
false,
|
||||
}
|
||||
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
c := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host1",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Groups: []string{"default-group"},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
Issuer: "signer-shasum",
|
||||
},
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
hostId: ip2int(ipNet.IP),
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||
// Allow outbound because conntrack
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Allow outbound because conntrack and new rules allow port 10
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Drop outbound because conntrack doesn't match new ruleset
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func BenchmarkLookup(b *testing.B) {
|
||||
ml := func(m map[string]struct{}, a [][]string) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
|
@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
|
|||
}
|
||||
|
||||
func resetConntrack(fw *Firewall) {
|
||||
fw.connMutex.Lock()
|
||||
fw.Conns = map[FirewallPacket]*conn{}
|
||||
fw.connMutex.Unlock()
|
||||
fw.Conntrack.Lock()
|
||||
fw.Conntrack.Conns = map[FirewallPacket]*conn{}
|
||||
fw.Conntrack.Unlock()
|
||||
}
|
||||
|
|
17
interface.go
17
interface.go
|
@ -219,11 +219,28 @@ func (f *Interface) reloadFirewall(c *Config) {
|
|||
}
|
||||
|
||||
oldFw := f.firewall
|
||||
conntrack := oldFw.Conntrack
|
||||
conntrack.Lock()
|
||||
defer conntrack.Unlock()
|
||||
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
||||
// safe and just reset conntrack in this case.
|
||||
if fw.rulesVersion == 0 {
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).
|
||||
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||
WithField("rulesVersion", fw.rulesVersion).
|
||||
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
||||
} else {
|
||||
fw.Conntrack = conntrack
|
||||
}
|
||||
|
||||
f.firewall = fw
|
||||
|
||||
oldFw.Destroy()
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).
|
||||
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||
WithField("rulesVersion", fw.rulesVersion).
|
||||
Info("New firewall has been installed")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue