Dont apply race avoidance to existing handshakes, use the handshake time to determine who wins (#451)
Co-authored-by: Wade Simmons <wadey@slack-corp.com>
This commit is contained in:
parent
df7c7eec4a
commit
db23fdf9bc
|
@ -124,6 +124,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
remoteIndexId: hs.Details.InitiatorIndex,
|
remoteIndexId: hs.Details.InitiatorIndex,
|
||||||
hostId: vpnIP,
|
hostId: vpnIP,
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
|
lastHandshakeTime: hs.Details.Time,
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.Lock()
|
hostinfo.Lock()
|
||||||
|
@ -138,6 +139,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = ci.certState.rawCertificateNoKey
|
hs.Details.Cert = ci.certState.rawCertificateNoKey
|
||||||
|
// Update the time in case their clock is way off from ours
|
||||||
|
hs.Details.Time = uint64(time.Now().Unix())
|
||||||
|
|
||||||
hsBytes, err := proto.Marshal(hs)
|
hsBytes, err := proto.Marshal(hs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -204,18 +207,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case ErrExistingHostInfo:
|
case ErrExistingHostInfo:
|
||||||
// This means there was an existing tunnel and we didn't win
|
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||||
// handshake avoidance
|
|
||||||
|
|
||||||
//TODO: sprinkle the new protobuf stuff in here, send a reply to get the recv_errors flowing
|
|
||||||
//TODO: if not new send a test packet like old
|
|
||||||
|
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
||||||
|
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Info("Prevented a handshake race")
|
Info("Handshake too old")
|
||||||
|
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||||
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
|
@ -394,7 +394,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
Info("Handshake message received")
|
Info("Handshake message received")
|
||||||
|
|
||||||
hostinfo.remoteIndexId = hs.Details.ResponderIndex
|
hostinfo.remoteIndexId = hs.Details.ResponderIndex
|
||||||
hs.Details.Cert = ci.certState.rawCertificateNoKey
|
hostinfo.lastHandshakeTime = hs.Details.Time
|
||||||
|
|
||||||
// Store their cert and our symmetric keys
|
// Store their cert and our symmetric keys
|
||||||
ci.peerCert = remoteCert
|
ci.peerCert = remoteCert
|
||||||
|
|
|
@ -199,7 +199,7 @@ var (
|
||||||
// exact same handshake packet
|
// exact same handshake packet
|
||||||
//
|
//
|
||||||
// ErrExistingHostInfo if we already have an entry in the hostmap for this
|
// ErrExistingHostInfo if we already have an entry in the hostmap for this
|
||||||
// VpnIP and overwrite was false.
|
// VpnIP and the new handshake was older than the one we currently have
|
||||||
//
|
//
|
||||||
// ErrLocalIndexCollision if we already have an entry in the main or pending
|
// ErrLocalIndexCollision if we already have an entry in the main or pending
|
||||||
// hostmap for the hostinfo.localIndexId.
|
// hostmap for the hostinfo.localIndexId.
|
||||||
|
@ -217,10 +217,12 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
||||||
return existingHostInfo, ErrAlreadySeen
|
return existingHostInfo, ErrAlreadySeen
|
||||||
}
|
}
|
||||||
|
|
||||||
if !overwrite {
|
// Is this a newer handshake?
|
||||||
// It's a new handshake and we lost the race
|
if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime {
|
||||||
return existingHostInfo, ErrExistingHostInfo
|
return existingHostInfo, ErrExistingHostInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
existingHostInfo.logger(c.l).Info("Taking new handshake")
|
||||||
}
|
}
|
||||||
|
|
||||||
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
|
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
|
||||||
|
@ -261,7 +263,6 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
||||||
}
|
}
|
||||||
|
|
||||||
if existingHostInfo != nil {
|
if existingHostInfo != nil {
|
||||||
hostinfo.logger(c.l).Info("Race lost, taking new handshake")
|
|
||||||
// We are going to overwrite this entry, so remove the old references
|
// We are going to overwrite this entry, so remove the old references
|
||||||
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
|
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
|
||||||
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
||||||
|
|
|
@ -59,6 +59,11 @@ type HostInfo struct {
|
||||||
// with a handshake
|
// with a handshake
|
||||||
lastRebindCount int8
|
lastRebindCount int8
|
||||||
|
|
||||||
|
// lastHandshakeTime records the time the remote side told us about at the stage when the handshake was completed locally
|
||||||
|
// Stage 1 packet will contain it if I am a responder, stage 2 packet if I am an initiator
|
||||||
|
// This is used to avoid an attack where a handshake packet is replayed after some time
|
||||||
|
lastHandshakeTime uint64
|
||||||
|
|
||||||
lastRoam time.Time
|
lastRoam time.Time
|
||||||
lastRoamRemote *udpAddr
|
lastRoamRemote *udpAddr
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue