create ConnectionState before adding to HostMap (#535)

We have a few small race conditions with creating the HostInfo.ConnectionState
since we add the host info to the pendingHostMap before we set this
field. We can make everything a lot easier if we just add an "init"
function so that we can set this field in the hostinfo before we add it
to the hostmap.
This commit is contained in:
Wade Simmons 2021-11-08 14:46:22 -05:00 committed by GitHub
parent 16be0ce566
commit 304b12f63f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 42 additions and 31 deletions

View File

@ -57,6 +57,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- A race condition when `punchy.respond` is enabled and ensures the correct
vpn ip is sent a punch back response in highly queried node. (#566)
- Fix a rare crash during handshake due to a race condition. (#535)
## [1.4.0] - 2021-05-11
### Added

View File

@ -57,7 +57,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
out := make([]byte, mtu)
nc.HandleMonitorTick(now, p, nb, out)
// Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
@ -126,7 +126,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
out := make([]byte, mtu)
nc.HandleMonitorTick(now, p, nb, out)
// Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
@ -232,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
defer cancel()
nc := newConnectionManager(ctx, l, ifce, 5, 10)
ifce.connectionManager = nc
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
peerCert: &peerCert,

View File

@ -191,13 +191,13 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
}
}
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIp(vpnIp)
// We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map
//TODO: what lock?
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
c.metricInitiated.Inc(1)
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
if created {
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
c.metricInitiated.Inc(1)
}
return hostinfo
}

View File

@ -27,7 +27,19 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
i := blah.AddVpnIp(ip)
var initCalled bool
initFunc := func(*HostInfo) {
initCalled = true
}
i := blah.AddVpnIp(ip, initFunc)
assert.True(t, initCalled)
initCalled = false
i2 := blah.AddVpnIp(ip, initFunc)
assert.False(t, initCalled)
assert.Same(t, i, i2)
i.remotes = NewRemoteList()
i.HandshakeReady = true
@ -71,7 +83,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.AddVpnIp(ip)
hi := blah.AddVpnIp(ip, nil)
hi.HandshakeReady = true
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")

View File

@ -134,24 +134,25 @@ func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
hm.Unlock()
}
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
h := &HostInfo{}
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) {
hm.RLock()
if _, ok := hm.Hosts[vpnIp]; !ok {
if h, ok := hm.Hosts[vpnIp]; !ok {
hm.RUnlock()
h = &HostInfo{
promoteCounter: 0,
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
}
if init != nil {
init(h)
}
hm.Lock()
hm.Hosts[vpnIp] = h
hm.Unlock()
return h
return h, true
} else {
h = hm.Hosts[vpnIp]
hm.RUnlock()
return h
return h, false
}
}

View File

@ -83,7 +83,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
if err != nil {
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
}
}
ci := hostinfo.ConnectionState
@ -102,16 +102,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
return hostinfo
}
if ci == nil {
// if we don't have a connection state, then send a handshake initiation
ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
hostinfo.ConnectionState = ci
} else if ci.eKey == nil {
// if we don't have any state at all, create it
}
// If we have already created the handshake packet, we don't want to call the function at all.
if !hostinfo.HandshakeReady {
ixHandshakeStage0(f, vpnIp, hostinfo)
@ -131,6 +121,12 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
return hostinfo
}
// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
// will create the initial Noise ConnectionState
func (f *Interface) initHostInfo(hostinfo *HostInfo) {
hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
fp := &firewall.Packet{}
err := newPacket(p, false, fp)

2
ssh.go
View File

@ -569,7 +569,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
}
}
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
if addr != nil {
hostInfo.SetRemote(addr)
}