Rework some things into packages (#489)
This commit is contained in:
192
firewall.go
192
firewall.go
@ -4,7 +4,6 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -12,22 +11,14 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
const (
|
||||
fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
||||
fwProtoTCP = 6
|
||||
fwProtoUDP = 17
|
||||
fwProtoICMP = 1
|
||||
|
||||
fwPortAny = 0 // Special value for matching `port: any`
|
||||
fwPortFragment = -1 // Special value for matching `port: fragment`
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
)
|
||||
|
||||
const tcpACK = 0x10
|
||||
@ -63,7 +54,7 @@ type Firewall struct {
|
||||
DefaultTimeout time.Duration //linux: 600s
|
||||
|
||||
// Used to ensure we don't emit local packets for ips we don't own
|
||||
localIps *CIDRTree
|
||||
localIps *cidr.Tree4
|
||||
|
||||
rules string
|
||||
rulesVersion uint16
|
||||
@ -85,7 +76,7 @@ type firewallMetrics struct {
|
||||
type FirewallConntrack struct {
|
||||
sync.Mutex
|
||||
|
||||
Conns map[FirewallPacket]*conn
|
||||
Conns map[firewall.Packet]*conn
|
||||
TimerWheel *TimerWheel
|
||||
}
|
||||
|
||||
@ -116,55 +107,13 @@ type FirewallRule struct {
|
||||
Any bool
|
||||
Hosts map[string]struct{}
|
||||
Groups [][]string
|
||||
CIDR *CIDRTree
|
||||
CIDR *cidr.Tree4
|
||||
}
|
||||
|
||||
// Even though ports are uint16, int32 maps are faster for lookup
|
||||
// Plus we can use `-1` for fragment rules
|
||||
type firewallPort map[int32]*FirewallCA
|
||||
|
||||
type FirewallPacket struct {
|
||||
LocalIP uint32
|
||||
RemoteIP uint32
|
||||
LocalPort uint16
|
||||
RemotePort uint16
|
||||
Protocol uint8
|
||||
Fragment bool
|
||||
}
|
||||
|
||||
func (fp *FirewallPacket) Copy() *FirewallPacket {
|
||||
return &FirewallPacket{
|
||||
LocalIP: fp.LocalIP,
|
||||
RemoteIP: fp.RemoteIP,
|
||||
LocalPort: fp.LocalPort,
|
||||
RemotePort: fp.RemotePort,
|
||||
Protocol: fp.Protocol,
|
||||
Fragment: fp.Fragment,
|
||||
}
|
||||
}
|
||||
|
||||
func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
|
||||
var proto string
|
||||
switch fp.Protocol {
|
||||
case fwProtoTCP:
|
||||
proto = "tcp"
|
||||
case fwProtoICMP:
|
||||
proto = "icmp"
|
||||
case fwProtoUDP:
|
||||
proto = "udp"
|
||||
default:
|
||||
proto = fmt.Sprintf("unknown %v", fp.Protocol)
|
||||
}
|
||||
return json.Marshal(m{
|
||||
"LocalIP": int2ip(fp.LocalIP).String(),
|
||||
"RemoteIP": int2ip(fp.RemoteIP).String(),
|
||||
"LocalPort": fp.LocalPort,
|
||||
"RemotePort": fp.RemotePort,
|
||||
"Protocol": proto,
|
||||
"Fragment": fp.Fragment,
|
||||
})
|
||||
}
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||
//TODO: error on 0 duration
|
||||
@ -184,7 +133,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
max = defaultTimeout
|
||||
}
|
||||
|
||||
localIps := NewCIDRTree()
|
||||
localIps := cidr.NewTree4()
|
||||
for _, ip := range c.Details.Ips {
|
||||
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||
}
|
||||
@ -195,7 +144,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
|
||||
return &Firewall{
|
||||
Conntrack: &FirewallConntrack{
|
||||
Conns: make(map[FirewallPacket]*conn),
|
||||
Conns: make(map[firewall.Packet]*conn),
|
||||
TimerWheel: NewTimerWheel(min, max),
|
||||
},
|
||||
InRules: newFirewallTable(),
|
||||
@ -220,7 +169,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
}
|
||||
}
|
||||
|
||||
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
||||
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
|
||||
fw := NewFirewall(
|
||||
l,
|
||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||
@ -278,13 +227,13 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
}
|
||||
|
||||
switch proto {
|
||||
case fwProtoTCP:
|
||||
case firewall.ProtoTCP:
|
||||
fp = ft.TCP
|
||||
case fwProtoUDP:
|
||||
case firewall.ProtoUDP:
|
||||
fp = ft.UDP
|
||||
case fwProtoICMP:
|
||||
case firewall.ProtoICMP:
|
||||
fp = ft.ICMP
|
||||
case fwProtoAny:
|
||||
case firewall.ProtoAny:
|
||||
fp = ft.AnyProto
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol %v", proto)
|
||||
@ -299,7 +248,7 @@ func (f *Firewall) GetRuleHash() string {
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
|
||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||
var table string
|
||||
if inbound {
|
||||
table = "firewall.inbound"
|
||||
@ -307,7 +256,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
|
||||
table = "firewall.outbound"
|
||||
}
|
||||
|
||||
r := config.Get(table)
|
||||
r := c.Get(table)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
@ -362,13 +311,13 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
|
||||
var proto uint8
|
||||
switch r.Proto {
|
||||
case "any":
|
||||
proto = fwProtoAny
|
||||
proto = firewall.ProtoAny
|
||||
case "tcp":
|
||||
proto = fwProtoTCP
|
||||
proto = firewall.ProtoTCP
|
||||
case "udp":
|
||||
proto = fwProtoUDP
|
||||
proto = firewall.ProtoUDP
|
||||
case "icmp":
|
||||
proto = fwProtoICMP
|
||||
proto = firewall.ProtoICMP
|
||||
default:
|
||||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||
}
|
||||
@ -396,7 +345,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
// returns nil if the packet should not be dropped.
|
||||
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
|
||||
func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
|
||||
// Check if we spoke to this tuple, if we did then allow this packet
|
||||
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
|
||||
return nil
|
||||
@ -410,7 +359,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
|
||||
}
|
||||
} else {
|
||||
// Simple case: Certificate has one IP and no subnets
|
||||
if fp.RemoteIP != h.hostId {
|
||||
if fp.RemoteIP != h.vpnIp {
|
||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
@ -462,7 +411,7 @@ func (f *Firewall) EmitStats() {
|
||||
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
|
||||
func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return true
|
||||
@ -520,14 +469,14 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
||||
}
|
||||
|
||||
switch fp.Protocol {
|
||||
case fwProtoTCP:
|
||||
case firewall.ProtoTCP:
|
||||
c.Expires = time.Now().Add(f.TCPTimeout)
|
||||
if incoming {
|
||||
f.checkTCPRTT(c, packet)
|
||||
} else {
|
||||
setTCPRTTTracking(c, packet)
|
||||
}
|
||||
case fwProtoUDP:
|
||||
case firewall.ProtoUDP:
|
||||
c.Expires = time.Now().Add(f.UDPTimeout)
|
||||
default:
|
||||
c.Expires = time.Now().Add(f.DefaultTimeout)
|
||||
@ -542,17 +491,17 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
||||
func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
|
||||
var timeout time.Duration
|
||||
c := &conn{}
|
||||
|
||||
switch fp.Protocol {
|
||||
case fwProtoTCP:
|
||||
case firewall.ProtoTCP:
|
||||
timeout = f.TCPTimeout
|
||||
if !incoming {
|
||||
setTCPRTTTracking(c, packet)
|
||||
}
|
||||
case fwProtoUDP:
|
||||
case firewall.ProtoUDP:
|
||||
timeout = f.UDPTimeout
|
||||
default:
|
||||
timeout = f.DefaultTimeout
|
||||
@ -575,7 +524,7 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
||||
|
||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||
// Caller must own the connMutex lock!
|
||||
func (f *Firewall) evict(p FirewallPacket) {
|
||||
func (f *Firewall) evict(p firewall.Packet) {
|
||||
//TODO: report a stat if the tcp rtt tracking was never resolved?
|
||||
// Are we still tracking this conn?
|
||||
conntrack := f.Conntrack
|
||||
@ -596,21 +545,21 @@ func (f *Firewall) evict(p FirewallPacket) {
|
||||
delete(conntrack.Conns, p)
|
||||
}
|
||||
|
||||
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
if ft.AnyProto.match(p, incoming, c, caPool) {
|
||||
return true
|
||||
}
|
||||
|
||||
switch p.Protocol {
|
||||
case fwProtoTCP:
|
||||
case firewall.ProtoTCP:
|
||||
if ft.TCP.match(p, incoming, c, caPool) {
|
||||
return true
|
||||
}
|
||||
case fwProtoUDP:
|
||||
case firewall.ProtoUDP:
|
||||
if ft.UDP.match(p, incoming, c, caPool) {
|
||||
return true
|
||||
}
|
||||
case fwProtoICMP:
|
||||
case firewall.ProtoICMP:
|
||||
if ft.ICMP.match(p, incoming, c, caPool) {
|
||||
return true
|
||||
}
|
||||
@ -640,7 +589,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
// We don't have any allowed ports, bail
|
||||
if fp == nil {
|
||||
return false
|
||||
@ -649,7 +598,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
|
||||
var port int32
|
||||
|
||||
if p.Fragment {
|
||||
port = fwPortFragment
|
||||
port = firewall.PortFragment
|
||||
} else if incoming {
|
||||
port = int32(p.LocalPort)
|
||||
} else {
|
||||
@ -660,7 +609,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
|
||||
return true
|
||||
}
|
||||
|
||||
return fp[fwPortAny].match(p, c, caPool)
|
||||
return fp[firewall.PortAny].match(p, c, caPool)
|
||||
}
|
||||
|
||||
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
|
||||
@ -668,7 +617,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
|
||||
return &FirewallRule{
|
||||
Hosts: make(map[string]struct{}),
|
||||
Groups: make([][]string, 0),
|
||||
CIDR: NewCIDRTree(),
|
||||
CIDR: cidr.NewTree4(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -703,7 +652,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
if fc == nil {
|
||||
return false
|
||||
}
|
||||
@ -736,7 +685,7 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
|
||||
// If it's any we need to wipe out any pre-existing rules to save on memory
|
||||
fr.Groups = make([][]string, 0)
|
||||
fr.Hosts = make(map[string]struct{})
|
||||
fr.CIDR = NewCIDRTree()
|
||||
fr.CIDR = cidr.NewTree4()
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
fr.Groups = append(fr.Groups, groups)
|
||||
@ -776,7 +725,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
|
||||
return false
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
|
||||
func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
|
||||
if fr == nil {
|
||||
return false
|
||||
}
|
||||
@ -885,12 +834,12 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
|
||||
|
||||
func parsePort(s string) (startPort, endPort int32, err error) {
|
||||
if s == "any" {
|
||||
startPort = fwPortAny
|
||||
endPort = fwPortAny
|
||||
startPort = firewall.PortAny
|
||||
endPort = firewall.PortAny
|
||||
|
||||
} else if s == "fragment" {
|
||||
startPort = fwPortFragment
|
||||
endPort = fwPortFragment
|
||||
startPort = firewall.PortFragment
|
||||
endPort = firewall.PortFragment
|
||||
|
||||
} else if strings.Contains(s, `-`) {
|
||||
sPorts := strings.SplitN(s, `-`, 2)
|
||||
@ -914,8 +863,8 @@ func parsePort(s string) (startPort, endPort int32, err error) {
|
||||
startPort = int32(rStartPort)
|
||||
endPort = int32(rEndPort)
|
||||
|
||||
if startPort == fwPortAny {
|
||||
endPort = fwPortAny
|
||||
if startPort == firewall.PortAny {
|
||||
endPort = firewall.PortAny
|
||||
}
|
||||
|
||||
} else {
|
||||
@ -968,54 +917,3 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
|
||||
c.Seq = 0
|
||||
return true
|
||||
}
|
||||
|
||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||
// has been seen in the conntrack table.
|
||||
type ConntrackCache map[FirewallPacket]struct{}
|
||||
|
||||
type ConntrackCacheTicker struct {
|
||||
cacheV uint64
|
||||
cacheTick uint64
|
||||
|
||||
cache ConntrackCache
|
||||
}
|
||||
|
||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
||||
if d == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := &ConntrackCacheTicker{
|
||||
cache: ConntrackCache{},
|
||||
}
|
||||
|
||||
go c.tick(d)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
||||
for {
|
||||
time.Sleep(d)
|
||||
atomic.AddUint64(&c.cacheTick, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Get checks if the cache ticker has moved to the next version before returning
|
||||
// the map. If it has moved, we reset the map.
|
||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
|
||||
c.cacheV = tick
|
||||
if ll := len(c.cache); ll > 0 {
|
||||
if l.Level == logrus.DebugLevel {
|
||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||
}
|
||||
c.cache = make(ConntrackCache, ll)
|
||||
}
|
||||
}
|
||||
|
||||
return c.cache
|
||||
}
|
||||
|
Reference in New Issue
Block a user