Refactor remotes and handshaking to give every address a fair shot (#437)

This commit is contained in:
Nathan Brown 2021-04-14 13:50:09 -05:00 committed by GitHub
parent 20bef975cd
commit 710df6a876
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1561 additions and 1385 deletions

View File

@ -67,23 +67,11 @@ func (c *Control) RebindUDPServer() {
// ListHostmap returns details about the actual or pending (handshaking) hostmap // ListHostmap returns details about the actual or pending (handshaking) hostmap
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo { func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
var hm *HostMap
if pendingMap { if pendingMap {
hm = c.f.handshakeManager.pendingHostMap return listHostMap(c.f.handshakeManager.pendingHostMap)
} else { } else {
hm = c.f.hostMap return listHostMap(c.f.hostMap)
} }
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v)
i++
}
hm.RUnlock()
return hosts
} }
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found // GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
@ -100,7 +88,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
return nil return nil
} }
ch := copyHostInfo(h) ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
return &ch return &ch
} }
@ -112,7 +100,7 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
} }
hostInfo.SetRemote(addr.Copy()) hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo) ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
return &ch return &ch
} }
@ -163,14 +151,17 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
return return
} }
func copyHostInfo(h *HostInfo) ControlHostInfo { func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
chi := ControlHostInfo{ chi := ControlHostInfo{
VpnIP: int2ip(h.hostId), VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId, LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId, RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.CopyRemotes(), RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
CachedPackets: len(h.packetStore), CachedPackets: len(h.packetStore),
MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter), }
if h.ConnectionState != nil {
chi.MessageCounter = atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter)
} }
if c := h.GetCert(); c != nil { if c := h.GetCert(); c != nil {
@ -183,3 +174,16 @@ func copyHostInfo(h *HostInfo) ControlHostInfo {
return chi return chi
} }
func listHostMap(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()
return hosts
}

View File

@ -45,10 +45,12 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
Signature: []byte{1, 2, 1, 2, 1, 3}, Signature: []byte{1, 2, 1, 2, 1, 3},
} }
remotes := []*udpAddr{remote1, remote2} remotes := NewRemoteList()
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(ip2int(ipNet.IP), &HostInfo{ hm.Add(ip2int(ipNet.IP), &HostInfo{
remote: remote1, remote: remote1,
Remotes: remotes, remotes: remotes,
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: crt, peerCert: crt,
}, },
@ -59,7 +61,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
hm.Add(ip2int(ipNet2.IP), &HostInfo{ hm.Add(ip2int(ipNet2.IP), &HostInfo{
remote: remote1, remote: remote1,
Remotes: remotes, remotes: remotes,
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: nil, peerCert: nil,
}, },
@ -81,7 +83,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
VpnIP: net.IPv4(1, 2, 3, 4).To4(), VpnIP: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201, LocalIndex: 201,
RemoteIndex: 200, RemoteIndex: 200,
RemoteAddrs: []*udpAddr{remote1, remote2}, RemoteAddrs: []*udpAddr{remote2, remote1},
CachedPackets: 0, CachedPackets: 0,
Cert: crt.Copy(), Cert: crt.Copy(),
MessageCounter: 0, MessageCounter: 0,

View File

@ -44,7 +44,18 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
// InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
// This is necessary if you did not configure static hosts or are not running a lighthouse // This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
c.f.lightHouse.AddRemote(ip2int(vpnIp), &udpAddr{IP: toAddr.IP, Port: uint16(toAddr.Port)}, false) c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
iVpnIp := ip2int(vpnIp)
if v4 := toAddr.IP.To4(); v4 != nil {
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
} else {
remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
}
} }
// GetFromTun will pull a packet off the tun side of nebula // GetFromTun will pull a packet off the tun side of nebula
@ -84,14 +95,17 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
SrcPort: layers.UDPPort(fromPort), SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort), DstPort: layers.UDPPort(toPort),
} }
udp.SetNetworkLayerForChecksum(&ip) err := udp.SetNetworkLayerForChecksum(&ip)
if err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer() buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{ opt := gopacket.SerializeOptions{
ComputeChecksums: true, ComputeChecksums: true,
FixLengths: true, FixLengths: true,
} }
err := gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data)) err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -102,3 +116,13 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
func (c *Control) GetUDPAddr() string { func (c *Control) GetUDPAddr() string {
return c.f.outside.addr.String() return c.f.outside.addr.String()
} }
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
if !ok {
return false
}
c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
return true
}

View File

@ -9,6 +9,7 @@ import (
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
) )
func TestGoodHandshake(t *testing.T) { func TestGoodHandshake(t *testing.T) {
@ -23,35 +24,35 @@ func TestGoodHandshake(t *testing.T) {
myControl.Start() myControl.Start()
theirControl.Start() theirControl.Start()
// Send a udp packet through to begin standing up the tunnel, this should come out the other side t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
// Have them consume my stage 0 packet. They have a tunnel now t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
// Get their stage 1 packet so that we can play with it t.Log("Get their stage 1 packet so that we can play with it")
stage1Packet := theirControl.GetFromUDP(true) stage1Packet := theirControl.GetFromUDP(true)
// I consume a garbage packet with a proper nebula header for our tunnel t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel // this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy() badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen] badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
myControl.InjectUDPPacket(badPacket) myControl.InjectUDPPacket(badPacket)
// Have me consume their real stage 1 packet. I have a tunnel now t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet) myControl.InjectUDPPacket(stage1Packet)
// Wait until we see my cached packet come through t.Log("Wait until we see my cached packet come through")
myControl.WaitForType(1, 0, theirControl) myControl.WaitForType(1, 0, theirControl)
// Make sure our host infos are correct t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
// Get that cached packet and make sure it looks right t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true) myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
// Do a bidirectional tunnel test t.Log("Do a bidirectional tunnel test")
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, router.NewR(myControl, theirControl)) assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, router.NewR(myControl, theirControl))
myControl.Stop() myControl.Stop()
@ -62,14 +63,17 @@ func TestGoodHandshake(t *testing.T) {
func TestWrongResponderHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}) // The IPs here are chosen on purpose:
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) // The current remote handling will sort by preference, public, and then lexically.
evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 99}) // So we need them to have a higher address than evil (we could apply a preference though)
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100})
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99})
evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2})
// Add their real udp addr, which should be tried after evil. Doing this first because learned addresses are prepended // Add their real udp addr, which should be tried after evil.
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. This will now be the first attempted ip // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr) myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
@ -80,137 +84,98 @@ func TestWrongResponderHandshake(t *testing.T) {
theirControl.Start() theirControl.Start()
evilControl.Start() evilControl.Start()
t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)") t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
r.OnceFrom(myControl) r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType {
r.OnceFrom(evilControl) h := &nebula.Header{}
err := h.Parse(p.Data)
if err != nil {
panic(err)
}
t.Log("I should have a tunnel with evil now and there should not be a cached packet waiting for us") if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) return router.RouteAndExit
assertHostInfoPair(t, myUdpAddr, evilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl) }
return router.KeepRouting
})
//TODO: Assert pending hostmap - I should have a correct hostinfo for them now //TODO: Assert pending hostmap - I should have a correct hostinfo for them now
t.Log("Lets let the messages fly, this time we should have a tunnel with them") t.Log("My cached packet should be received by them")
r.OnceFrom(myControl)
r.OnceFrom(theirControl)
t.Log("I should now have a tunnel with them now and my original packet should get there")
r.RouteUntilAfterMsgType(myControl, 1, 0)
myCachedPacket := theirControl.GetFromTun(true) myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
t.Log("I should now have a proper tunnel with them") t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
t.Log("Lets make sure evil is still good") t.Log("Flush all packets from all controllers")
assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone //TODO: assert hostmaps for everyone
t.Log("Success!") t.Log("Success!")
//TODO: myControl is attempting to shut down 2 tunnels but is blocked on the udp txChan after the first close message myControl.Stop()
// what we really need here is a way to exit all the go routines loops (there are many) theirControl.Stop()
//myControl.Stop()
//theirControl.Stop()
} }
////TODO: We need to test lies both as the race winner and race loser func Test_Case1_Stage1Race(t *testing.T) {
//func TestManyWrongResponderHandshake(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
// ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1})
// theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
// myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 99})
// theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) // Put their info in our lighthouse and vice versa
// evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 1}) myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
// theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr)
// t.Log("Build a router so we don't have to reason who gets which packet")
// r := newRouter(myControl, theirControl, evilControl) // Build a router so we don't have to reason who gets which packet
// r := router.NewR(myControl, theirControl)
// t.Log("Lets add more than 10 evil addresses, this exceeds the hostinfo remotes limit")
// for i := 0; i < 10; i++ { // Start the servers
// addr := net.UDPAddr{IP: evilUdpAddr.IP, Port: evilUdpAddr.Port + i} myControl.Start()
// myControl.InjectLightHouseAddr(theirVpnIp, &addr) theirControl.Start()
// // We also need to tell our router about it
// r.AddRoute(addr.IP, uint16(addr.Port), evilControl) t.Log("Trigger a handshake to start on both me and them")
// } myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
// theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them"))
// // Start the servers
// myControl.Start() t.Log("Get both stage 1 handshake packets")
// theirControl.Start() myHsForThem := myControl.GetFromUDP(true)
// evilControl.Start() theirHsForMe := theirControl.GetFromUDP(true)
//
// t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)") t.Log("Now inject both stage 1 handshake packets")
// myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) myControl.InjectUDPPacket(theirHsForMe)
// theirControl.InjectUDPPacket(myHsForThem)
// t.Log("We need to spin until we get to the right remote for them") //TODO: they should win, grab their index for me and make sure I use it in the end.
// getOut := false
// injected := false t.Log("They should not have a stage 2 (won the race) but I should send one")
// for { theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
// t.Log("Routing for me and evil while we work through the bad ips")
// r.RouteExitFunc(myControl, func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType { t.Log("Route for me until I send a message packet to them")
// // We should stop routing right after we see a packet coming from us to them myControl.WaitForType(1, 0, theirControl)
// if *receiver == *theirControl {
// getOut = true t.Log("My cached packet should be received by them")
// return drainAndExit myCachedPacket := theirControl.GetFromTun(true)
// } assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
//
// // We need to poke our real ip in at some point, this is a well protected check looking for that moment t.Log("Route for them until I send a message packet to me")
// if *receiver == *evilControl { theirControl.WaitForType(1, 0, myControl)
// hi := myControl.GetHostInfoByVpnIP(ip2int(theirVpnIp), true)
// if !injected && len(hi.RemoteAddrs) == 1 { t.Log("Their cached packet should be received by me")
// t.Log("I am on my last ip for them, time to inject the real one into my lighthouse") theirCachedPacket := myControl.GetFromTun(true)
// myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80)
// injected = true
// } t.Log("Do a bidirectional tunnel test")
// return drainAndExit assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
// }
// myControl.Stop()
// return keepRouting theirControl.Stop()
// }) //TODO: assert hostmaps
// }
// if getOut {
// break //TODO: add a test with many lies
// }
//
// r.RouteForUntilAfterToAddr(evilControl, myUdpAddr, drainAndExit)
// }
//
// t.Log("I should have a tunnel with evil and them, evil should not have a cached packet")
// assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
// evilHostInfo := myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false)
// realEvilUdpAddr := &net.UDPAddr{IP: evilHostInfo.CurrentRemote.IP, Port: int(evilHostInfo.CurrentRemote.Port)}
//
// t.Log("Assert mine and evil's host pairs", evilUdpAddr, realEvilUdpAddr)
// assertHostInfoPair(t, myUdpAddr, realEvilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl)
//
// //t.Log("Draining everyones packets")
// //r.Drain(theirControl)
// //r.DrainAll(myControl, theirControl, evilControl)
// //
// //go func() {
// // for {
// // time.Sleep(10 * time.Millisecond)
// // t.Log(len(theirControl.GetUDPTxChan()))
// // t.Log(len(theirControl.GetTunTxChan()))
// // t.Log(len(myControl.GetUDPTxChan()))
// // t.Log(len(evilControl.GetUDPTxChan()))
// // t.Log("=====")
// // }
// //}()
//
// t.Log("I should have a tunnel with them now and my original packet should get there")
// r.RouteUntilAfterMsgType(myControl, 1, 0)
// myCachedPacket := theirControl.GetFromTun(true)
//
// t.Log("Got the cached packet, lets test the tunnel")
// assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
//
// t.Log("Testing tunnels with them")
// assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
// assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
//
// t.Log("Testing tunnels with evil")
// assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
//
// //TODO: assert hostmaps for everyone
//}

View File

@ -64,6 +64,9 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
"host": "any", "host": "any",
}}, }},
}, },
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": m{ "listen": m{
"host": udpAddr.IP.String(), "host": udpAddr.IP.String(),
"port": udpAddr.Port, "port": udpAddr.Port,

3
e2e/router/doc.go Normal file
View File

@ -0,0 +1,3 @@
package router
// This file exists to allow `go fmt` to traverse here on its own. The build tags were keeping it out before

View File

@ -5,6 +5,7 @@ package router
import ( import (
"fmt" "fmt"
"net" "net"
"reflect"
"strconv" "strconv"
"sync" "sync"
@ -28,18 +29,18 @@ type R struct {
sync.Mutex sync.Mutex
} }
type exitType int type ExitType int
const ( const (
// Keeps routing, the function will get called again on the next packet // Keeps routing, the function will get called again on the next packet
keepRouting exitType = 0 KeepRouting ExitType = 0
// Does not route this packet and exits immediately // Does not route this packet and exits immediately
exitNow exitType = 1 ExitNow ExitType = 1
// Routes this packet and exits immediately afterwards // Routes this packet and exits immediately afterwards
routeAndExit exitType = 2 RouteAndExit ExitType = 2
) )
type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
func NewR(controls ...*nebula.Control) *R { func NewR(controls ...*nebula.Control) *R {
r := &R{ r := &R{
@ -77,8 +78,8 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
// OnceFrom will route a single packet from sender then return // OnceFrom will route a single packet from sender then return
// If the router doesn't have the nebula controller for that address, we panic // If the router doesn't have the nebula controller for that address, we panic
func (r *R) OnceFrom(sender *nebula.Control) { func (r *R) OnceFrom(sender *nebula.Control) {
r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) exitType { r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
return routeAndExit return RouteAndExit
}) })
} }
@ -116,7 +117,6 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// - exitNow: the packet will not be routed and this call will return immediately // - exitNow: the packet will not be routed and this call will return immediately
// - routeAndExit: this call will return immediately after routing the last packet from sender // - routeAndExit: this call will return immediately after routing the last packet from sender
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
//TODO: is this RouteWhile?
func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
h := &nebula.Header{} h := &nebula.Header{}
for { for {
@ -136,16 +136,16 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
e := whatDo(p, receiver) e := whatDo(p, receiver)
switch e { switch e {
case exitNow: case ExitNow:
r.Unlock() r.Unlock()
return return
case routeAndExit: case RouteAndExit:
receiver.InjectUDPPacket(p) receiver.InjectUDPPacket(p)
r.Unlock() r.Unlock()
return return
case keepRouting: case KeepRouting:
receiver.InjectUDPPacket(p) receiver.InjectUDPPacket(p)
default: default:
@ -160,35 +160,135 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
// If the router doesn't have the nebula controller for that address, we panic // If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) { func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
h := &nebula.Header{} h := &nebula.Header{}
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType { r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
if err := h.Parse(p.Data); err != nil { if err := h.Parse(p.Data); err != nil {
panic(err) panic(err)
} }
if h.Type == msgType && h.Subtype == subType { if h.Type == msgType && h.Subtype == subType {
return routeAndExit return RouteAndExit
} }
return keepRouting return KeepRouting
}) })
} }
// RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
// finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
// If the router doesn't have the nebula controller for that address, we panic // If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish exitType) { func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
if finish == keepRouting { if finish == KeepRouting {
finish = routeAndExit finish = RouteAndExit
} }
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType { r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
return finish return finish
} }
return keepRouting return KeepRouting
}) })
} }
// RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from
// whatDo can return:
// - exitNow: the packet will not be routed and this call will return immediately
// - routeAndExit: this call will return immediately after routing the last packet from sender
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
sc := make([]reflect.SelectCase, len(r.controls))
cm := make([]*nebula.Control, len(r.controls))
i := 0
for _, c := range r.controls {
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(c.GetUDPTxChan()),
Send: reflect.Value{},
}
cm[i] = c
i++
}
for {
x, rx, _ := reflect.Select(sc)
r.Lock()
p := rx.Interface().(*nebula.UdpPacket)
outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil {
r.Unlock()
panic("Can't route for host: " + inAddr)
}
e := whatDo(p, receiver)
switch e {
case ExitNow:
r.Unlock()
return
case RouteAndExit:
receiver.InjectUDPPacket(p)
r.Unlock()
return
case KeepRouting:
receiver.InjectUDPPacket(p)
default:
panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
}
r.Unlock()
}
}
// FlushAll will route for every registered controller, exiting once there are no packets left to route
func (r *R) FlushAll() {
sc := make([]reflect.SelectCase, len(r.controls))
cm := make([]*nebula.Control, len(r.controls))
i := 0
for _, c := range r.controls {
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(c.GetUDPTxChan()),
Send: reflect.Value{},
}
cm[i] = c
i++
}
// Add a default case to exit when nothing is left to send
sc = append(sc, reflect.SelectCase{
Dir: reflect.SelectDefault,
Chan: reflect.Value{},
Send: reflect.Value{},
})
for {
x, rx, ok := reflect.Select(sc)
if !ok {
return
}
r.Lock()
p := rx.Interface().(*nebula.UdpPacket)
outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil {
r.Unlock()
panic("Can't route for host: " + inAddr)
}
r.Unlock()
}
}
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock // This is an internal router function, the caller must hold the lock
func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control { func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
@ -216,6 +316,5 @@ func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Con
return c return c
} }
//TODO: call receive hooks!
return r.controls[toAddr] return r.controls[toAddr]
} }

View File

@ -202,16 +202,16 @@ logging:
# Handshake Manger Settings # Handshake Manger Settings
#handshakes: #handshakes:
# Total time to try a handshake = sequence of `try_interval * retries` # Handshakes are sent to all known addresses at each interval with a linear backoff,
# With 100ms interval and 20 retries it is 23.5 seconds # Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout
# A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
#try_interval: 100ms #try_interval: 100ms
#retries: 20 #retries: 20
# wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
#wait_rotation: 5
# trigger_buffer is the size of the buffer channel for quickly sending handshakes # trigger_buffer is the size of the buffer channel for quickly sending handshakes
# after receiving the response for lighthouse queries # after receiving the response for lighthouse queries
#trigger_buffer: 64 #trigger_buffer: 64
# Nebula security group configuration # Nebula security group configuration
firewall: firewall:
conntrack: conntrack:

View File

@ -14,14 +14,10 @@ import (
// Sending is done by the handshake manager // Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
// This queries the lighthouse if we don't know a remote for the host // This queries the lighthouse if we don't know a remote for the host
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
// more quickly, effect is a quicker handshake.
if hostinfo.remote == nil { if hostinfo.remote == nil {
ips, err := f.lightHouse.Query(vpnIp, f) f.lightHouse.QueryServer(vpnIp, f)
if err != nil {
//l.Debugln(err)
}
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
} }
err := f.handshakeManager.AddIndexHostInfo(hostinfo) err := f.handshakeManager.AddIndexHostInfo(hostinfo)
@ -69,7 +65,6 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
hostinfo.HandshakePacket[0] = msg hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true hostinfo.HandshakeReady = true
hostinfo.handshakeStart = time.Now() hostinfo.handshakeStart = time.Now()
} }
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
@ -125,13 +120,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
hostinfo := &HostInfo{ hostinfo := &HostInfo{
ConnectionState: ci, ConnectionState: ci,
Remotes: []*udpAddr{},
localIndexId: myIndex, localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex, remoteIndexId: hs.Details.InitiatorIndex,
hostId: vpnIP, hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
} }
hostinfo.Lock()
defer hostinfo.Unlock()
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@ -182,16 +179,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
ci.peerCert = remoteCert ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey) ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes() hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
hostinfo.AddRemote(addr) hostinfo.SetRemote(addr)
hostinfo.ForcePromoteBest(f.hostMap.preferredRanges)
hostinfo.CreateRemoteCIDR(remoteCert) hostinfo.CreateRemoteCIDR(remoteCert)
hostinfo.Lock()
defer hostinfo.Unlock()
// Only overwrite existing record if we should win the handshake race // Only overwrite existing record if we should win the handshake race
overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP) overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
@ -214,6 +206,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and we didn't win // This means there was an existing tunnel and we didn't win
// handshake avoidance // handshake avoidance
//TODO: sprinkle the new protobuf stuff in here, send a reply to get the recv_errors flowing
//TODO: if not new send a test packet like old
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@ -234,6 +230,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)). WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
Error("Failed to add HostInfo due to localIndex collision") Error("Failed to add HostInfo due to localIndex collision")
return return
case ErrExistingHandshake:
// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Prevented a pending handshake race")
return
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
@ -286,6 +291,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Info("Handshake is already complete") Info("Handshake is already complete")
//TODO: evaluate addr for preference, if we handshook with a less preferred addr we can correct quickly here
// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets // We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
return false return false
} }
@ -334,17 +341,13 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
certName := remoteCert.Details.Name certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum() fingerprint, _ := remoteCert.Sha256Sum()
// Ensure the right host responded
if vpnIP != hostinfo.hostId { if vpnIP != hostinfo.hostId {
f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)). f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
WithField("udpAddr", addr).WithField("certName", certName). WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake") Info("Incorrect host responded to handshake")
if ho, _ := f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIP); ho != nil {
// We might have a pending tunnel to this host already, clear out that attempt since we have a tunnel now
f.handshakeManager.pendingHostMap.DeleteHostInfo(ho)
}
// Release our old handshake from pending, it should not continue // Release our old handshake from pending, it should not continue
f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
@ -354,26 +357,28 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
newHostInfo.Lock() newHostInfo.Lock()
// Block the current used address // Block the current used address
newHostInfo.unlockedBlockRemote(addr) newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)
// If this is an ongoing issue our previous hostmap will have some bad ips too // Get the correct remote list for the host we did handshake with
for _, v := range hostinfo.badRemotes { hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
newHostInfo.unlockedBlockRemote(v)
}
//TODO: this is me enabling tests
newHostInfo.ForcePromoteBest(f.hostMap.preferredRanges)
f.l.WithField("blockedUdpAddrs", newHostInfo.badRemotes).WithField("vpnIp", IntIp(vpnIP)). f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)).
WithField("remotes", newHostInfo.Remotes). WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes") Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient // Swap the packet store to benefit the original intended recipient
hostinfo.ConnectionState.queueLock.Lock()
newHostInfo.packetStore = hostinfo.packetStore newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{} hostinfo.packetStore = []*cachedPacket{}
hostinfo.ConnectionState.queueLock.Unlock()
// Set the current hostId to the new vpnIp // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.hostId = vpnIP hostinfo.hostId = vpnIP
f.sendCloseTunnel(hostinfo)
newHostInfo.Unlock() newHostInfo.Unlock()
return true
} }
// Mark packet 2 as seen so it doesn't show up as missed // Mark packet 2 as seen so it doesn't show up as missed

View File

@ -12,12 +12,8 @@ import (
) )
const ( const (
// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
// With 100ms interval and 20 retries is 23.5 seconds
DefaultHandshakeTryInterval = time.Millisecond * 100 DefaultHandshakeTryInterval = time.Millisecond * 100
DefaultHandshakeRetries = 20 DefaultHandshakeRetries = 10
// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
DefaultHandshakeWaitRotation = 5
DefaultHandshakeTriggerBuffer = 64 DefaultHandshakeTriggerBuffer = 64
) )
@ -25,7 +21,6 @@ var (
defaultHandshakeConfig = HandshakeConfig{ defaultHandshakeConfig = HandshakeConfig{
tryInterval: DefaultHandshakeTryInterval, tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries, retries: DefaultHandshakeRetries,
waitRotation: DefaultHandshakeWaitRotation,
triggerBuffer: DefaultHandshakeTriggerBuffer, triggerBuffer: DefaultHandshakeTriggerBuffer,
} }
) )
@ -33,7 +28,6 @@ var (
type HandshakeConfig struct { type HandshakeConfig struct {
tryInterval time.Duration tryInterval time.Duration
retries int retries int
waitRotation int
triggerBuffer int triggerBuffer int
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
@ -45,15 +39,12 @@ type HandshakeManager struct {
lightHouse *LightHouse lightHouse *LightHouse
outside *udpConn outside *udpConn
config HandshakeConfig config HandshakeConfig
OutboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
l *logrus.Logger
// can be used to trigger outbound handshake for the given vpnIP // can be used to trigger outbound handshake for the given vpnIP
trigger chan uint32 trigger chan uint32
OutboundHandshakeTimer *SystemTimerWheel
InboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
l *logrus.Logger
} }
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager { func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
@ -62,14 +53,9 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
mainHostMap: mainHostMap, mainHostMap: mainHostMap,
lightHouse: lightHouse, lightHouse: lightHouse,
outside: outside, outside: outside,
config: config, config: config,
trigger: make(chan uint32, config.triggerBuffer), trigger: make(chan uint32, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
messageMetrics: config.messageMetrics, messageMetrics: config.messageMetrics,
l: l, l: l,
} }
@ -84,7 +70,6 @@ func (c *HandshakeManager) Run(f EncWriter) {
c.handleOutbound(vpnIP, f, true) c.handleOutbound(vpnIP, f, true)
case now := <-clockSource: case now := <-clockSource:
c.NextOutboundHandshakeTimerTick(now, f) c.NextOutboundHandshakeTimerTick(now, f)
c.NextInboundHandshakeTimerTick(now)
} }
} }
} }
@ -109,91 +94,92 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
hostinfo.Lock() hostinfo.Lock()
defer hostinfo.Unlock() defer hostinfo.Unlock()
// If we haven't finished the handshake and we haven't hit max retries, query // We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
// lighthouse and then send the handshake packet again. if hostinfo.HandshakeComplete {
if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete { // Ensure we don't exist in the pending hostmap anymore since we have completed
if hostinfo.remote == nil { c.pendingHostMap.DeleteHostInfo(hostinfo)
// We continue to query the lighthouse because hosts may
// come online during handshake retries. If the query
// succeeds (no error), add the lighthouse info to hostinfo
ips := c.lightHouse.QueryCache(vpnIP)
// If we have no responses yet, or only one IP (the host hadn't
// finished reporting its own IPs yet), then send another query to
// the LH.
if len(ips) <= 1 {
ips, err = c.lightHouse.Query(vpnIP, f)
}
if err == nil {
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
}
} else if lighthouseTriggered {
// We were triggered by a lighthouse HostQueryReply packet, but
// we have already picked a remote for this host (this can happen
// if we are configured with multiple lighthouses). So we can skip
// this trigger and let the timerwheel handle the rest of the
// process
return return
} }
hostinfo.HandshakeCounter++ // Check if we have a handshake packet to transmit yet
if !hostinfo.HandshakeReady {
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through // There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
// all the others until we can stand up a connection. // Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
if hostinfo.HandshakeCounter > c.config.waitRotation { c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
hostinfo.rotateRemote() return
} }
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation // If we are out of time, clean up
if hostinfo.HandshakeReady && hostinfo.remote != nil { if hostinfo.HandshakeCounter >= c.config.retries {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1) hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil {
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId). WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId). WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
Info("Handshake timed out")
//TODO: emit metrics
c.pendingHostMap.DeleteHostInfo(hostinfo)
return
}
// We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific
// optimization for a fast lighthouse reply
//TODO: it would feel better to do this once, anytime, as our delay increases over time
if lighthouseTriggered && hostinfo.HandshakeCounter > 0 {
// If we didn't return here a lighthouse could cause us to aggressively send handshakes
return
}
// Get a remotes object if we don't already have one.
// This is mainly to protect us as this should never be the case
if hostinfo.remotes == nil {
hostinfo.remotes = c.lightHouse.QueryCache(vpnIP)
}
//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// the learned public ip for them. Query again to short circuit the promotion counter
c.lightHouse.QueryServer(vpnIP, f)
}
// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udpAddr
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
hostinfo.logger(c.l).WithField("udpAddr", addr).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message") WithError(err).Error("Failed to send handshake message")
} else { } else {
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should sentTo = append(sentTo, addr)
// keep the real packet struct around for logging purposes }
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote). })
hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId). WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent") Info("Handshake message sent")
}
}
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try // Increment the counter to increase our delay, linear backoff
hostinfo.HandshakeCounter++
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered { if !lighthouseTriggered {
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter)) //TODO: feel like we dupe handshake real fast in a tight loop, why?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
} }
} else {
c.pendingHostMap.DeleteHostInfo(hostinfo)
}
}
func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
c.InboundHandshakeTimer.advance(now)
for {
ep := c.InboundHandshakeTimer.Purge()
if ep == nil {
break
}
index := ep.(uint32)
c.pendingHostMap.DeleteIndex(index)
}
} }
func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo { func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP) hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
// We lock here and use an array to insert items to prevent locking the // We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map // main receive thread for very long by waiting to add items to the pending map
//TODO: what lock?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval) c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
return hostinfo return hostinfo
@ -203,6 +189,7 @@ var (
ErrExistingHostInfo = errors.New("existing hostinfo") ErrExistingHostInfo = errors.New("existing hostinfo")
ErrAlreadySeen = errors.New("already seen") ErrAlreadySeen = errors.New("already seen")
ErrLocalIndexCollision = errors.New("local index collision") ErrLocalIndexCollision = errors.New("local index collision")
ErrExistingHandshake = errors.New("existing handshake")
) )
// CheckAndComplete checks for any conflicts in the main and pending hostmap // CheckAndComplete checks for any conflicts in the main and pending hostmap
@ -217,17 +204,21 @@ var (
// ErrLocalIndexCollision if we already have an entry in the main or pending // ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId. // hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) { func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
c.pendingHostMap.RLock() c.pendingHostMap.Lock()
defer c.pendingHostMap.RUnlock() defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock() c.mainHostMap.Lock()
defer c.mainHostMap.Unlock() defer c.mainHostMap.Unlock()
// Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
if found && existingHostInfo != nil { if found && existingHostInfo != nil {
// Is it just a delayed handshake packet?
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
return existingHostInfo, ErrAlreadySeen return existingHostInfo, ErrAlreadySeen
} }
if !overwrite { if !overwrite {
// It's a new handshake and we lost the race
return existingHostInfo, ErrExistingHostInfo return existingHostInfo, ErrExistingHostInfo
} }
} }
@ -237,6 +228,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// We have a collision, but for a different hostinfo // We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision return existingIndex, ErrLocalIndexCollision
} }
existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId] existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
if found && existingIndex != hostinfo { if found && existingIndex != hostinfo {
// We have a collision, but for a different hostinfo // We have a collision, but for a different hostinfo
@ -252,7 +244,24 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
// Check if we are also handshaking with this vpn ip
pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId]
if found && pendingHostInfo != nil {
if !overwrite {
// We won, let our pending handshake win
return pendingHostInfo, ErrExistingHandshake
}
// We lost, take this handshake and move any cached packets over so they get sent
pendingHostInfo.ConnectionState.queueLock.Lock()
hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...)
c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo)
pendingHostInfo.ConnectionState.queueLock.Unlock()
pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel")
}
if existingHostInfo != nil { if existingHostInfo != nil {
hostinfo.logger(c.l).Info("Race lost, taking new handshake")
// We are going to overwrite this entry, so remove the old references // We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId) delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
@ -267,6 +276,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// won't have a localIndexId collision because we already have an entry in the // won't have a localIndexId collision because we already have an entry in the
// pendingHostMap // pendingHostMap
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock() c.mainHostMap.Lock()
defer c.mainHostMap.Unlock() defer c.mainHostMap.Unlock()
@ -288,6 +299,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
} }
c.mainHostMap.addHostInfo(hostinfo, f) c.mainHostMap.addHostInfo(hostinfo, f)
c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
} }
// AddIndexHostInfo generates a unique localIndexId for this HostInfo // AddIndexHostInfo generates a unique localIndexId for this HostInfo
@ -359,3 +371,7 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
} }
return index, nil return index, nil
} }
func hsTimeout(tries int, interval time.Duration) time.Duration {
return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval)))
}

View File

@ -8,66 +8,12 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
var ips []uint32
func Test_NewHandshakeManagerIndex(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
var indexes = make([]uint32, 4)
var hostinfo = make([]*HostInfo, len(indexes))
for i := range indexes {
hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}}
}
// Add four indexes
for i := range indexes {
err := blah.AddIndexHostInfo(hostinfo[i])
assert.NoError(t, err)
indexes[i] = hostinfo[i].localIndexId
blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10)
}
// Confirm they are in the pending index list
for _, v := range indexes {
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
}
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Indexes, 0)
// Jump ahead 8 seconds
for i := 1; i <= DefaultHandshakeRetries; i++ {
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
// Confirm they are still in the pending index list
for _, v := range indexes {
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
}
// Jump ahead 4 more seconds
next_tick := now.Add(12 * time.Second)
blah.NextInboundHandshakeTimerTick(next_tick)
// Confirm they have been removed
for _, v := range indexes {
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(v))
}
}
func Test_NewHandshakeManagerVpnIP(t *testing.T) { func Test_NewHandshakeManagerVpnIP(t *testing.T) {
l := NewTestLogger() l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))} ip := ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
@ -77,39 +23,30 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
now := time.Now() now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
// Add four "IPs" - which are just uint32s i := blah.AddVpnIP(ip)
for _, v := range ips { i.remotes = NewRemoteList()
blah.AddVpnIP(v) i.HandshakeReady = true
}
// Adding something to pending should not affect the main hostmap // Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0) assert.Len(t, mainHM.Hosts, 0)
// Confirm they are in the pending index list
for _, v := range ips {
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
}
// Jump ahead `HandshakeRetries` ticks // Confirm they are in the pending index list
cumulative := time.Duration(0) assert.Contains(t, blah.pendingHostMap.Hosts, ip)
for i := 0; i <= DefaultHandshakeRetries+1; i++ {
cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1 // Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
next_tick := now.Add(cumulative) for i := 1; i <= DefaultHandshakeRetries+1; i++ {
//l.Infoln(next_tick) now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
blah.NextOutboundHandshakeTimerTick(next_tick, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
} }
// Confirm they are still in the pending index list // Confirm they are still in the pending index list
for _, v := range ips { assert.Contains(t, blah.pendingHostMap.Hosts, ip)
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
} // Tick 1 more time, a minute will certainly flush it out
// Jump ahead 1 more second blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
// Confirm they have been removed // Confirm they have been removed
for _, v := range ips { assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(v))
}
} }
func Test_NewHandshakeManagerTrigger(t *testing.T) { func Test_NewHandshakeManagerTrigger(t *testing.T) {
@ -121,7 +58,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := &LightHouse{} lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l}
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
@ -130,28 +67,25 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
blah.AddVpnIP(ip) hi := blah.AddVpnIP(ip)
hi.HandshakeReady = true
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
// Trigger the same method the channel will // Trigger the same method the channel will but, this should set our remotes pointer
blah.handleOutbound(ip, mw, true) blah.handleOutbound(ip, mw, true)
assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt")
assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer")
// Make sure the trigger doesn't schedule another timer entry // Make sure the trigger doesn't double schedule the timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.pendingHostMap.Hosts[ip]
assert.Nil(t, hi.remote)
uaddr := NewUDPAddrFromString("10.1.1.1:4242") uaddr := NewUDPAddrFromString("10.1.1.1:4242")
lh.addrMap = map[uint32]*ip4And6{} hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
lh.addrMap[ip] = &ip4And6{
v4: []*Ip4AndPort{NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))},
v6: []*Ip6AndPort{},
}
// This should trigger the hostmap to populate the hostinfo // We now have remotes but only the first trigger should have pushed things forward
blah.handleOutbound(ip, mw, true) blah.handleOutbound(ip, mw, true)
assert.NotNil(t, hi.remote) assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt")
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
} }
@ -166,100 +100,9 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
return c return c
} }
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
hostinfo := blah.AddVpnIP(vpnIP)
// Pretned we have an index too
err := blah.AddIndexHostInfo(hostinfo)
assert.NoError(t, err)
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
assert.NotZero(t, hostinfo.localIndexId)
assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId)
// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
// but not main hostmap
cumulative := time.Duration(0)
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
cumulative += DefaultHandshakeTryInterval * time.Duration(i)
next_tick := now.Add(cumulative)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
}
/*
for i := 0; i <= HandshakeRetries+1; i++ {
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick)
}
*/
/*
for i := 0; i <= HandshakeRetries+1; i++ {
next_tick := now.Add(time.Duration(i) * time.Second)
blah.NextOutboundHandshakeTimerTick(next_tick)
}
*/
/*
cumulative += HandshakeTryInterval*time.Duration(HandshakeRetries) + 3
next_tick := now.Add(cumulative)
l.Infoln(cumulative, next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick)
*/
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
}
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
hostinfo := &HostInfo{ConnectionState: &ConnectionState{}}
err := blah.AddIndexHostInfo(hostinfo)
assert.NoError(t, err)
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
// Pretned we have an index too
blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
blah.NextInboundHandshakeTimerTick(next_tick)
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId))
}
type mockEncWriter struct { type mockEncWriter struct {
} }
func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
return return
} }
func (mw *mockEncWriter) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
return
}

View File

@ -1,7 +1,6 @@
package nebula package nebula
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -16,6 +15,7 @@ import (
//const ProbeLen = 100 //const ProbeLen = 100
const PromoteEvery = 1000 const PromoteEvery = 1000
const ReQueryEvery = 5000
const MaxRemotes = 10 const MaxRemotes = 10
// How long we should prevent roaming back to the previous IP. // How long we should prevent roaming back to the previous IP.
@ -30,7 +30,6 @@ type HostMap struct {
Hosts map[uint32]*HostInfo Hosts map[uint32]*HostInfo
preferredRanges []*net.IPNet preferredRanges []*net.IPNet
vpnCIDR *net.IPNet vpnCIDR *net.IPNet
defaultRoute uint32
unsafeRoutes *CIDRTree unsafeRoutes *CIDRTree
metricsEnabled bool metricsEnabled bool
l *logrus.Logger l *logrus.Logger
@ -40,25 +39,21 @@ type HostInfo struct {
sync.RWMutex sync.RWMutex
remote *udpAddr remote *udpAddr
Remotes []*udpAddr remotes *RemoteList
promoteCounter uint32 promoteCounter uint32
ConnectionState *ConnectionState ConnectionState *ConnectionState
handshakeStart time.Time handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool HandshakeReady bool //todo: being in the manager means you are ready
HandshakeCounter int HandshakeCounter int //todo: another handshake manager entry
HandshakeComplete bool HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready
HandshakePacket map[uint8][]byte HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry
packetStore []*cachedPacket packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32 remoteIndexId uint32
localIndexId uint32 localIndexId uint32
hostId uint32 hostId uint32
recvError int recvError int
remoteCidr *CIDRTree remoteCidr *CIDRTree
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
badRemotes []*udpAddr
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
// with a handshake // with a handshake
@ -88,7 +83,6 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
Hosts: h, Hosts: h,
preferredRanges: preferredRanges, preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR, vpnCIDR: vpnCIDR,
defaultRoute: 0,
unsafeRoutes: NewCIDRTree(), unsafeRoutes: NewCIDRTree(),
l: l, l: l,
} }
@ -131,7 +125,6 @@ func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
if _, ok := hm.Hosts[vpnIP]; !ok { if _, ok := hm.Hosts[vpnIP]; !ok {
hm.RUnlock() hm.RUnlock()
h = &HostInfo{ h = &HostInfo{
Remotes: []*udpAddr{},
promoteCounter: 0, promoteCounter: 0,
hostId: vpnIP, hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
@ -239,7 +232,11 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) { func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
hm.Lock() hm.Lock()
defer hm.Unlock()
hm.unlockedDeleteHostInfo(hostinfo)
}
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
// Check if this same hostId is in the hostmap with a different instance. // Check if this same hostId is in the hostmap with a different instance.
// This could happen if we have an entry in the pending hostmap with different // This could happen if we have an entry in the pending hostmap with different
// index values than the one in the main hostmap. // index values than the one in the main hostmap.
@ -262,7 +259,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
if len(hm.RemoteIndexes) == 0 { if len(hm.RemoteIndexes) == 0 {
hm.RemoteIndexes = map[uint32]*HostInfo{} hm.RemoteIndexes = map[uint32]*HostInfo{}
} }
hm.Unlock()
if hm.l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
@ -294,30 +290,6 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
} }
} }
func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
hm.Lock()
i, v := hm.Hosts[vpnIp]
if v {
i.AddRemote(remote)
} else {
i = &HostInfo{
Remotes: []*udpAddr{remote.Copy()},
promoteCounter: 0,
hostId: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
}
i.remote = i.Remotes[0]
hm.Hosts[vpnIp] = i
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap remote ip added")
}
}
i.ForcePromoteBest(hm.preferredRanges)
hm.Unlock()
return i
}
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) { func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, nil) return hm.queryVpnIP(vpnIp, nil)
} }
@ -331,12 +303,13 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) { func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
hm.RLock() hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok { if h, ok := hm.Hosts[vpnIp]; ok {
if promoteIfce != nil { // Do not attempt promotion if you are a lighthouse
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
h.TryPromoteBest(hm.preferredRanges, promoteIfce) h.TryPromoteBest(hm.preferredRanges, promoteIfce)
} }
//fmt.Println(h.remote)
hm.RUnlock() hm.RUnlock()
return h, nil return h, nil
} else { } else {
//return &net.UDPAddr{}, nil, errors.New("Unable to find host") //return &net.UDPAddr{}, nil, errors.New("Unable to find host")
hm.RUnlock() hm.RUnlock()
@ -362,11 +335,8 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
// We already have the hm Lock when this is called, so make sure to not call // We already have the hm Lock when this is called, so make sure to not call
// any other methods that might try to grab it again // any other methods that might try to grab it again
func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
remoteCert := hostinfo.ConnectionState.peerCert
ip := ip2int(remoteCert.Details.Ips[0].IP)
f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote)
if f.serveDns { if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
} }
@ -381,38 +351,21 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
} }
} }
func (hm *HostMap) ClearRemotes(vpnIP uint32) { // punchList assembles a list of all non nil RemoteList pointer entries in this hostmap
hm.Lock() // The caller can then do the its work outside of the read lock
i := hm.Hosts[vpnIP] func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
if i == nil {
hm.Unlock()
return
}
i.remote = nil
i.Remotes = nil
hm.Unlock()
}
func (hm *HostMap) SetDefaultRoute(ip uint32) {
hm.defaultRoute = ip
}
func (hm *HostMap) PunchList() []*udpAddr {
var list []*udpAddr
hm.RLock() hm.RLock()
defer hm.RUnlock()
for _, v := range hm.Hosts { for _, v := range hm.Hosts {
for _, r := range v.Remotes { if v.remotes != nil {
list = append(list, r) rl = append(rl, v.remotes)
} }
// if h, ok := hm.Hosts[vpnIp]; ok {
// hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false)
//fmt.Println(h.remote)
// }
} }
hm.RUnlock() return rl
return list
} }
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
func (hm *HostMap) Punchy(conn *udpConn) { func (hm *HostMap) Punchy(conn *udpConn) {
var metricsTxPunchy metrics.Counter var metricsTxPunchy metrics.Counter
if hm.metricsEnabled { if hm.metricsEnabled {
@ -421,13 +374,18 @@ func (hm *HostMap) Punchy(conn *udpConn) {
metricsTxPunchy = metrics.NilCounter{} metricsTxPunchy = metrics.NilCounter{}
} }
var remotes []*RemoteList
b := []byte{1} b := []byte{1}
for { for {
for _, addr := range hm.PunchList() { remotes = hm.punchList(remotes[:0])
for _, rl := range remotes {
//TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better
for _, addr := range rl.CopyAddrs(hm.preferredRanges) {
metricsTxPunchy.Inc(1) metricsTxPunchy.Inc(1)
conn.WriteTo(b, addr) conn.WriteTo(b, addr)
} }
time.Sleep(time.Second * 30) }
time.Sleep(time.Second * 10)
} }
} }
@ -438,38 +396,15 @@ func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
} }
} }
func (i *HostInfo) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"remote": i.remote,
"remotes": i.Remotes,
"promote_counter": i.promoteCounter,
"connection_state": i.ConnectionState,
"handshake_start": i.handshakeStart,
"handshake_ready": i.HandshakeReady,
"handshake_counter": i.HandshakeCounter,
"handshake_complete": i.HandshakeComplete,
"handshake_packet": i.HandshakePacket,
"packet_store": i.packetStore,
"remote_index": i.remoteIndexId,
"local_index": i.localIndexId,
"host_id": int2ip(i.hostId),
"receive_errors": i.recvError,
"last_roam": i.lastRoam,
"last_roam_remote": i.lastRoamRemote,
})
}
func (i *HostInfo) BindConnectionState(cs *ConnectionState) { func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
i.ConnectionState = cs i.ConnectionState = cs
} }
// TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
if i.remote == nil { c := atomic.AddUint32(&i.promoteCounter, 1)
i.ForcePromoteBest(preferredRanges) if c%PromoteEvery == 0 {
return
}
if atomic.AddUint32(&i.promoteCounter, 1)%PromoteEvery == 0 {
// return early if we are already on a preferred remote // return early if we are already on a preferred remote
rIP := i.remote.IP rIP := i.remote.IP
for _, l := range preferredRanges { for _, l := range preferredRanges {
@ -478,87 +413,21 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
} }
} }
// We re-query the lighthouse periodically while sending packets, so i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
// check for new remotes in our local lighthouse cache if addr == nil || !preferred {
ips := ifce.lightHouse.QueryCache(i.hostId) return
for _, ip := range ips {
i.AddRemote(ip)
} }
best, preferred := i.getBestRemote(preferredRanges)
if preferred && !best.Equals(i.remote) {
// Try to send a test packet to that host, this should // Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes // cause it to detect a roaming event and switch remotes
ifce.send(test, testRequest, i.ConnectionState, i, best, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} })
}
}
func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) {
best, _ := i.getBestRemote(preferredRanges)
if best != nil {
i.remote = best
}
}
func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) {
if len(i.Remotes) > 0 {
for _, r := range i.Remotes {
for _, l := range preferredRanges {
if l.Contains(r.IP) {
return r, true
}
} }
if best == nil || !PrivateIP(r.IP) { // Re query our lighthouses for new remotes occasionally
best = r if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
ifce.lightHouse.QueryServer(i.hostId, ifce)
} }
/*
for _, r := range i.Remotes {
// Must have > 80% probe success to be considered.
//fmt.Println("GRADE:", r.addr.IP, r.Grade())
if r.Grade() > float64(.8) {
if localToMe.Contains(r.addr.IP) == true {
best = r.addr
break
//i.remote = i.Remotes[c].addr
} else {
//}
}
*/
}
return best, false
}
return nil, false
}
// rotateRemote will move remote to the next ip in the list of remote ips for this host
// This is different than PromoteBest in that what is algorithmically best may not actually work.
// Only known use case is when sending a stage 0 handshake.
// It may be better to just send stage 0 handshakes to all known ips and sort it out in the receiver.
func (i *HostInfo) rotateRemote() {
// We have 0, can't rotate
if len(i.Remotes) < 1 {
return
}
if i.remote == nil {
i.remote = i.Remotes[0]
return
}
// We want to look at all but the very last entry since that is handled at the end
for x := 0; x < len(i.Remotes)-1; x++ {
// Find our current position and move to the next one in the list
if i.Remotes[x].Equals(i.remote) {
i.remote = i.Remotes[x+1]
return
}
}
// Our current position was likely the last in the list, start over at 0
i.remote = i.Remotes[0]
} }
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) { func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
@ -607,23 +476,13 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger) {
} }
} }
i.badRemotes = make([]*udpAddr, 0) i.remotes.ResetBlockedRemotes()
i.packetStore = make([]*cachedPacket, 0) i.packetStore = make([]*cachedPacket, 0)
i.ConnectionState.ready = true i.ConnectionState.ready = true
i.ConnectionState.queueLock.Unlock() i.ConnectionState.queueLock.Unlock()
i.ConnectionState.certState = nil i.ConnectionState.certState = nil
} }
func (i *HostInfo) CopyRemotes() []*udpAddr {
i.RLock()
rc := make([]*udpAddr, len(i.Remotes), len(i.Remotes))
for x, addr := range i.Remotes {
rc[x] = addr.Copy()
}
i.RUnlock()
return rc
}
func (i *HostInfo) GetCert() *cert.NebulaCertificate { func (i *HostInfo) GetCert() *cert.NebulaCertificate {
if i.ConnectionState != nil { if i.ConnectionState != nil {
return i.ConnectionState.peerCert return i.ConnectionState.peerCert
@ -631,58 +490,12 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
return nil return nil
} }
func (i *HostInfo) AddRemote(remote *udpAddr) *udpAddr {
if i.unlockedIsBadRemote(remote) {
return i.remote
}
for _, r := range i.Remotes {
if r.Equals(remote) {
return r
}
}
// Trim this down if necessary
if len(i.Remotes) > MaxRemotes {
i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:]
}
rc := remote.Copy()
i.Remotes = append(i.Remotes, rc)
return rc
}
func (i *HostInfo) SetRemote(remote *udpAddr) { func (i *HostInfo) SetRemote(remote *udpAddr) {
i.remote = i.AddRemote(remote) // We copy here because we likely got this remote from a source that reuses the object
} if !i.remote.Equals(remote) {
i.remote = remote.Copy()
func (i *HostInfo) unlockedBlockRemote(remote *udpAddr) { i.remotes.LearnRemote(i.hostId, remote.Copy())
if !i.unlockedIsBadRemote(remote) {
// We copy here because we are taking something else's memory and we can't trust everything
i.badRemotes = append(i.badRemotes, remote.Copy())
} }
for k, v := range i.Remotes {
if v.Equals(remote) {
i.Remotes[k] = i.Remotes[len(i.Remotes)-1]
i.Remotes = i.Remotes[:len(i.Remotes)-1]
return
}
}
}
func (i *HostInfo) unlockedIsBadRemote(remote *udpAddr) bool {
for _, v := range i.badRemotes {
if v.Equals(remote) {
return true
}
}
return false
}
func (i *HostInfo) ClearRemotes() {
i.remote = nil
i.Remotes = []*udpAddr{}
} }
func (i *HostInfo) ClearConnectionState() { func (i *HostInfo) ClearConnectionState() {
@ -805,13 +618,3 @@ func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
} }
return &ips return &ips
} }
func PrivateIP(ip net.IP) bool {
//TODO: Private for ipv6 or just let it ride?
private := false
_, private24BitBlock, _ := net.ParseCIDR("10.0.0.0/8")
_, private20BitBlock, _ := net.ParseCIDR("172.16.0.0/12")
_, private16BitBlock, _ := net.ParseCIDR("192.168.0.0/16")
private = private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
return private
}

