Rework some things into packages (#489)

This commit is contained in:
Nate Brown
2021-11-03 20:54:04 -05:00
committed by GitHub
parent 1f75fb3c73
commit bcabcfdaca
73 changed files with 2526 additions and 2374 deletions

View File

@ -4,7 +4,6 @@ import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net"
@ -12,22 +11,14 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
const (
fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
fwProtoTCP = 6
fwProtoUDP = 17
fwProtoICMP = 1
fwPortAny = 0 // Special value for matching `port: any`
fwPortFragment = -1 // Special value for matching `port: fragment`
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
)
const tcpACK = 0x10
@ -63,7 +54,7 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
localIps *CIDRTree
localIps *cidr.Tree4
rules string
rulesVersion uint16
@ -85,7 +76,7 @@ type firewallMetrics struct {
type FirewallConntrack struct {
sync.Mutex
Conns map[FirewallPacket]*conn
Conns map[firewall.Packet]*conn
TimerWheel *TimerWheel
}
@ -116,55 +107,13 @@ type FirewallRule struct {
Any bool
Hosts map[string]struct{}
Groups [][]string
CIDR *CIDRTree
CIDR *cidr.Tree4
}
// Even though ports are uint16, int32 maps are faster for lookup
// Plus we can use `-1` for fragment rules
type firewallPort map[int32]*FirewallCA
type FirewallPacket struct {
LocalIP uint32
RemoteIP uint32
LocalPort uint16
RemotePort uint16
Protocol uint8
Fragment bool
}
func (fp *FirewallPacket) Copy() *FirewallPacket {
return &FirewallPacket{
LocalIP: fp.LocalIP,
RemoteIP: fp.RemoteIP,
LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort,
Protocol: fp.Protocol,
Fragment: fp.Fragment,
}
}
func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
var proto string
switch fp.Protocol {
case fwProtoTCP:
proto = "tcp"
case fwProtoICMP:
proto = "icmp"
case fwProtoUDP:
proto = "udp"
default:
proto = fmt.Sprintf("unknown %v", fp.Protocol)
}
return json.Marshal(m{
"LocalIP": int2ip(fp.LocalIP).String(),
"RemoteIP": int2ip(fp.RemoteIP).String(),
"LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort,
"Protocol": proto,
"Fragment": fp.Fragment,
})
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration
@ -184,7 +133,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
max = defaultTimeout
}
localIps := NewCIDRTree()
localIps := cidr.NewTree4()
for _, ip := range c.Details.Ips {
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
}
@ -195,7 +144,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[FirewallPacket]*conn),
Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel(min, max),
},
InRules: newFirewallTable(),
@ -220,7 +169,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
@ -278,13 +227,13 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
}
switch proto {
case fwProtoTCP:
case firewall.ProtoTCP:
fp = ft.TCP
case fwProtoUDP:
case firewall.ProtoUDP:
fp = ft.UDP
case fwProtoICMP:
case firewall.ProtoICMP:
fp = ft.ICMP
case fwProtoAny:
case firewall.ProtoAny:
fp = ft.AnyProto
default:
return fmt.Errorf("unknown protocol %v", proto)
@ -299,7 +248,7 @@ func (f *Firewall) GetRuleHash() string {
return hex.EncodeToString(sum[:])
}
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string
if inbound {
table = "firewall.inbound"
@ -307,7 +256,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
table = "firewall.outbound"
}
r := config.Get(table)
r := c.Get(table)
if r == nil {
return nil
}
@ -362,13 +311,13 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
var proto uint8
switch r.Proto {
case "any":
proto = fwProtoAny
proto = firewall.ProtoAny
case "tcp":
proto = fwProtoTCP
proto = firewall.ProtoTCP
case "udp":
proto = fwProtoUDP
proto = firewall.ProtoUDP
case "icmp":
proto = fwProtoICMP
proto = firewall.ProtoICMP
default:
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
@ -396,7 +345,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
return nil
@ -410,7 +359,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
}
} else {
// Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.hostId {
if fp.RemoteIP != h.vpnIp {
f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP
}
@ -462,7 +411,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
}
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
@ -520,14 +469,14 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
}
switch fp.Protocol {
case fwProtoTCP:
case firewall.ProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout)
if incoming {
f.checkTCPRTT(c, packet)
} else {
setTCPRTTTracking(c, packet)
}
case fwProtoUDP:
case firewall.ProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout)
default:
c.Expires = time.Now().Add(f.DefaultTimeout)
@ -542,17 +491,17 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
return true
}
func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case fwProtoTCP:
case firewall.ProtoTCP:
timeout = f.TCPTimeout
if !incoming {
setTCPRTTTracking(c, packet)
}
case fwProtoUDP:
case firewall.ProtoUDP:
timeout = f.UDPTimeout
default:
timeout = f.DefaultTimeout
@ -575,7 +524,7 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock!
func (f *Firewall) evict(p FirewallPacket) {
func (f *Firewall) evict(p firewall.Packet) {
//TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn?
conntrack := f.Conntrack
@ -596,21 +545,21 @@ func (f *Firewall) evict(p FirewallPacket) {
delete(conntrack.Conns, p)
}
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
if ft.AnyProto.match(p, incoming, c, caPool) {
return true
}
switch p.Protocol {
case fwProtoTCP:
case firewall.ProtoTCP:
if ft.TCP.match(p, incoming, c, caPool) {
return true
}
case fwProtoUDP:
case firewall.ProtoUDP:
if ft.UDP.match(p, incoming, c, caPool) {
return true
}
case fwProtoICMP:
case firewall.ProtoICMP:
if ft.ICMP.match(p, incoming, c, caPool) {
return true
}
@ -640,7 +589,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
return nil
}
func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
// We don't have any allowed ports, bail
if fp == nil {
return false
@ -649,7 +598,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
var port int32
if p.Fragment {
port = fwPortFragment
port = firewall.PortFragment
} else if incoming {
port = int32(p.LocalPort)
} else {
@ -660,7 +609,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
return true
}
return fp[fwPortAny].match(p, c, caPool)
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
@ -668,7 +617,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
return &FirewallRule{
Hosts: make(map[string]struct{}),
Groups: make([][]string, 0),
CIDR: NewCIDRTree(),
CIDR: cidr.NewTree4(),
}
}
@ -703,7 +652,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
return nil
}
func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
if fc == nil {
return false
}
@ -736,7 +685,7 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
// If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{})
fr.CIDR = NewCIDRTree()
fr.CIDR = cidr.NewTree4()
} else {
if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups)
@ -776,7 +725,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return false
}
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
if fr == nil {
return false
}
@ -885,12 +834,12 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" {
startPort = fwPortAny
endPort = fwPortAny
startPort = firewall.PortAny
endPort = firewall.PortAny
} else if s == "fragment" {
startPort = fwPortFragment
endPort = fwPortFragment
startPort = firewall.PortFragment
endPort = firewall.PortFragment
} else if strings.Contains(s, `-`) {
sPorts := strings.SplitN(s, `-`, 2)
@ -914,8 +863,8 @@ func parsePort(s string) (startPort, endPort int32, err error) {
startPort = int32(rStartPort)
endPort = int32(rEndPort)
if startPort == fwPortAny {
endPort = fwPortAny
if startPort == firewall.PortAny {
endPort = firewall.PortAny
}
} else {
@ -968,54 +917,3 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
c.Seq = 0
return true
}
// ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table.
type ConntrackCache map[FirewallPacket]struct{}
type ConntrackCacheTicker struct {
cacheV uint64
cacheTick uint64
cache ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
cache: ConntrackCache{},
}
go c.tick(d)
return c
}
func (c *ConntrackCacheTicker) tick(d time.Duration) {
for {
time.Sleep(d)
atomic.AddUint64(&c.cacheTick, 1)
}
}
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
if c == nil {
return nil
}
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)
}
}
return c.cache
}