Rework some things into packages (#489)
This commit is contained in:
157
hostmap.go
157
hostmap.go
@ -12,6 +12,10 @@ import (
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
//const ProbeLen = 100
|
||||
@ -28,10 +32,10 @@ type HostMap struct {
|
||||
name string
|
||||
Indexes map[uint32]*HostInfo
|
||||
RemoteIndexes map[uint32]*HostInfo
|
||||
Hosts map[uint32]*HostInfo
|
||||
Hosts map[iputil.VpnIp]*HostInfo
|
||||
preferredRanges []*net.IPNet
|
||||
vpnCIDR *net.IPNet
|
||||
unsafeRoutes *CIDRTree
|
||||
unsafeRoutes *cidr.Tree4
|
||||
metricsEnabled bool
|
||||
l *logrus.Logger
|
||||
}
|
||||
@ -39,7 +43,7 @@ type HostMap struct {
|
||||
type HostInfo struct {
|
||||
sync.RWMutex
|
||||
|
||||
remote *udpAddr
|
||||
remote *udp.Addr
|
||||
remotes *RemoteList
|
||||
promoteCounter uint32
|
||||
ConnectionState *ConnectionState
|
||||
@ -51,9 +55,9 @@ type HostInfo struct {
|
||||
packetStore []*cachedPacket //todo: this is other handshake manager entry
|
||||
remoteIndexId uint32
|
||||
localIndexId uint32
|
||||
hostId uint32
|
||||
vpnIp iputil.VpnIp
|
||||
recvError int
|
||||
remoteCidr *CIDRTree
|
||||
remoteCidr *cidr.Tree4
|
||||
|
||||
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
|
||||
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
|
||||
@ -66,17 +70,17 @@ type HostInfo struct {
|
||||
lastHandshakeTime uint64
|
||||
|
||||
lastRoam time.Time
|
||||
lastRoamRemote *udpAddr
|
||||
lastRoamRemote *udp.Addr
|
||||
}
|
||||
|
||||
type cachedPacket struct {
|
||||
messageType NebulaMessageType
|
||||
messageSubType NebulaMessageSubType
|
||||
messageType header.MessageType
|
||||
messageSubType header.MessageSubType
|
||||
callback packetCallback
|
||||
packet []byte
|
||||
}
|
||||
|
||||
type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
|
||||
type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte)
|
||||
|
||||
type cachedPacketMetrics struct {
|
||||
sent metrics.Counter
|
||||
@ -84,7 +88,7 @@ type cachedPacketMetrics struct {
|
||||
}
|
||||
|
||||
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
|
||||
h := map[uint32]*HostInfo{}
|
||||
h := map[iputil.VpnIp]*HostInfo{}
|
||||
i := map[uint32]*HostInfo{}
|
||||
r := map[uint32]*HostInfo{}
|
||||
m := HostMap{
|
||||
@ -94,7 +98,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
|
||||
Hosts: h,
|
||||
preferredRanges: preferredRanges,
|
||||
vpnCIDR: vpnCIDR,
|
||||
unsafeRoutes: NewCIDRTree(),
|
||||
unsafeRoutes: cidr.NewTree4(),
|
||||
l: l,
|
||||
}
|
||||
return &m
|
||||
@ -113,9 +117,9 @@ func (hm *HostMap) EmitStats(name string) {
|
||||
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
|
||||
}
|
||||
|
||||
func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
|
||||
func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
|
||||
hm.RLock()
|
||||
if i, ok := hm.Hosts[vpnIP]; ok {
|
||||
if i, ok := hm.Hosts[vpnIp]; ok {
|
||||
index := i.localIndexId
|
||||
hm.RUnlock()
|
||||
return index, nil
|
||||
@ -124,43 +128,43 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
|
||||
return 0, errors.New("vpn IP not found")
|
||||
}
|
||||
|
||||
func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
|
||||
func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
|
||||
hm.Lock()
|
||||
hm.Hosts[ip] = hostinfo
|
||||
hm.Unlock()
|
||||
}
|
||||
|
||||
func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
|
||||
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
||||
h := &HostInfo{}
|
||||
hm.RLock()
|
||||
if _, ok := hm.Hosts[vpnIP]; !ok {
|
||||
if _, ok := hm.Hosts[vpnIp]; !ok {
|
||||
hm.RUnlock()
|
||||
h = &HostInfo{
|
||||
promoteCounter: 0,
|
||||
hostId: vpnIP,
|
||||
vpnIp: vpnIp,
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
}
|
||||
hm.Lock()
|
||||
hm.Hosts[vpnIP] = h
|
||||
hm.Hosts[vpnIp] = h
|
||||
hm.Unlock()
|
||||
return h
|
||||
} else {
|
||||
h = hm.Hosts[vpnIP]
|
||||
h = hm.Hosts[vpnIp]
|
||||
hm.RUnlock()
|
||||
return h
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
|
||||
func (hm *HostMap) DeleteVpnIp(vpnIp iputil.VpnIp) {
|
||||
hm.Lock()
|
||||
delete(hm.Hosts, vpnIP)
|
||||
delete(hm.Hosts, vpnIp)
|
||||
if len(hm.Hosts) == 0 {
|
||||
hm.Hosts = map[uint32]*HostInfo{}
|
||||
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
|
||||
}
|
||||
hm.Unlock()
|
||||
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts)}).
|
||||
Debug("Hostmap vpnIp deleted")
|
||||
}
|
||||
}
|
||||
@ -174,22 +178,22 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
|
||||
|
||||
if hm.l.Level > logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
|
||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
|
||||
Debug("Hostmap remoteIndex added")
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
|
||||
func (hm *HostMap) AddVpnIpHostInfo(vpnIp iputil.VpnIp, h *HostInfo) {
|
||||
hm.Lock()
|
||||
h.hostId = vpnIP
|
||||
hm.Hosts[vpnIP] = h
|
||||
h.vpnIp = vpnIp
|
||||
hm.Hosts[vpnIp] = h
|
||||
hm.Indexes[h.localIndexId] = h
|
||||
hm.RemoteIndexes[h.remoteIndexId] = h
|
||||
hm.Unlock()
|
||||
|
||||
if hm.l.Level > logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "vpnIp": h.vpnIp}}).
|
||||
Debug("Hostmap vpnIp added")
|
||||
}
|
||||
}
|
||||
@ -204,9 +208,9 @@ func (hm *HostMap) DeleteIndex(index uint32) {
|
||||
|
||||
// Check if we have an entry under hostId that matches the same hostinfo
|
||||
// instance. Clean it up as well if we do.
|
||||
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
|
||||
hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
|
||||
if ok && hostinfo2 == hostinfo {
|
||||
delete(hm.Hosts, hostinfo.hostId)
|
||||
delete(hm.Hosts, hostinfo.vpnIp)
|
||||
}
|
||||
}
|
||||
hm.Unlock()
|
||||
@ -228,9 +232,9 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
|
||||
// Check if we have an entry under hostId that matches the same hostinfo
|
||||
// instance. Clean it up as well if we do (they might not match in pendingHostmap)
|
||||
var hostinfo2 *HostInfo
|
||||
hostinfo2, ok = hm.Hosts[hostinfo.hostId]
|
||||
hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
|
||||
if ok && hostinfo2 == hostinfo {
|
||||
delete(hm.Hosts, hostinfo.hostId)
|
||||
delete(hm.Hosts, hostinfo.vpnIp)
|
||||
}
|
||||
}
|
||||
hm.Unlock()
|
||||
@ -251,16 +255,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||
// Check if this same hostId is in the hostmap with a different instance.
|
||||
// This could happen if we have an entry in the pending hostmap with different
|
||||
// index values than the one in the main hostmap.
|
||||
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
|
||||
hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
|
||||
if ok && hostinfo2 != hostinfo {
|
||||
delete(hm.Hosts, hostinfo2.hostId)
|
||||
delete(hm.Hosts, hostinfo2.vpnIp)
|
||||
delete(hm.Indexes, hostinfo2.localIndexId)
|
||||
delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
|
||||
}
|
||||
|
||||
delete(hm.Hosts, hostinfo.hostId)
|
||||
delete(hm.Hosts, hostinfo.vpnIp)
|
||||
if len(hm.Hosts) == 0 {
|
||||
hm.Hosts = map[uint32]*HostInfo{}
|
||||
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
|
||||
}
|
||||
delete(hm.Indexes, hostinfo.localIndexId)
|
||||
if len(hm.Indexes) == 0 {
|
||||
@ -273,7 +277,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
|
||||
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
Debug("Hostmap hostInfo deleted")
|
||||
}
|
||||
}
|
||||
@ -301,17 +305,17 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
|
||||
return hm.queryVpnIP(vpnIp, nil)
|
||||
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
|
||||
return hm.queryVpnIp(vpnIp, nil)
|
||||
}
|
||||
|
||||
// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every
|
||||
// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
|
||||
// `PromoteEvery` calls to this function for a given host.
|
||||
func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) {
|
||||
return hm.queryVpnIP(vpnIp, ifce)
|
||||
func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
|
||||
return hm.queryVpnIp(vpnIp, ifce)
|
||||
}
|
||||
|
||||
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
|
||||
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
|
||||
hm.RLock()
|
||||
if h, ok := hm.Hosts[vpnIp]; ok {
|
||||
hm.RUnlock()
|
||||
@ -327,10 +331,10 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
|
||||
return nil, errors.New("unable to find host")
|
||||
}
|
||||
|
||||
func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
|
||||
func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := hm.unsafeRoutes.MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(uint32)
|
||||
return r.(iputil.VpnIp)
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
@ -344,13 +348,13 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
||||
}
|
||||
|
||||
hm.Hosts[hostinfo.hostId] = hostinfo
|
||||
hm.Hosts[hostinfo.vpnIp] = hostinfo
|
||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
|
||||
Debug("Hostmap vpnIp added")
|
||||
}
|
||||
}
|
||||
@ -370,7 +374,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
|
||||
}
|
||||
|
||||
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
|
||||
func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
|
||||
func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
|
||||
var metricsTxPunchy metrics.Counter
|
||||
if hm.metricsEnabled {
|
||||
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
|
||||
@ -406,7 +410,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
|
||||
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
|
||||
for _, r := range *routes {
|
||||
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
|
||||
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
|
||||
hm.unsafeRoutes.AddCIDR(r.route, iputil.Ip2VpnIp(*r.via))
|
||||
}
|
||||
}
|
||||
|
||||
@ -431,24 +435,24 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
|
||||
}
|
||||
}
|
||||
|
||||
i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
|
||||
i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
|
||||
if addr == nil || !preferred {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to send a test packet to that host, this should
|
||||
// cause it to detect a roaming event and switch remotes
|
||||
ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
ifce.send(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
})
|
||||
}
|
||||
|
||||
// Re query our lighthouses for new remotes occasionally
|
||||
if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
|
||||
ifce.lightHouse.QueryServer(i.hostId, ifce)
|
||||
ifce.lightHouse.QueryServer(i.vpnIp, ifce)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
||||
func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
||||
//TODO: return the error so we can log with more context
|
||||
if len(i.packetStore) < 100 {
|
||||
tempPacket := make([]byte, len(packet))
|
||||
@ -510,17 +514,17 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *HostInfo) SetRemote(remote *udpAddr) {
|
||||
func (i *HostInfo) SetRemote(remote *udp.Addr) {
|
||||
// We copy here because we likely got this remote from a source that reuses the object
|
||||
if !i.remote.Equals(remote) {
|
||||
i.remote = remote.Copy()
|
||||
i.remotes.LearnRemote(i.hostId, remote.Copy())
|
||||
i.remotes.LearnRemote(i.vpnIp, remote.Copy())
|
||||
}
|
||||
}
|
||||
|
||||
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
|
||||
// time on the HostInfo will also be updated.
|
||||
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udpAddr) bool {
|
||||
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
|
||||
currentRemote := i.remote
|
||||
if currentRemote == nil {
|
||||
i.SetRemote(newRemote)
|
||||
@ -572,7 +576,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
||||
return
|
||||
}
|
||||
|
||||
remoteCidr := NewCIDRTree()
|
||||
remoteCidr := cidr.NewTree4()
|
||||
for _, ip := range c.Details.Ips {
|
||||
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||
}
|
||||
@ -588,8 +592,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||
return logrus.NewEntry(l)
|
||||
}
|
||||
|
||||
li := l.WithField("vpnIp", IntIp(i.hostId))
|
||||
|
||||
li := l.WithField("vpnIp", i.vpnIp)
|
||||
if connState := i.ConnectionState; connState != nil {
|
||||
if peerCert := connState.peerCert; peerCert != nil {
|
||||
li = li.WithField("certName", peerCert.Details.Name)
|
||||
@ -599,38 +602,6 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||
return li
|
||||
}
|
||||
|
||||
//########################
|
||||
|
||||
/*
|
||||
|
||||
func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
|
||||
s := "\n"
|
||||
for _, h := range hm.Hosts {
|
||||
for _, r := range h.Remotes {
|
||||
s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
|
||||
for _, r := range i.Remotes {
|
||||
if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
|
||||
r.ProbeReceived(counter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *HostInfo) Probes() []*Probe {
|
||||
p := []*Probe{}
|
||||
for _, d := range i.Remotes {
|
||||
p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()})
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
// Utility functions
|
||||
|
||||
func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
|
||||
|
Reference in New Issue
Block a user