View File

@ -1,169 +1 @@
package nebula package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
/*
func TestHostInfoDestProbe(t *testing.T) {
a, _ := net.ResolveUDPAddr("udp", "1.0.0.1:22222")
d := NewHostInfoDest(a)
// 999 probes that all return should give a 100% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
d.ProbeReceived(meh)
}
assert.Equal(t, d.Grade(), float64(1))
// 999 probes of which only half return should give a 50% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%2 == 0 {
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.5))
// 999 probes of which none return should give a 0% success rate
for i := 0; i < 999; i++ {
d.Probe()
}
assert.Equal(t, d.Grade(), float64(0))
// 999 probes of which only 1/4 return should give a 25% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%4 == 0 {
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.25))
// 999 probes of which only half return and are duplicates should give a 50% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%2 == 0 {
d.ProbeReceived(meh)
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.5))
// 999 probes of which only way old replies return should give a 0% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
d.ProbeReceived(meh - 101)
}
assert.Equal(t, d.Grade(), float64(0))
}
*/
func TestHostmap(t *testing.T) {
l := NewTestLogger()
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
myNets := []*net.IPNet{myNet}
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap(l, "test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
// There should be three remotes in the host map
assert.Equal(t, 3, len(info.Remotes))
// Adding an identical remote should not change the count
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
assert.Equal(t, 3, len(info.Remotes))
// Adding a fresh remote should add one
y = NewUDPAddrFromString("10.18.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
assert.Equal(t, 4, len(info.Remotes))
// Query and reference remote should get the first one (and not nil)
info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
assert.NotNil(t, info.remote)
// Promotion should ensure that the best remote is chosen (y)
info.ForcePromoteBest(myNets)
assert.True(t, myNet.Contains(info.remote.IP))
}
func TestHostmapdebug(t *testing.T) {
l := NewTestLogger()
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap(l, "test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
//t.Errorf("%s", m.DebugRemotes(1))
}
func TestHostMap_rotateRemote(t *testing.T) {
h := HostInfo{}
// 0 remotes, no panic
h.rotateRemote()
assert.Nil(t, h.remote)
// 1 remote, no panic
h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 1}, 0))
h.rotateRemote()
assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 1})
h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 2}, 0))
h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 3}, 0))
h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 4}, 0))
//TODO: ensure we are copying and not storing the slice!
// Rotate through those 3
h.rotateRemote()
assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 2})
h.rotateRemote()
assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 3})
h.rotateRemote()
assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 4}, Port: 0})
// Finally, we should start over
h.rotateRemote()
assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 1}, Port: 0})
}
func BenchmarkHostmappromote2(b *testing.B) {
l := NewTestLogger()
for n := 0; n < b.N; n++ {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap(l, "test", myNet, preferredRanges)
y := NewUDPAddrFromString("10.128.0.3:11111")
a := NewUDPAddrFromString("10.127.0.3:11111")
g := NewUDPAddrFromString("1.0.0.1:22222")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
}
}

View File

@ -54,10 +54,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache) dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
if dropReason == nil { if dropReason == nil {
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
if f.lightHouse != nil && mc%5000 == 0 {
f.lightHouse.Query(fwPacket.RemoteIP, f)
}
} else if f.l.Level >= logrus.DebugLevel { } else if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l). hostinfo.logger(f.l).
@ -84,15 +81,13 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
hostinfo = f.handshakeManager.AddVpnIP(vpnIp) hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
} }
} }
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
if ci != nil && ci.eKey != nil && ci.ready { if ci != nil && ci.eKey != nil && ci.ready {
return hostinfo return hostinfo
} }
// Handshake is not ready, we need to grab the lock now before we start // Handshake is not ready, we need to grab the lock now before we start the handshake process
// the handshake process
hostinfo.Lock() hostinfo.Lock()
defer hostinfo.Unlock() defer hostinfo.Unlock()
@ -150,10 +145,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
return return
} }
messageCounter := f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0) f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
if f.lightHouse != nil && messageCounter%5000 == 0 {
f.lightHouse.Query(fp.RemoteIP, f)
}
} }
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
@ -187,50 +179,15 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out) f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
} }
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
}
return
}
if hostInfo.ConnectionState.ready == false {
// Because we might be sending stored packets, lock here to stop new things going to
// the packet queue.
hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll)
hostInfo.ConnectionState.queueLock.Unlock()
return
}
hostInfo.ConnectionState.queueLock.Unlock()
}
f.sendMessageToAll(t, st, hostInfo, p, nb, out)
return
}
func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, b []byte) {
hostInfo.RLock()
for _, r := range hostInfo.Remotes {
f.send(t, st, hostInfo.ConnectionState, hostInfo, r, p, nb, b)
}
hostInfo.RUnlock()
}
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) { func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
} }
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) uint64 { func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) {
if ci.eKey == nil { if ci.eKey == nil {
//TODO: log warning //TODO: log warning
return 0 return
} }
var err error var err error
@ -262,7 +219,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
WithField("udpAddr", remote).WithField("counter", c). WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c). WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet") Error("Failed to encrypt outgoing packet")
return c return
} }
err = f.writers[q].WriteTo(out, remote) err = f.writers[q].WriteTo(out, remote)
@ -270,7 +227,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
hostinfo.logger(f.l).WithError(err). hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet") WithField("udpAddr", remote).Error("Failed to write outgoing packet")
} }
return c return
} }
func isMulticast(ip uint32) bool { func isMulticast(ip uint32) bool {

View File

@ -13,26 +13,11 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
//TODO: nodes are roaming lighthouses, this is bad. How are they learning? //TODO: nodes are roaming lighthouses, this is bad. How are they learning?
var ErrHostNotKnown = errors.New("host not known") var ErrHostNotKnown = errors.New("host not known")
// The maximum number of ip addresses to store for a given vpnIp per address family
const maxAddrs = 10
type ip4And6 struct {
//TODO: adding a lock here could allow us to release the lock on lh.addrMap quicker
// v4 and v6 store addresses that have been self reported by the client in a server or where all addresses are stored on a client
v4 []*Ip4AndPort
v6 []*Ip6AndPort
// Learned addresses are ones that a client does not know about but a lighthouse learned from as a result of the received packet
// This is only used if you are a lighthouse server
learnedV4 []*Ip4AndPort
learnedV6 []*Ip6AndPort
}
type LightHouse struct { type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps sync.RWMutex //Because we concurrently read and write to our maps
@ -42,7 +27,8 @@ type LightHouse struct {
punchConn *udpConn punchConn *udpConn
// Local cache of answers from light houses // Local cache of answers from light houses
addrMap map[uint32]*ip4And6 // map of vpn Ip to answers
addrMap map[uint32]*RemoteList
// filters remote addresses allowed for each host // filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and // - When we are a lighthouse, this filters what addresses we store and
@ -81,7 +67,7 @@ func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, i
amLighthouse: amLighthouse, amLighthouse: amLighthouse,
myVpnIp: ip2int(myVpnIpNet.IP), myVpnIp: ip2int(myVpnIpNet.IP),
myVpnZeros: uint32(32 - ones), myVpnZeros: uint32(32 - ones),
addrMap: make(map[uint32]*ip4And6), addrMap: make(map[uint32]*RemoteList),
nebulaPort: nebulaPort, nebulaPort: nebulaPort,
lighthouses: make(map[uint32]struct{}), lighthouses: make(map[uint32]struct{}),
staticList: make(map[uint32]struct{}), staticList: make(map[uint32]struct{}),
@ -130,23 +116,29 @@ func (lh *LightHouse) ValidateLHStaticEntries() error {
return nil return nil
} }
func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]*udpAddr, error) { func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
//TODO: we need to hold the lock through the next func
if !lh.IsLighthouseIP(ip) { if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f) lh.QueryServer(ip, f)
} }
lh.RLock() lh.RLock()
if v, ok := lh.addrMap[ip]; ok { if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock() lh.RUnlock()
return TransformLHReplyToUdpAddrs(v), nil return v
} }
lh.RUnlock() lh.RUnlock()
return nil, ErrHostNotKnown return nil
} }
// This is asynchronous so no reply should be expected // This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) { func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
if !lh.amLighthouse { if lh.amLighthouse {
return
}
if lh.IsLighthouseIP(ip) {
return
}
// Send a query to the lighthouses and hope for the best next time // Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip)) query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil { if err != nil {
@ -160,28 +152,44 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
for n := range lh.lighthouses { for n := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out) f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
} }
}
} }
func (lh *LightHouse) QueryCache(ip uint32) []*udpAddr { func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
//TODO: we need to hold the lock through the next func
lh.RLock() lh.RLock()
if v, ok := lh.addrMap[ip]; ok { if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock() lh.RUnlock()
return TransformLHReplyToUdpAddrs(v) return v
} }
lh.RUnlock() lh.RUnlock()
return nil
lh.Lock()
defer lh.Unlock()
// Add an entry if we don't already have one
return lh.unlockedGetRemoteList(ip)
} }
// // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
func (lh *LightHouse) queryAndPrepMessage(ip uint32, f func(*ip4And6) (int, error)) (bool, int, error) { // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) {
lh.RLock() lh.RLock()
if v, ok := lh.addrMap[ip]; ok { // Do we have an entry in the main cache?
n, err := f(v) if v, ok := lh.addrMap[vpnIp]; ok {
// Swap lh lock for remote list lock
v.RLock()
defer v.RUnlock()
lh.RUnlock() lh.RUnlock()
// vpnIp should also be the owner here since we are a lighthouse.
c := v.cache[vpnIp]
// Make sure we have
if c != nil {
n, err := f(c)
return true, n, err return true, n, err
} }
return false, 0, nil
}
lh.RUnlock() lh.RUnlock()
return false, 0, nil return false, 0, nil
} }
@ -203,70 +211,47 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
lh.Unlock() lh.Unlock()
} }
// AddRemote is correct way for non LightHouse members to add an address. toAddr will be placed in the learned map // AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
// static means this is a static host entry from the config file, it should only be used on start up // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
func (lh *LightHouse) AddRemote(vpnIP uint32, toAddr *udpAddr, static bool) { // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp)
am.Lock()
defer am.Unlock()
lh.Unlock()
if ipv4 := toAddr.IP.To4(); ipv4 != nil { if ipv4 := toAddr.IP.To4(); ipv4 != nil {
lh.addRemoteV4(vpnIP, NewIp4AndPort(ipv4, uint32(toAddr.Port)), static, true) to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
if !lh.unlockedShouldAddV4(to) {
return
}
am.unlockedPrependV4(lh.myVpnIp, to)
} else { } else {
lh.addRemoteV6(vpnIP, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)), static, true) to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
if !lh.unlockedShouldAddV6(to) {
return
}
am.unlockedPrependV6(lh.myVpnIp, to)
} }
//TODO: if we do not add due to a config filter we may end up not having any addresses here // Mark it as static
if static { lh.staticList[vpnIp] = struct{}{}
lh.staticList[vpnIP] = struct{}{}
}
} }
// unlockedGetAddrs assumes you have the lh lock // unlockedGetRemoteList assumes you have the lh lock
func (lh *LightHouse) unlockedGetAddrs(vpnIP uint32) *ip4And6 { func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList {
am, ok := lh.addrMap[vpnIP] am, ok := lh.addrMap[vpnIP]
if !ok { if !ok {
am = &ip4And6{} am = NewRemoteList()
lh.addrMap[vpnIP] = am lh.addrMap[vpnIP] = am
} }
return am return am
} }
// addRemoteV4 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated // unlockedShouldAddV4 checks if to is allowed by our allow list
func (lh *LightHouse) addRemoteV4(vpnIP uint32, to *Ip4AndPort, static bool, learned bool) { func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool {
// First we check if the sender thinks this is a static entry
// and do nothing if it is not, but should be considered static
if static == false {
if _, ok := lh.staticList[vpnIP]; ok {
return
}
}
lh.Lock()
defer lh.Unlock()
am := lh.unlockedGetAddrs(vpnIP)
if learned {
if !lh.unlockedShouldAddV4(am.learnedV4, to) {
return
}
am.learnedV4 = prependAndLimitV4(am.learnedV4, to)
} else {
if !lh.unlockedShouldAddV4(am.v4, to) {
return
}
am.v4 = prependAndLimitV4(am.v4, to)
}
}
func prependAndLimitV4(cache []*Ip4AndPort, to *Ip4AndPort) []*Ip4AndPort {
cache = append(cache, nil)
copy(cache[1:], cache)
cache[0] = to
if len(cache) > MaxRemotes {
cache = cache[:maxAddrs]
}
return cache
}
// unlockedShouldAddV4 checks if to is allowed by our allow list and is not already present in the cache
func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool {
allow := lh.remoteAllowList.AllowIpV4(to.Ip) allow := lh.remoteAllowList.AllowIpV4(to.Ip)
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow") lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
@ -276,69 +261,21 @@ func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool
return false return false
} }
for _, v := range am {
if v.Ip == to.Ip && v.Port == to.Port {
return false
}
}
return true return true
} }
// addRemoteV6 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated // unlockedShouldAddV6 checks if to is allowed by our allow list
func (lh *LightHouse) addRemoteV6(vpnIP uint32, to *Ip6AndPort, static bool, learned bool) { func (lh *LightHouse) unlockedShouldAddV6(to *Ip6AndPort) bool {
// First we check if the sender thinks this is a static entry
// and do nothing if it is not, but should be considered static
if static == false {
if _, ok := lh.staticList[vpnIP]; ok {
return
}
}
lh.Lock()
defer lh.Unlock()
am := lh.unlockedGetAddrs(vpnIP)
if learned {
if !lh.unlockedShouldAddV6(am.learnedV6, to) {
return
}
am.learnedV6 = prependAndLimitV6(am.learnedV6, to)
} else {
if !lh.unlockedShouldAddV6(am.v6, to) {
return
}
am.v6 = prependAndLimitV6(am.v6, to)
}
}
func prependAndLimitV6(cache []*Ip6AndPort, to *Ip6AndPort) []*Ip6AndPort {
cache = append(cache, nil)
copy(cache[1:], cache)
cache[0] = to
if len(cache) > MaxRemotes {
cache = cache[:maxAddrs]
}
return cache
}
// unlockedShouldAddV6 checks if to is allowed by our allow list and is not already present in the cache
func (lh *LightHouse) unlockedShouldAddV6(am []*Ip6AndPort, to *Ip6AndPort) bool {
allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo) allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo)
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
} }
// We don't check our vpn network here because nebula does not support ipv6 on the inside
if !allow { if !allow {
return false return false
} }
for _, v := range am {
if v.Hi == to.Hi && v.Lo == to.Lo && v.Port == to.Port {
return false
}
}
return true return true
} }
@ -349,13 +286,6 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
return ip return ip
} }
func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) {
if lh.amLighthouse {
lh.DeleteVpnIP(vpnIP)
lh.AddRemote(vpnIP, toIp, false)
}
}
func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool { func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
if _, ok := lh.lighthouses[vpnIP]; ok { if _, ok := lh.lighthouses[vpnIP]; ok {
return true return true
@ -496,7 +426,6 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta return lhh.meta
} }
//TODO: do we need c here?
func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) { func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
n := lhh.resetMeta() n := lhh.resetMeta()
err := n.Unmarshal(p) err := n.Unmarshal(p)
@ -544,13 +473,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
//TODO: we can DRY this further //TODO: we can DRY this further
reqVpnIP := n.Details.VpnIp reqVpnIP := n.Details.VpnIp
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
//TODO: If we use a lock on cache we can avoid holding it on lh.addrMap and keep things moving better found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) {
found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(cache *ip4And6) (int, error) {
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply n.Type = NebulaMeta_HostQueryReply
n.Details.VpnIp = reqVpnIP n.Details.VpnIp = reqVpnIP
lhh.coalesceAnswers(cache, n) lhh.coalesceAnswers(c, n)
return n.MarshalTo(lhh.pb) return n.MarshalTo(lhh.pb)
}) })
@ -568,12 +496,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
// This signals the other side to punch some zero byte udp packets // This signals the other side to punch some zero byte udp packets
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(cache *ip4And6) (int, error) { found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification n.Type = NebulaMeta_HostPunchNotification
n.Details.VpnIp = vpnIp n.Details.VpnIp = vpnIp
lhh.coalesceAnswers(cache, n) lhh.coalesceAnswers(c, n)
return n.MarshalTo(lhh.pb) return n.MarshalTo(lhh.pb)
}) })
@ -591,12 +519,24 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
} }
func (lhh *LightHouseHandler) coalesceAnswers(cache *ip4And6, n *NebulaMeta) { func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.v4...) if c.v4 != nil {
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.learnedV4...) if c.v4.learned != nil {
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned)
}
if c.v4.reported != nil && len(c.v4.reported) > 0 {
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...)
}
}
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.v6...) if c.v6 != nil {
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.learnedV6...) if c.v6.learned != nil {
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned)
}
if c.v6.reported != nil && len(c.v6.reported) > 0 {
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...)
}
}
} }
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) { func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
@ -604,14 +544,14 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32)
return return
} }
// We can't just slam the responses in as they may come from multiple lighthouses and we should coalesce the answers lhh.lh.Lock()
for _, to := range n.Details.Ip4AndPorts { am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp)
lhh.lh.addRemoteV4(n.Details.VpnIp, to, false, false) am.Lock()
} lhh.lh.Unlock()
for _, to := range n.Details.Ip6AndPorts { am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
lhh.lh.addRemoteV6(n.Details.VpnIp, to, false, false) am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
} am.Unlock()
// Non-blocking attempt to trigger, skip if it would block // Non-blocking attempt to trigger, skip if it would block
select { select {
@ -637,35 +577,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
} }
lhh.lh.Lock() lhh.lh.Lock()
defer lhh.lh.Unlock() am := lhh.lh.unlockedGetRemoteList(vpnIp)
am := lhh.lh.unlockedGetAddrs(vpnIp) am.Lock()
lhh.lh.Unlock()
//TODO: other note on a lock for am so we can release more quickly and lock our real unit of change which is far less contended am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
// We don't accumulate addresses being told to us am.Unlock()
am.v4 = am.v4[:0]
am.v6 = am.v6[:0]
for _, v := range n.Details.Ip4AndPorts {
if lhh.lh.unlockedShouldAddV4(am.v4, v) {
am.v4 = append(am.v4, v)
}
}
for _, v := range n.Details.Ip6AndPorts {
if lhh.lh.unlockedShouldAddV6(am.v6, v) {
am.v6 = append(am.v6, v)
}
}
// We prefer the first n addresses if we got too big
if len(am.v4) > MaxRemotes {
am.v4 = am.v4[:MaxRemotes]
}
if len(am.v6) > MaxRemotes {
am.v6 = am.v6[:MaxRemotes]
}
} }
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) { func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
@ -716,33 +634,6 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
} }
} }
func TransformLHReplyToUdpAddrs(ips *ip4And6) []*udpAddr {
addrs := make([]*udpAddr, len(ips.v4)+len(ips.v6)+len(ips.learnedV4)+len(ips.learnedV6))
i := 0
for _, v := range ips.learnedV4 {
addrs[i] = NewUDPAddrFromLH4(v)
i++
}
for _, v := range ips.v4 {
addrs[i] = NewUDPAddrFromLH4(v)
i++
}
for _, v := range ips.learnedV6 {
addrs[i] = NewUDPAddrFromLH6(v)
i++
}
for _, v := range ips.v6 {
addrs[i] = NewUDPAddrFromLH6(v)
i++
}
return addrs
}
// ipMaskContains checks if testIp is contained by ip after applying a cidr // ipMaskContains checks if testIp is contained by ip after applying a cidr
// zeros is 32 - bits from net.IPMask.Size() // zeros is 32 - bits from net.IPMask.Size()
func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool { func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {

View File

@ -48,16 +48,16 @@ func Test_lhStaticMapping(t *testing.T) {
udpServer, _ := NewListener(l, "0.0.0.0", 0, true) udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true) meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
err := meh.ValidateLHStaticEntries() err := meh.ValidateLHStaticEntries()
assert.Nil(t, err) assert.Nil(t, err)
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2) lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false) meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true) meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
err = meh.ValidateLHStaticEntries() err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
} }
@ -73,17 +73,27 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
hAddr := NewUDPAddrFromString("4.5.6.7:12345") hAddr := NewUDPAddrFromString("4.5.6.7:12345")
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346") hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
lh.addrMap[3] = &ip4And6{v4: []*Ip4AndPort{ lh.addrMap[3] = NewRemoteList()
lh.addrMap[3].unlockedSetV4(
3,
[]*Ip4AndPort{
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port))}, NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
} },
func(*Ip4AndPort) bool { return true },
)
rAddr := NewUDPAddrFromString("1.2.2.3:12345") rAddr := NewUDPAddrFromString("1.2.2.3:12345")
rAddr2 := NewUDPAddrFromString("1.2.2.3:12346") rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
lh.addrMap[2] = &ip4And6{v4: []*Ip4AndPort{ lh.addrMap[2] = NewRemoteList()
lh.addrMap[2].unlockedSetV4(
3,
[]*Ip4AndPort{
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port))}, NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
} },
func(*Ip4AndPort) bool { return true },
)
mw := &mockEncWriter{} mw := &mockEncWriter{}
@ -173,7 +183,7 @@ func TestLighthouse_Memory(t *testing.T) {
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
// Ensure proper ordering and limiting // Ensure proper ordering and limiting
// Send 12 addrs, get 10 back, one removed on a dupe check the other by limiting // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
newLHHostUpdate( newLHHostUpdate(
myUdpAddr0, myUdpAddr0,
myVpnIp, myVpnIp,
@ -191,11 +201,12 @@ func TestLighthouse_Memory(t *testing.T) {
myUdpAddr10, myUdpAddr10,
myUdpAddr11, // This should get cut myUdpAddr11, // This should get cut
}, lhh) }, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray( assertIp4InArray(
t, t,
r.msg.Details.Ip4AndPorts, r.msg.Details.Ip4AndPorts,
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, myUdpAddr10, myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
) )
// Make sure we won't add ips in our vpn network // Make sure we won't add ips in our vpn network
@ -247,71 +258,71 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig
lhh.HandleRequest(fromAddr, vpnIp, b, w) lhh.HandleRequest(fromAddr, vpnIp, b, w)
} }
func Test_lhRemoteAllowList(t *testing.T) { //TODO: this is a RemoteList test
l := NewTestLogger() //func Test_lhRemoteAllowList(t *testing.T) {
c := NewConfig(l) // l := NewTestLogger()
c.Settings["remoteallowlist"] = map[interface{}]interface{}{ // c := NewConfig(l)
"10.20.0.0/12": false, // c.Settings["remoteallowlist"] = map[interface{}]interface{}{
} // "10.20.0.0/12": false,
allowList, err := c.GetAllowList("remoteallowlist", false) // }
assert.Nil(t, err) // allowList, err := c.GetAllowList("remoteallowlist", false)
// assert.Nil(t, err)
lh1 := "10.128.0.2" //
lh1IP := net.ParseIP(lh1) // lh1 := "10.128.0.2"
// lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true) //
// udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) //
lh.SetRemoteAllowList(allowList) // lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
// lh.SetRemoteAllowList(allowList)
// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap //
remote1IP := net.ParseIP("10.20.0.3") // // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
lh.AddRemote(ip2int(remote1IP), NewUDPAddr(remote1IP, uint16(4242)), true) // remote1IP := net.ParseIP("10.20.0.3")
assert.NotNil(t, lh.addrMap[ip2int(remote1IP)]) // remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v4) // remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v6) // assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
// Make sure a good ip enters the cache and addrMap //
remote2IP := net.ParseIP("10.128.0.3") // // Make sure a good ip enters the cache and addrMap
remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242)) // remote2IP := net.ParseIP("10.128.0.3")
lh.AddRemote(ip2int(remote2IP), remote2UDPAddr, true) // remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote2UDPAddr) // lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
// Another good ip gets into the cache, ordering is inverted //
remote3IP := net.ParseIP("10.128.0.4") // // Another good ip gets into the cache, ordering is inverted
remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243)) // remote3IP := net.ParseIP("10.128.0.4")
lh.AddRemote(ip2int(remote2IP), remote3UDPAddr, true) // remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote3UDPAddr, remote2UDPAddr) // lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
// If we exceed the length limit we should only have the most recent addresses //
addedAddrs := []*udpAddr{} // // If we exceed the length limit we should only have the most recent addresses
for i := 0; i < 11; i++ { // addedAddrs := []*udpAddr{}
remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i)) // for i := 0; i < 11; i++ {
lh.AddRemote(ip2int(remote2IP), remoteUDPAddr, true) // remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
// The first entry here is a duplicate, don't add it to the assert list // lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
if i != 0 { // // The first entry here is a duplicate, don't add it to the assert list
addedAddrs = append(addedAddrs, remoteUDPAddr) // if i != 0 {
} // addedAddrs = append(addedAddrs, remoteUDPAddr)
} // }
// }
// We should only have the last 10 of what we tried to add //
assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses") // // We should only have the last 10 of what we tried to add
ln := len(addedAddrs) // assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
assertIp4InArray( // assertUdpAddrInArray(
t, // t,
lh.addrMap[ip2int(remote2IP)].learnedV4, // lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
addedAddrs[ln-1], // addedAddrs[0],
addedAddrs[ln-2], // addedAddrs[1],
addedAddrs[ln-3], // addedAddrs[2],
addedAddrs[ln-4], // addedAddrs[3],
addedAddrs[ln-5], // addedAddrs[4],
addedAddrs[ln-6], // addedAddrs[5],
addedAddrs[ln-7], // addedAddrs[6],
addedAddrs[ln-8], // addedAddrs[7],
addedAddrs[ln-9], // addedAddrs[8],
addedAddrs[ln-10], // addedAddrs[9],
) // )
} //}
func Test_ipMaskContains(t *testing.T) { func Test_ipMaskContains(t *testing.T) {
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255")))) assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
@ -354,6 +365,16 @@ func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) {
} }
} }
// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
assert.Len(t, have, len(want))
for k, w := range want {
if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
}
}
}
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr { func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
addrs := make([]*udpAddr, len(ips)) addrs := make([]*udpAddr, len(ips))
for k, v := range ips { for k, v := range ips {

View File

@ -221,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges) hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
hostMap.addUnsafeRoutes(&unsafeRoutes) hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false) hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
@ -302,14 +302,14 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
if err != nil { if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true) lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
} }
} else { } else {
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v)) ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
if err != nil { if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true) lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
} }
} }
@ -328,7 +328,6 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
handshakeConfig := HandshakeConfig{ handshakeConfig := HandshakeConfig{
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries), retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
messageMetrics: messageMetrics, messageMetrics: messageMetrics,

View File

@ -132,6 +132,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
f.connectionManager.In(hostinfo.hostId) f.connectionManager.In(hostinfo.hostId)
} }
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
func (f *Interface) closeTunnel(hostInfo *HostInfo) { func (f *Interface) closeTunnel(hostInfo *HostInfo) {
//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately //TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
f.connectionManager.ClearIP(hostInfo.hostId) f.connectionManager.ClearIP(hostInfo.hostId)
@ -140,6 +141,11 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
f.hostMap.DeleteHostInfo(hostInfo) f.hostMap.DeleteHostInfo(hostInfo)
} }
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
if hostDidRoam(hostinfo.remote, addr) { if hostDidRoam(hostinfo.remote, addr) {
if !f.lightHouse.remoteAllowList.Allow(addr.IP) { if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
@ -160,9 +166,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
remoteCopy := *hostinfo.remote remoteCopy := *hostinfo.remote
hostinfo.lastRoamRemote = &remoteCopy hostinfo.lastRoamRemote = &remoteCopy
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
if f.lightHouse.amLighthouse {
f.lightHouse.AddRemote(hostinfo.hostId, addr, false)
}
} }
} }

500
remote_list.go Normal file
View File

@ -0,0 +1,500 @@
package nebula
import (
"bytes"
"net"
"sort"
"sync"
)
// forEachFunc is used to benefit folks that want to do work inside the lock
type forEachFunc func(addr *udpAddr, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
type checkFuncV4 func(to *Ip4AndPort) bool
type checkFuncV6 func(to *Ip6AndPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp
type CacheMap map[string]*Cache
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
// We don't reason about ipv4 vs ipv6 here
type Cache struct {
Learned []*udpAddr `json:"learned,omitempty"`
Reported []*udpAddr `json:"reported,omitempty"`
}
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
// We will never clean learned/reported information for them as it stands today
// cache is an internal struct that splits v4 and v6 addresses inside the cache map
type cache struct {
v4 *cacheV4
v6 *cacheV6
}
// cacheV4 stores learned and reported ipv4 records under cache
type cacheV4 struct {
learned *Ip4AndPort
reported []*Ip4AndPort
}
// cacheV4 stores learned and reported ipv6 records under cache
type cacheV6 struct {
learned *Ip6AndPort
reported []*Ip6AndPort
}
// RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
// It serves as a local cache of query replies, host update notifications, and locally learned addresses
type RemoteList struct {
// Every interaction with internals requires a lock!
sync.RWMutex
// A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []*udpAddr
// These are maps to store v4 and v6 addresses per lighthouse
// Map key is the vpnIp of the person that told us about this the cached entries underneath.
// For learned addresses, this is the vpnIp that sent the packet
cache map[uint32]*cache
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
badRemotes []*udpAddr
// A flag that the cache may have changed and addrs needs to be rebuilt
shouldRebuild bool
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList() *RemoteList {
return &RemoteList{
addrs: make([]*udpAddr, 0),
cache: make(map[uint32]*cache),
}
}
// Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
r.Rebuild(preferredRanges)
r.RLock()
defer r.RUnlock()
return len(r.addrs)
}
// ForEach locks and will call the forEachFunc for every deduplicated address in the list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
r.Rebuild(preferredRanges)
r.RLock()
for _, v := range r.addrs {
forEach(v, isPreferred(v.IP, preferredRanges))
}
r.RUnlock()
}
// CopyAddrs locks and makes a deep copy of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
r.Rebuild(preferredRanges)
r.RLock()
defer r.RUnlock()
c := make([]*udpAddr, len(r.addrs))
for i, v := range r.addrs {
c[i] = v.Copy()
}
return c
}
// LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
//TODO: this needs to support the allow list list
func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) {
r.Lock()
defer r.Unlock()
if v4 := addr.IP.To4(); v4 != nil {
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
} else {
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
}
}
// CopyCache locks and creates a more human friendly form of the internal address cache.
// This may contain duplicates and blocked addresses
func (r *RemoteList) CopyCache() *CacheMap {
r.RLock()
defer r.RUnlock()
cm := make(CacheMap)
getOrMake := func(vpnIp string) *Cache {
c := cm[vpnIp]
if c == nil {
c = &Cache{
Learned: make([]*udpAddr, 0),
Reported: make([]*udpAddr, 0),
}
cm[vpnIp] = c
}
return c
}
for owner, mc := range r.cache {
c := getOrMake(IntIp(owner).String())
if mc.v4 != nil {
if mc.v4.learned != nil {
c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
}
for _, a := range mc.v4.reported {
c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
}
}
if mc.v6 != nil {
if mc.v6.learned != nil {
c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
}
for _, a := range mc.v6.reported {
c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
}
}
}
return &cm
}
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad *udpAddr) {
r.Lock()
defer r.Unlock()
// Check if we already blocked this addr
if r.unlockedIsBad(bad) {
return
}
// We copy here because we are taking something else's memory and we can't trust everything
r.badRemotes = append(r.badRemotes, bad.Copy())
// Mark the next interaction must recollect/dedupe
r.shouldRebuild = true
}
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
func (r *RemoteList) CopyBlockedRemotes() []*udpAddr {
r.RLock()
defer r.RUnlock()
c := make([]*udpAddr, len(r.badRemotes))
for i, v := range r.badRemotes {
c[i] = v.Copy()
}
return c
}
// ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() {
r.Lock()
r.badRemotes = nil
r.Unlock()
}
// Rebuild locks and generates the deduplicated address list only if there is work to be done
// There is generally no reason to call this directly but it is safe to do so
func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
r.Lock()
defer r.Unlock()
// Only rebuild if the cache changed
//TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in
if r.shouldRebuild {
r.unlockedCollect()
r.shouldRebuild = false
}
// Always re-sort, preferredRanges can change via HUP
r.unlockedSort(preferredRanges)
}
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
for _, v := range r.badRemotes {
if v.Equals(remote) {
return true
}
}
return false
}
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
}
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check checkFuncV4) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
// Reset the slice
c.reported = c.reported[:0]
// We can't take their array but we can take their pointers
for _, v := range to[:minInt(len(to), MaxRemotes)] {
if check(v) {
c.reported = append(c.reported, v)
}
}
}
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
// We are doing the easy append because this is rarely called
c.reported = append([]*Ip4AndPort{to}, c.reported...)
if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes]
}
}
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
}
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check checkFuncV6) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
// Reset the slice
c.reported = c.reported[:0]
// We can't take their array but we can take their pointers
for _, v := range to[:minInt(len(to), MaxRemotes)] {
if check(v) {
c.reported = append(c.reported, v)
}
}
}
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
// We are doing the easy append because this is rarely called
c.reported = append([]*Ip6AndPort{to}, c.reported...)
if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes]
}
}
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
// The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
r.cache[ownerVpnIp] = am
}
// Avoid occupying memory for v6 addresses if we never have any
if am.v4 == nil {
am.v4 = &cacheV4{}
}
return am.v4
}
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
// The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
r.cache[ownerVpnIp] = am
}
// Avoid occupying memory for v4 addresses if we never have any
if am.v6 == nil {
am.v6 = &cacheV6{}
}
return am.v6
}
// unlockedCollect assumes you have the write lock and collects/transforms the cache into the deduped address list.
// The result of this function can contain duplicates. unlockedSort handles cleaning it.
func (r *RemoteList) unlockedCollect() {
addrs := r.addrs[:0]
for _, c := range r.cache {
if c.v4 != nil {
if c.v4.learned != nil {
u := NewUDPAddrFromLH4(c.v4.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v4.reported {
u := NewUDPAddrFromLH4(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
}
if c.v6 != nil {
if c.v6.learned != nil {
u := NewUDPAddrFromLH6(c.v6.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v6.reported {
u := NewUDPAddrFromLH6(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
}
}
r.addrs = addrs
}
// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
n := len(r.addrs)
if n < 2 {
return
}
lessFunc := func(i, j int) bool {
a := r.addrs[i]
b := r.addrs[j]
// Preferred addresses first
aPref := isPreferred(a.IP, preferredRanges)
bPref := isPreferred(b.IP, preferredRanges)
switch {
case aPref && !bPref:
// If i is preferred and j is not, i is less than j
return true
case !aPref && bPref:
// If j is preferred then i is not due to the else, i is not less than j
return false
default:
// Both i an j are either preferred or not, sort within that
}
// ipv6 addresses 2nd
a4 := a.IP.To4()
b4 := b.IP.To4()
switch {
case a4 == nil && b4 != nil:
// If i is v6 and j is v4, i is less than j
return true
case a4 != nil && b4 == nil:
// If j is v6 and i is v4, i is not less than j
return false
case a4 != nil && b4 != nil:
// Special case for ipv4, a4 and b4 are not nil
aPrivate := isPrivateIP(a4)
bPrivate := isPrivateIP(b4)
switch {
case !aPrivate && bPrivate:
// If i is a public ip (not private) and j is a private ip, i is less then j
return true
case aPrivate && !bPrivate:
// If j is public (not private) then i is private due to the else, i is not less than j
return false
default:
// Both i an j are either public or private, sort within that
}
default:
// Both i an j are either ipv4 or ipv6, sort within that
}
// lexical order of ips 3rd
c := bytes.Compare(a.IP, b.IP)
if c == 0 {
// Ips are the same, Lexical order of ports 4th
return a.Port < b.Port
}
// Ip wasn't the same
return c < 0
}
// Sort it
sort.Slice(r.addrs, lessFunc)
// Deduplicate
a, b := 0, 1
for b < n {
if !r.addrs[a].Equals(r.addrs[b]) {
a++
if a != b {
r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
}
}
b++
}
r.addrs = r.addrs[:a+1]
return
}
// minInt returns the minimum integer of a or b
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
// isPreferred returns true of the ip is contained in the preferredRanges list
func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
//TODO: this would be better in a CIDR6Tree
for _, p := range preferredRanges {
if p.Contains(ip) {
return true
}
}
return false
}
var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
func isPrivateIP(ip net.IP) bool {
//TODO: another great cidrtree option
//TODO: Private for ipv6 or just let it ride?
return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
}

228
remote_list_test.go Normal file
View File

@ -0,0 +1,228 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList()
rl.unlockedSetV4(
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe
},
func(*Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
1,
[]*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
},
func(*Ip6AndPort) bool { return true },
)
rl.Rebuild([]*net.IPNet{})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv6 first, sorted lexically within
assert.Equal(t, "[1::1]:1", rl.addrs[0].String())
assert.Equal(t, "[1::1]:2", rl.addrs[1].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[2].String())
// ipv4 last, sorted by public first, then private, lexically within them
assert.Equal(t, "70.199.182.92:1475", rl.addrs[3].String())
assert.Equal(t, "70.199.182.92:1476", rl.addrs[4].String())
assert.Equal(t, "172.17.0.182:10101", rl.addrs[5].String())
assert.Equal(t, "172.17.1.1:10101", rl.addrs[6].String())
assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
// Now ensure we can hoist ipv4 up
_, ipNet, err := net.ParseCIDR("0.0.0.0/0")
assert.NoError(t, err)
rl.Rebuild([]*net.IPNet{ipNet})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv4 first, public then private, lexically within them
assert.Equal(t, "70.199.182.92:1475", rl.addrs[0].String())
assert.Equal(t, "70.199.182.92:1476", rl.addrs[1].String())
assert.Equal(t, "172.17.0.182:10101", rl.addrs[2].String())
assert.Equal(t, "172.17.1.1:10101", rl.addrs[3].String())
assert.Equal(t, "172.18.0.1:10101", rl.addrs[4].String())
assert.Equal(t, "172.19.0.1:10101", rl.addrs[5].String())
assert.Equal(t, "172.31.0.1:10101", rl.addrs[6].String())
// ipv6 last, sorted by public first, then private, lexically within them
assert.Equal(t, "[1::1]:1", rl.addrs[7].String())
assert.Equal(t, "[1::1]:2", rl.addrs[8].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
// Ensure we can hoist a specific ipv4 range over anything else
_, ipNet, err = net.ParseCIDR("172.17.0.0/16")
assert.NoError(t, err)
rl.Rebuild([]*net.IPNet{ipNet})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// Preferred ipv4 first
assert.Equal(t, "172.17.0.182:10101", rl.addrs[0].String())
assert.Equal(t, "172.17.1.1:10101", rl.addrs[1].String())
// ipv6 next
assert.Equal(t, "[1::1]:1", rl.addrs[2].String())
assert.Equal(t, "[1::1]:2", rl.addrs[3].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[4].String())
// the remaining ipv4 last
assert.Equal(t, "70.199.182.92:1475", rl.addrs[5].String())
assert.Equal(t, "70.199.182.92:1476", rl.addrs[6].String())
assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
}
func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList()
rl.unlockedSetV4(
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
},
func(*Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
0,
[]*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
},
func(*Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{})
}
})
_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet})
}
})
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
}
})
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
}
})
}
func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList()
rl.unlockedSetV4(
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
},
func(*Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
0,
[]*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
},
func(*Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{})
}
})
_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
rl.Rebuild([]*net.IPNet{ipNet})
assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet})
}
})
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
}
})
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
}
})
}

87
ssh.go
View File

@ -10,8 +10,8 @@ import (
"os" "os"
"reflect" "reflect"
"runtime/pprof" "runtime/pprof"
"sort"
"strings" "strings"
"sync/atomic"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -335,8 +335,10 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
return nil return nil
} }
hostMap.RLock() hm := listHostMap(hostMap)
defer hostMap.RUnlock() sort.Slice(hm, func(i, j int) bool {
return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0
})
if fs.Json || fs.Pretty { if fs.Json || fs.Pretty {
js := json.NewEncoder(w.GetWriter()) js := json.NewEncoder(w.GetWriter())
@ -344,35 +346,15 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
js.SetIndent("", " ") js.SetIndent("", " ")
} }
d := make([]m, len(hostMap.Hosts)) err := js.Encode(hm)
x := 0
var h m
for _, v := range hostMap.Hosts {
h = m{
"vpnIp": int2ip(v.hostId),
"localIndex": v.localIndexId,
"remoteIndex": v.remoteIndexId,
"remoteAddrs": v.CopyRemotes(),
"cachedPackets": len(v.packetStore),
"cert": v.GetCert(),
}
if v.ConnectionState != nil {
h["messageCounter"] = atomic.LoadUint64(&v.ConnectionState.atomicMessageCounter)
}
d[x] = h
x++
}
err := js.Encode(d)
if err != nil { if err != nil {
//TODO //TODO
return nil return nil
} }
} else { } else {
for i, v := range hostMap.Hosts { for _, v := range hm {
err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.CopyRemotes())) err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs))
if err != nil { if err != nil {
return err return err
} }
@ -389,8 +371,26 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
return nil return nil
} }
type lighthouseInfo struct {
VpnIP net.IP `json:"vpnIp"`
Addrs *CacheMap `json:"addrs"`
}
lightHouse.RLock() lightHouse.RLock()
defer lightHouse.RUnlock() addrMap := make([]lighthouseInfo, len(lightHouse.addrMap))
x := 0
for k, v := range lightHouse.addrMap {
addrMap[x] = lighthouseInfo{
VpnIP: int2ip(k),
Addrs: v.CopyCache(),
}
x++
}
lightHouse.RUnlock()
sort.Slice(addrMap, func(i, j int) bool {
return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0
})
if fs.Json || fs.Pretty { if fs.Json || fs.Pretty {
js := json.NewEncoder(w.GetWriter()) js := json.NewEncoder(w.GetWriter())
@ -398,27 +398,19 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
js.SetIndent("", " ") js.SetIndent("", " ")
} }
d := make([]m, len(lightHouse.addrMap)) err := js.Encode(addrMap)
x := 0
var h m
for vpnIp, v := range lightHouse.addrMap {
h = m{
"vpnIp": int2ip(vpnIp),
"addrs": TransformLHReplyToUdpAddrs(v),
}
d[x] = h
x++
}
err := js.Encode(d)
if err != nil { if err != nil {
//TODO //TODO
return nil return nil
} }
} else { } else {
for vpnIp, v := range lightHouse.addrMap { for _, v := range addrMap {
err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(vpnIp), TransformLHReplyToUdpAddrs(v))) b, err := json.Marshal(v.Addrs)
if err != nil {
return err
}
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b)))
if err != nil { if err != nil {
return err return err
} }
@ -469,8 +461,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
ips, _ := ifce.lightHouse.Query(vpnIp, ifce) return json.NewEncoder(w.GetWriter()).Encode(ifce.lightHouse.Query(vpnIp, ifce).CopyCache())
return json.NewEncoder(w.GetWriter()).Encode(ips)
} }
func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
@ -727,7 +718,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp)
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
} }
@ -737,7 +728,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
enc.SetIndent("", " ") enc.SetIndent("", " ")
} }
return enc.Encode(hostInfo) return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
} }
func sshReload(fs interface{}, a []string, w sshd.StringWriter) error { func sshReload(fs interface{}, a []string, w sshd.StringWriter) error {

View File

@ -41,9 +41,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []route, _ []r
// These are unencrypted ip layer frames destined for another nebula node. // These are unencrypted ip layer frames destined for another nebula node.
// packets should exit the udp side, capture them with udpConn.Get // packets should exit the udp side, capture them with udpConn.Get
func (c *Tun) Send(packet []byte) { func (c *Tun) Send(packet []byte) {
if c.l.Level >= logrus.DebugLevel { c.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet")
c.l.Debug("Tun injecting packet")
}
c.rxPackets <- packet c.rxPackets <- packet
} }

View File

@ -13,8 +13,8 @@ type udpAddr struct {
} }
func NewUDPAddr(ip net.IP, port uint16) *udpAddr { func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
addr := udpAddr{IP: make([]byte, len(ip)), Port: port} addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port}
copy(addr.IP, ip) copy(addr.IP, ip.To16())
return &addr return &addr
} }
@ -22,7 +22,7 @@ func NewUDPAddrFromString(s string) *udpAddr {
ip, port, err := parseIPAndPort(s) ip, port, err := parseIPAndPort(s)
//TODO: handle err //TODO: handle err
_ = err _ = err
return &udpAddr{IP: ip, Port: port} return &udpAddr{IP: ip.To16(), Port: port}
} }
func (ua *udpAddr) Equals(t *udpAddr) bool { func (ua *udpAddr) Equals(t *udpAddr) bool {

View File

@ -97,40 +97,21 @@ func (u *udpConn) GetSendBuffer() (int, error) {
} }
func (u *udpConn) LocalAddr() (*udpAddr, error) { func (u *udpConn) LocalAddr() (*udpAddr, error) {
var rsa unix.RawSockaddrAny sa, err := unix.Getsockname(u.sysFd)
var rLen = unix.SizeofSockaddrAny if err != nil {
_, _, err := unix.Syscall(
unix.SYS_GETSOCKNAME,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&rsa)),
uintptr(unsafe.Pointer(&rLen)),
)
if err != 0 {
return nil, err return nil, err
} }
addr := &udpAddr{} addr := &udpAddr{}
if rsa.Addr.Family == unix.AF_INET { switch sa := sa.(type) {
pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(&rsa)) case *unix.SockaddrInet4:
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1]) addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
copy(addr.IP, pp.Addr[:]) addr.Port = uint16(sa.Port)
case *unix.SockaddrInet6:
} else if rsa.Addr.Family == unix.AF_INET6 { addr.IP = sa.Addr[0:]
//TODO: this cast sucks and we can do better addr.Port = uint16(sa.Port)
pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(&rsa))
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
copy(addr.IP, pp.Addr[:])
} else {
addr.Port = 0
addr.IP = []byte{}
} }
//TODO: Just use this instead?
//a, b := unix.Getsockname(u.sysFd)
return addr, nil return addr, nil
} }

View File

@ -3,6 +3,7 @@
package nebula package nebula
import ( import (
"fmt"
"net" "net"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -53,7 +54,14 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error
// this is an encrypted packet or a handshake message in most cases // this is an encrypted packet or a handshake message in most cases
// packets were transmitted from another nebula node, you can send them with Tun.Send // packets were transmitted from another nebula node, you can send them with Tun.Send
func (u *udpConn) Send(packet *UdpPacket) { func (u *udpConn) Send(packet *UdpPacket) {
u.l.Infof("UDP injecting packet %+v", packet) h := &Header{}
if err := h.Parse(packet.Data); err != nil {
panic(err)
}
u.l.WithField("header", h).
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
WithField("dataLen", len(packet.Data)).
Info("UDP receiving injected packet")
u.rxPackets <- packet u.rxPackets <- packet
} }