Refactor handshake_ix (#401)

There are some subtle race conditions with the previous handshake_ix implementation, mostly around collisions with localIndexId. This change refactors it so that we have a "commit" phase during the handshake where we grab the lock for the hostmap and ensure that we have a unique local index before storing it. We also now avoid using the pending hostmap at all for receiving stage1 packets, since we have everything we need to just store the completed handshake.

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
Co-authored-by: Ryan Huber <rhuber@gmail.com>
Co-authored-by: forfuncsake <drussell@slack-corp.com>
This commit is contained in:
Wade Simmons 2021-03-12 14:16:25 -05:00 committed by GitHub
parent 64d8035d09
commit 6c55d67f18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 345 additions and 315 deletions

View File

@ -6,30 +6,23 @@ const (
) )
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
//TODO: For stage 1 we won't have hostinfo yet but stage 2 and above would require it, this check may be helpful in those cases
//if err != nil {
// l.WithError(err).WithField("udpAddr", addr).Error("Error while finding host info for handshake message")
// return
//}
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) { if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
tearDown := false
switch h.Subtype { switch h.Subtype {
case handshakeIXPSK0: case handshakeIXPSK0:
switch h.MessageCounter { switch h.MessageCounter {
case 1: case 1:
tearDown = ixHandshakeStage1(f, addr, newHostinfo, packet, h) ixHandshakeStage1(f, addr, packet, h)
case 2: case 2:
tearDown = ixHandshakeStage2(f, addr, newHostinfo, packet, h) newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
} tearDown := ixHandshakeStage2(f, addr, newHostinfo, packet, h)
}
if tearDown && newHostinfo != nil { if tearDown && newHostinfo != nil {
f.handshakeManager.DeleteHostInfo(newHostinfo) f.handshakeManager.DeleteHostInfo(newHostinfo)
} }
}
}
} }

View File

@ -25,17 +25,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
} }
} }
myIndex, err := generateIndex() err := f.handshakeManager.AddIndexHostInfo(hostinfo)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return return
} }
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
f.handshakeManager.AddIndexHostInfo(myIndex, hostinfo)
hsProto := &NebulaHandshakeDetails{ hsProto := &NebulaHandshakeDetails{
InitiatorIndex: myIndex, InitiatorIndex: hostinfo.localIndexId,
Time: uint64(time.Now().Unix()), Time: uint64(time.Now().Unix()),
Cert: ci.certState.rawCertificateNoKey, Cert: ci.certState.rawCertificateNoKey,
} }
@ -73,9 +73,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
} }
func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
var ip uint32
if h.RemoteIndex == 0 {
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0) ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed // Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(1) ci.window.Update(1)
@ -84,7 +82,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if err != nil { if err != nil {
l.WithError(err).WithField("udpAddr", addr). l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
return true return
} }
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
@ -95,35 +93,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
l.WithError(err).WithField("udpAddr", addr). l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return true return
}
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
if hostinfo != nil {
hostinfo.RLock()
defer hostinfo.RUnlock()
if bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
if msg, ok := hostinfo.HandshakePacket[2]; ok {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return false
}
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true).
WithField("packets", hostinfo.HandshakePacket).
Error("Seen this handshake packet already but don't have a cached packet to return")
}
} }
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
@ -131,7 +101,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
l.WithError(err).WithField("udpAddr", addr). l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Invalid certificate from host") Info("Invalid certificate from host")
return true return
} }
vpnIP := ip2int(remoteCert.Details.Ips[0].IP) vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name certName := remoteCert.Details.Name
@ -143,20 +113,17 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
return true return
} }
hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci) hostinfo := &HostInfo{
if err != nil { ConnectionState: ci,
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). Remotes: []*HostInfoDest{},
WithField("certName", certName). localIndexId: myIndex,
WithField("fingerprint", fingerprint). remoteIndexId: hs.Details.InitiatorIndex,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager") hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0),
return true
} }
hostinfo.Lock()
defer hostinfo.Unlock()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
@ -165,7 +132,6 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message received") Info("Handshake message received")
f.handshakeManager.addRemoteIndexHostInfo(hs.Details.InitiatorIndex, hostinfo)
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey hs.Details.Cert = ci.certState.rawCertificateNoKey
@ -175,7 +141,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return true return
} }
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
@ -185,10 +151,62 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return true return
} else if dKey == nil || eKey == nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
return
} }
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) { hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
// We are sending handshake packet 2, so we don't expect to receive
// handshake packet 2 from the initiator.
ci.window.Update(2)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes()
hostinfo.AddRemote(*addr)
hostinfo.ForcePromoteBest(f.hostMap.preferredRanges)
hostinfo.CreateRemoteCIDR(remoteCert)
hostinfo.Lock()
defer hostinfo.Unlock()
// Only overwrite existing record if we should win the handshake race
overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
if err != nil {
switch err {
case ErrAlreadySeen:
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return
case ErrExistingHostInfo:
// This means there was an existing tunnel and we didn't win
// handshake avoidance
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@ -198,24 +216,33 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return true return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
Error("Failed to add HostInfo due to localIndex collision")
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to add HostInfo to HostMap")
return
}
} }
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:])) // Do the send
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
if dKey != nil && eKey != nil {
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
// We are sending handshake packet 2, so we don't expect to receive
// handshake packet 2 from the initiator.
ci.window.Update(2)
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr) err = f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
@ -232,48 +259,9 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
Info("Handshake message sent") Info("Handshake message sent")
} }
ip = ip2int(remoteCert.Details.Ips[0].IP)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes()
hostinfo.AddRemote(*addr)
hostinfo.CreateRemoteCIDR(remoteCert)
f.lightHouse.AddRemoteAndReset(ip, addr)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
WithField("remoteIndex", ho.remoteIndexId).
Debug("Handshake processing")
f.hostMap.DeleteHostInfo(ho)
}
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo)
hostinfo.handshakeComplete() hostinfo.handshakeComplete()
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
}
} return
f.hostMap.AddRemote(ip, addr)
return false
} }
func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
@ -286,7 +274,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Already seen this handshake packet") Info("Already seen this handshake packet")
return false return false
} }
@ -307,6 +295,11 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
// near future // near future
return false return false
} else if dKey == nil || eKey == nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
} }
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
@ -351,45 +344,20 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// Regardless of whether you are the sender or receiver, you should arrive here // Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection. // and complete standing up the connection.
if dKey != nil && eKey != nil {
ip := ip2int(remoteCert.Details.Ips[0].IP)
ci.peerCert = remoteCert ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey) ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs") //l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes() //hostinfo.ClearRemotes()
f.hostMap.AddRemote(ip, addr) hostinfo.AddRemote(*addr)
hostinfo.ForcePromoteBest(f.hostMap.preferredRanges)
hostinfo.CreateRemoteCIDR(remoteCert) hostinfo.CreateRemoteCIDR(remoteCert)
f.lightHouse.AddRemoteAndReset(ip, addr)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
WithField("remoteIndex", ho.remoteIndexId).
Debug("Handshake processing")
f.hostMap.DeleteHostInfo(ho)
}
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo)
f.handshakeManager.Complete(hostinfo, f)
hostinfo.handshakeComplete() hostinfo.handshakeComplete()
f.metricHandshakes.Update(duration) f.metricHandshakes.Update(duration)
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
}
return false return false
} }

View File

@ -1,9 +1,10 @@
package nebula package nebula
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"fmt" "errors"
"net" "net"
"time" "time"
@ -196,18 +197,123 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
return hostinfo return hostinfo
} }
func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) { var (
hostinfo, err := c.pendingHostMap.AddIndex(index, ci) ErrExistingHostInfo = errors.New("existing hostinfo")
if err != nil { ErrAlreadySeen = errors.New("already seen")
return nil, fmt.Errorf("Issue adding index: %d", index) ErrLocalIndexCollision = errors.New("local index collision")
)
// CheckAndComplete checks for any conflicts in the main and pending hostmap
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the
// exact same handshake packet
//
// ErrExistingHostInfo if we already have an entry in the hostmap for this
// VpnIP and overwrite was false.
//
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
c.pendingHostMap.RLock()
defer c.pendingHostMap.RUnlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
if found && existingHostInfo != nil {
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
return existingHostInfo, ErrAlreadySeen
} }
//c.mainHostMap.AddIndexHostInfo(index, hostinfo) if !overwrite {
c.InboundHandshakeTimer.Add(index, time.Second*10) return existingHostInfo, ErrExistingHostInfo
return hostinfo, nil }
}
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
if found {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
if found && existingIndex != hostinfo {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger().
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex")
}
if existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
}
c.mainHostMap.addHostInfo(hostinfo, f)
return existingHostInfo, nil
} }
func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) { // Complete is a simpler version of CheckAndComplete when we already know we
c.pendingHostMap.AddIndexHostInfo(index, h) // won't have a localIndexId collision because we already have an entry in the
// pendingHostMap
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
if found && existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger().
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex")
}
c.mainHostMap.addHostInfo(hostinfo, f)
}
// AddIndexHostInfo generates a unique localIndexId for this HostInfo
// and adds it to the pendingHostMap. Will error if we are unable to generate
// a unique localIndexId
func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.RLock()
defer c.mainHostMap.RUnlock()
for i := 0; i < 32; i++ {
index, err := generateIndex()
if err != nil {
return err
}
_, inPending := c.pendingHostMap.Indexes[index]
_, inMain := c.mainHostMap.Indexes[index]
if !inMain && !inPending {
h.localIndexId = index
c.pendingHostMap.Indexes[index] = h
return nil
}
}
return errors.New("failed to generate unique localIndexId")
} }
func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) { func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {

View File

@ -8,12 +8,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923} //var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
var ips []uint32 var ips []uint32
func Test_NewHandshakeManagerIndex(t *testing.T) { func Test_NewHandshakeManagerIndex(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@ -26,9 +25,18 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
now := time.Now() now := time.Now()
blah.NextInboundHandshakeTimerTick(now) blah.NextInboundHandshakeTimerTick(now)
var indexes = make([]uint32, 4)
var hostinfo = make([]*HostInfo, len(indexes))
for i := range indexes {
hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}}
}
// Add four indexes // Add four indexes
for _, v := range indexes { for i := range indexes {
blah.AddIndex(v, &ConnectionState{}) err := blah.AddIndexHostInfo(hostinfo[i])
assert.NoError(t, err)
indexes[i] = hostinfo[i].localIndexId
blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10)
} }
// Confirm they are in the pending index list // Confirm they are in the pending index list
for _, v := range indexes { for _, v := range indexes {
@ -169,8 +177,11 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
hostinfo := blah.AddVpnIP(vpnIP) hostinfo := blah.AddVpnIP(vpnIP)
// Pretned we have an index too // Pretned we have an index too
blah.AddIndexHostInfo(12341234, hostinfo) err := blah.AddIndexHostInfo(hostinfo)
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234)) assert.NoError(t, err)
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
assert.NotZero(t, hostinfo.localIndexId)
assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId)
// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending // Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
// but not main hostmap // but not main hostmap
@ -216,7 +227,10 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
now := time.Now() now := time.Now()
blah.NextInboundHandshakeTimerTick(now) blah.NextInboundHandshakeTimerTick(now)
hostinfo, _ := blah.AddIndex(12341234, &ConnectionState{}) hostinfo := &HostInfo{ConnectionState: &ConnectionState{}}
err := blah.AddIndexHostInfo(hostinfo)
assert.NoError(t, err)
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
// Pretned we have an index too // Pretned we have an index too
blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo) blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010)) assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
@ -229,7 +243,7 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3) next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
blah.NextInboundHandshakeTimerTick(next_tick) blah.NextInboundHandshakeTimerTick(next_tick)
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010)) assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234)) assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId))
} }
type mockEncWriter struct { type mockEncWriter struct {

View File

@ -166,40 +166,6 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
} }
} }
func (hm *HostMap) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
hm.Lock()
if _, ok := hm.Indexes[index]; !ok {
h := &HostInfo{
ConnectionState: ci,
Remotes: []*HostInfoDest{},
localIndexId: index,
HandshakePacket: make(map[uint8][]byte, 0),
}
hm.Indexes[index] = h
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": false, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
hm.Unlock()
return h, nil
}
hm.Unlock()
return nil, fmt.Errorf("refusing to overwrite existing index: %d", index)
}
func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock()
h.localIndexId = index
hm.Indexes[index] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
}
}
// Only used by pendingHostMap when the remote index is not initially known // Only used by pendingHostMap when the remote index is not initially known
func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock() hm.Lock()
@ -234,16 +200,12 @@ func (hm *HostMap) DeleteIndex(index uint32) {
hm.Lock() hm.Lock()
hostinfo, ok := hm.Indexes[index] hostinfo, ok := hm.Indexes[index]
if ok { if ok {
hostinfo.Lock()
defer hostinfo.Unlock()
delete(hm.Indexes, index) delete(hm.Indexes, index)
delete(hm.RemoteIndexes, hostinfo.remoteIndexId) delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
// Check if we have an entry under hostId that matches the same hostinfo // Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do. // instance. Clean it up as well if we do.
var hostinfo2 *HostInfo hostinfo2, ok := hm.Hosts[hostinfo.hostId]
hostinfo2, ok = hm.Hosts[hostinfo.hostId]
if ok && hostinfo2 == hostinfo { if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId) delete(hm.Hosts, hostinfo.hostId)
} }
@ -400,36 +362,26 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
} }
} }
func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool { // We already have the hm Lock when this is called, so make sure to not call
hm.RLock() // any other methods that might try to grab it again
if i, ok := hm.Hosts[vpnIP]; ok { func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
if i == nil { remoteCert := hostinfo.ConnectionState.peerCert
hm.RUnlock() ip := ip2int(remoteCert.Details.Ips[0].IP)
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
} }
hm.RUnlock()
return false
}
func (hm *HostMap) CheckHandshakeCompleteIndex(index uint32) bool { hm.Hosts[hostinfo.hostId] = hostinfo
hm.RLock() hm.Indexes[hostinfo.localIndexId] = hostinfo
if i, ok := hm.Indexes[index]; ok { hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if i == nil {
hm.RUnlock()
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
Debug("Hostmap vpnIp added")
} }
hm.RUnlock()
return false
} }
func (hm *HostMap) ClearRemotes(vpnIP uint32) { func (hm *HostMap) ClearRemotes(vpnIP uint32) {

View File

@ -106,7 +106,6 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
case recvError: case recvError:
f.messageMetrics.Rx(header.Type, header.Subtype, 1) f.messageMetrics.Rx(header.Type, header.Subtype, 1)
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header) f.handleRecvError(addr, header)
return return
@ -312,8 +311,6 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
} }
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
// This flag is to stop caring about recv_error from old versions
// This should go away when the old version is gone from prod
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("index", h.RemoteIndex). l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr). WithField("udpAddr", addr).