Rework some things into packages (#489)

This commit is contained in:
Nate Brown
2021-11-03 20:54:04 -05:00
committed by GitHub
parent 1f75fb3c73
commit bcabcfdaca
73 changed files with 2526 additions and 2374 deletions

View File

@ -12,6 +12,10 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
//const ProbeLen = 100
@ -28,10 +32,10 @@ type HostMap struct {
name string
Indexes map[uint32]*HostInfo
RemoteIndexes map[uint32]*HostInfo
Hosts map[uint32]*HostInfo
Hosts map[iputil.VpnIp]*HostInfo
preferredRanges []*net.IPNet
vpnCIDR *net.IPNet
unsafeRoutes *CIDRTree
unsafeRoutes *cidr.Tree4
metricsEnabled bool
l *logrus.Logger
}
@ -39,7 +43,7 @@ type HostMap struct {
type HostInfo struct {
sync.RWMutex
remote *udpAddr
remote *udp.Addr
remotes *RemoteList
promoteCounter uint32
ConnectionState *ConnectionState
@ -51,9 +55,9 @@ type HostInfo struct {
packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32
localIndexId uint32
hostId uint32
vpnIp iputil.VpnIp
recvError int
remoteCidr *CIDRTree
remoteCidr *cidr.Tree4
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
@ -66,17 +70,17 @@ type HostInfo struct {
lastHandshakeTime uint64
lastRoam time.Time
lastRoamRemote *udpAddr
lastRoamRemote *udp.Addr
}
type cachedPacket struct {
messageType NebulaMessageType
messageSubType NebulaMessageSubType
messageType header.MessageType
messageSubType header.MessageSubType
callback packetCallback
packet []byte
}
type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte)
type cachedPacketMetrics struct {
sent metrics.Counter
@ -84,7 +88,7 @@ type cachedPacketMetrics struct {
}
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[uint32]*HostInfo{}
h := map[iputil.VpnIp]*HostInfo{}
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{}
m := HostMap{
@ -94,7 +98,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
unsafeRoutes: NewCIDRTree(),
unsafeRoutes: cidr.NewTree4(),
l: l,
}
return &m
@ -113,9 +117,9 @@ func (hm *HostMap) EmitStats(name string) {
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
}
func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
hm.RLock()
if i, ok := hm.Hosts[vpnIP]; ok {
if i, ok := hm.Hosts[vpnIp]; ok {
index := i.localIndexId
hm.RUnlock()
return index, nil
@ -124,43 +128,43 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
hm.Lock()
hm.Hosts[ip] = hostinfo
hm.Unlock()
}
func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
h := &HostInfo{}
hm.RLock()
if _, ok := hm.Hosts[vpnIP]; !ok {
if _, ok := hm.Hosts[vpnIp]; !ok {
hm.RUnlock()
h = &HostInfo{
promoteCounter: 0,
hostId: vpnIP,
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
}
hm.Lock()
hm.Hosts[vpnIP] = h
hm.Hosts[vpnIp] = h
hm.Unlock()
return h
} else {
h = hm.Hosts[vpnIP]
h = hm.Hosts[vpnIp]
hm.RUnlock()
return h
}
}
func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
func (hm *HostMap) DeleteVpnIp(vpnIp iputil.VpnIp) {
hm.Lock()
delete(hm.Hosts, vpnIP)
delete(hm.Hosts, vpnIp)
if len(hm.Hosts) == 0 {
hm.Hosts = map[uint32]*HostInfo{}
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
}
hm.Unlock()
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap vpnIp deleted")
}
}
@ -174,22 +178,22 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
Debug("Hostmap remoteIndex added")
}
}
func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
func (hm *HostMap) AddVpnIpHostInfo(vpnIp iputil.VpnIp, h *HostInfo) {
hm.Lock()
h.hostId = vpnIP
hm.Hosts[vpnIP] = h
h.vpnIp = vpnIp
hm.Hosts[vpnIp] = h
hm.Indexes[h.localIndexId] = h
hm.RemoteIndexes[h.remoteIndexId] = h
hm.Unlock()
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "vpnIp": h.vpnIp}}).
Debug("Hostmap vpnIp added")
}
}
@ -204,9 +208,9 @@ func (hm *HostMap) DeleteIndex(index uint32) {
// Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do.
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId)
delete(hm.Hosts, hostinfo.vpnIp)
}
}
hm.Unlock()
@ -228,9 +232,9 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
// Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do (they might not match in pendingHostmap)
var hostinfo2 *HostInfo
hostinfo2, ok = hm.Hosts[hostinfo.hostId]
hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId)
delete(hm.Hosts, hostinfo.vpnIp)
}
}
hm.Unlock()
@ -251,16 +255,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
// Check if this same hostId is in the hostmap with a different instance.
// This could happen if we have an entry in the pending hostmap with different
// index values than the one in the main hostmap.
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 != hostinfo {
delete(hm.Hosts, hostinfo2.hostId)
delete(hm.Hosts, hostinfo2.vpnIp)
delete(hm.Indexes, hostinfo2.localIndexId)
delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
}
delete(hm.Hosts, hostinfo.hostId)
delete(hm.Hosts, hostinfo.vpnIp)
if len(hm.Hosts) == 0 {
hm.Hosts = map[uint32]*HostInfo{}
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
}
delete(hm.Indexes, hostinfo.localIndexId)
if len(hm.Indexes) == 0 {
@ -273,7 +277,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
}
}
@ -301,17 +305,17 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
}
}
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, nil)
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
return hm.queryVpnIp(vpnIp, nil)
}
// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every
// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
// `PromoteEvery` calls to this function for a given host.
func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, ifce)
func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
return hm.queryVpnIp(vpnIp, ifce)
}
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock()
@ -327,10 +331,10 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
return nil, errors.New("unable to find host")
}
func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp {
r := hm.unsafeRoutes.MostSpecificContains(ip)
if r != nil {
return r.(uint32)
return r.(iputil.VpnIp)
} else {
return 0
}
@ -344,13 +348,13 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
hm.Hosts[hostinfo.hostId] = hostinfo
hm.Hosts[hostinfo.vpnIp] = hostinfo
hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
Debug("Hostmap vpnIp added")
}
}
@ -370,7 +374,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
}
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
var metricsTxPunchy metrics.Counter
if hm.metricsEnabled {
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
@ -406,7 +410,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
for _, r := range *routes {
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
hm.unsafeRoutes.AddCIDR(r.route, iputil.Ip2VpnIp(*r.via))
}
}
@ -431,24 +435,24 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
}
}
i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
if addr == nil || !preferred {
return
}
// Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes
ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
ifce.send(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
})
}
// Re query our lighthouses for new remotes occasionally
if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
ifce.lightHouse.QueryServer(i.hostId, ifce)
ifce.lightHouse.QueryServer(i.vpnIp, ifce)
}
}
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
//TODO: return the error so we can log with more context
if len(i.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
@ -510,17 +514,17 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
return nil
}
func (i *HostInfo) SetRemote(remote *udpAddr) {
func (i *HostInfo) SetRemote(remote *udp.Addr) {
// We copy here because we likely got this remote from a source that reuses the object
if !i.remote.Equals(remote) {
i.remote = remote.Copy()
i.remotes.LearnRemote(i.hostId, remote.Copy())
i.remotes.LearnRemote(i.vpnIp, remote.Copy())
}
}
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated.
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udpAddr) bool {
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
currentRemote := i.remote
if currentRemote == nil {
i.SetRemote(newRemote)
@ -572,7 +576,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
return
}
remoteCidr := NewCIDRTree()
remoteCidr := cidr.NewTree4()
for _, ip := range c.Details.Ips {
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
}
@ -588,8 +592,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return logrus.NewEntry(l)
}
li := l.WithField("vpnIp", IntIp(i.hostId))
li := l.WithField("vpnIp", i.vpnIp)
if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Details.Name)
@ -599,38 +602,6 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return li
}
//########################
/*
func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
s := "\n"
for _, h := range hm.Hosts {
for _, r := range h.Remotes {
s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes)
}
}
return s
}
func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
for _, r := range i.Remotes {
if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
r.ProbeReceived(counter)
}
}
}
func (i *HostInfo) Probes() []*Probe {
p := []*Probe{}
for _, d := range i.Remotes {
p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()})
}
return p
}
*/
// Utility functions
func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {