diff --git a/examples/config.yml b/examples/config.yml index 9c43bf6..63f454b 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -194,6 +194,9 @@ logging: #retries: 20 # wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses #wait_rotation: 5 + # trigger_buffer is the size of the buffer channel for quickly sending handshakes + # after receiving the response for lighthouse queries + #trigger_buffer: 64 # Nebula security group configuration firewall: diff --git a/handshake_manager.go b/handshake_manager.go index 1d23013..a0b04d4 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -16,21 +16,24 @@ const ( DefaultHandshakeTryInterval = time.Millisecond * 100 DefaultHandshakeRetries = 20 // DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses - DefaultHandshakeWaitRotation = 5 + DefaultHandshakeWaitRotation = 5 + DefaultHandshakeTriggerBuffer = 64 ) var ( defaultHandshakeConfig = HandshakeConfig{ - tryInterval: DefaultHandshakeTryInterval, - retries: DefaultHandshakeRetries, - waitRotation: DefaultHandshakeWaitRotation, + tryInterval: DefaultHandshakeTryInterval, + retries: DefaultHandshakeRetries, + waitRotation: DefaultHandshakeWaitRotation, + triggerBuffer: DefaultHandshakeTriggerBuffer, } ) type HandshakeConfig struct { - tryInterval time.Duration - retries int - waitRotation int + tryInterval time.Duration + retries int + waitRotation int + triggerBuffer int messageMetrics *MessageMetrics } @@ -42,6 +45,9 @@ type HandshakeManager struct { outside *udpConn config HandshakeConfig + // can be used to trigger outbound handshake for the given vpnIP + trigger chan uint32 + OutboundHandshakeTimer *SystemTimerWheel InboundHandshakeTimer *SystemTimerWheel @@ -57,6 +63,8 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH config: config, + trigger: make(chan uint32, config.triggerBuffer), + OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)), InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)), @@ -66,9 +74,15 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH func (c *HandshakeManager) Run(f EncWriter) { clockSource := time.Tick(c.config.tryInterval) - for now := range clockSource { - c.NextOutboundHandshakeTimerTick(now, f) - c.NextInboundHandshakeTimerTick(now) + for { + select { + case vpnIP := <-c.trigger: + l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered") + c.handleOutbound(vpnIP, f, true) + case now := <-clockSource: + c.NextOutboundHandshakeTimerTick(now, f) + c.NextInboundHandshakeTimerTick(now) + } } } @@ -80,69 +94,86 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr break } vpnIP := ep.(uint32) + c.handleOutbound(vpnIP, f, false) + } +} - index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP) - if err != nil { - continue +func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) { + index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP) + if err != nil { + return + } + hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP) + if err != nil { + return + } + + // If we haven't finished the handshake and we haven't hit max retries, query + // lighthouse and then send the handshake packet again. + if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete { + if hostinfo.remote == nil { + // We continue to query the lighthouse because hosts may + // come online during handshake retries. If the query + // succeeds (no error), add the lighthouse info to hostinfo + ips := c.lightHouse.QueryCache(vpnIP) + // If we have no responses yet, or only one IP (the host hadn't + // finished reporting its own IPs yet), then send another query to + // the LH. + if len(ips) <= 1 { + ips, err = c.lightHouse.Query(vpnIP, f) + } + if err == nil { + for _, ip := range ips { + hostinfo.AddRemote(ip) + } + hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges) + } + } else if lighthouseTriggered { + // We were triggered by a lighthouse HostQueryReply packet, but + // we have already picked a remote for this host (this can happen + // if we are configured with multiple lighthouses). So we can skip + // this trigger and let the timerwheel handle the rest of the + // process + return } - hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP) - if err != nil { - continue + hostinfo.HandshakeCounter++ + + // We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through + // all the others until we can stand up a connection. + if hostinfo.HandshakeCounter > c.config.waitRotation { + hostinfo.rotateRemote() } - // If we haven't finished the handshake and we haven't hit max retries, query - // lighthouse and then send the handshake packet again. - if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete { - if hostinfo.remote == nil { - // We continue to query the lighthouse because hosts may - // come online during handshake retries. If the query - // succeeds (no error), add the lighthouse info to hostinfo - ips, err := c.lightHouse.Query(vpnIP, f) - if err == nil { - for _, ip := range ips { - hostinfo.AddRemote(ip) - } - hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges) - } + // Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation + if hostinfo.HandshakeReady && hostinfo.remote != nil { + c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1) + err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote) + if err != nil { + hostinfo.logger().WithField("udpAddr", hostinfo.remote). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("remoteIndex", hostinfo.remoteIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithError(err).Error("Failed to send handshake message") + } else { + //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should + // keep the real packet struct around for logging purposes + hostinfo.logger().WithField("udpAddr", hostinfo.remote). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("remoteIndex", hostinfo.remoteIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Handshake message sent") } + } - hostinfo.HandshakeCounter++ - - // We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through - // all the others until we can stand up a connection. - if hostinfo.HandshakeCounter > c.config.waitRotation { - hostinfo.rotateRemote() - } - - // Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation - if hostinfo.HandshakeReady && hostinfo.remote != nil { - c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote) - if err != nil { - hostinfo.logger().WithField("udpAddr", hostinfo.remote). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("remoteIndex", hostinfo.remoteIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake message") - } else { - //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should - // keep the real packet struct around for logging purposes - hostinfo.logger().WithField("udpAddr", hostinfo.remote). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("remoteIndex", hostinfo.remoteIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake message sent") - } - } - - // Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try + // Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try + if !lighthouseTriggered { //l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter)) c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) - } else { - c.pendingHostMap.DeleteVpnIP(vpnIP) - c.pendingHostMap.DeleteIndex(index) } + } else { + c.pendingHostMap.DeleteVpnIP(vpnIP) + c.pendingHostMap.DeleteIndex(index) } } @@ -169,6 +200,15 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo { // We lock here and use an array to insert items to prevent locking the // main receive thread for very long by waiting to add items to the pending map c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval) + + // If this is a static host, we don't need to wait for the HostQueryReply + // We can trigger the handshake right now + if _, ok := c.lightHouse.staticList[vpnIP]; ok { + select { + case c.trigger <- vpnIP: + default: + } + } return hostinfo } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 99eb586..c4f1685 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -103,6 +103,56 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) { } } +func Test_NewHandshakeManagerTrigger(t *testing.T) { + _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") + _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") + _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + ip := ip2int(net.ParseIP("172.1.1.2")) + preferredRanges := []*net.IPNet{localrange} + mw := &mockEncWriter{} + mainHM := NewHostMap("test", vpncidr, preferredRanges) + lh := &LightHouse{} + + blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig) + + now := time.Now() + blah.NextOutboundHandshakeTimerTick(now, mw) + + assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) + + blah.AddVpnIP(ip) + + assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) + + // Trigger the same method the channel will + blah.handleOutbound(ip, mw, true) + + // Make sure the trigger doesn't schedule another timer entry + assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) + hi := blah.pendingHostMap.Hosts[ip] + assert.Nil(t, hi.remote) + + lh.addrMap = map[uint32][]udpAddr{ + ip: {*NewUDPAddrFromString("10.1.1.1:4242")}, + } + + // This should trigger the hostmap to populate the hostinfo + blah.handleOutbound(ip, mw, true) + assert.NotNil(t, hi.remote) + assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) +} + +func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) { + for _, i := range tw.wheel { + n := i.Head + for n != nil { + c++ + n = n.Next + } + } + return c +} + func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") diff --git a/lighthouse.go b/lighthouse.go index 9b5b1c7..3251ef1 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -30,6 +30,9 @@ type LightHouse struct { // filters local addresses that we advertise to lighthouses localAllowList *AllowList + // used to trigger the HandshakeManager when we receive HostQueryReply + handshakeTrigger chan<- uint32 + // staticList exists to avoid having a bool in each addrMap entry // since static should be rare staticList map[uint32]struct{} @@ -358,6 +361,11 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c ans := NewUDPAddr(a.Ip, uint16(a.Port)) lh.AddRemote(n.Details.VpnIp, ans, false) } + // Non-blocking attempt to trigger, skip if it would block + select { + case lh.handshakeTrigger <- n.Details.VpnIp: + default: + } case NebulaMeta_HostUpdateNotification: //Simple check that the host sent this not someone else diff --git a/main.go b/main.go index a89f0e8..09ad578 100644 --- a/main.go +++ b/main.go @@ -318,14 +318,16 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg } handshakeConfig := HandshakeConfig{ - tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries), - waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation), + tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), + retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries), + waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation), + triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), messageMetrics: messageMetrics, } handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer, handshakeConfig) + lightHouse.handshakeTrigger = handshakeManager.trigger //TODO: These will be reused for psk //handshakeMACKey := config.GetString("handshake_mac.key", "")