create ConnectionState before adding to HostMap (#535)
We have a few small race conditions with creating the HostInfo.ConnectionState since we add the host info to the pendingHostMap before we set this field. We can make everything a lot easier if we just add an "init" function so that we can set this field in the hostinfo before we add it to the hostmap.
This commit is contained in:
parent
16be0ce566
commit
304b12f63f
|
@ -57,6 +57,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- A race condition when `punchy.respond` is enabled and ensures the correct
|
- A race condition when `punchy.respond` is enabled and ensures the correct
|
||||||
vpn ip is sent a punch back response in highly queried node. (#566)
|
vpn ip is sent a punch back response in highly queried node. (#566)
|
||||||
|
|
||||||
|
- Fix a rare crash during handshake due to a race condition. (#535)
|
||||||
|
|
||||||
## [1.4.0] - 2021-05-11
|
## [1.4.0] - 2021-05-11
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
@ -57,7 +57,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
nc.HandleMonitorTick(now, p, nb, out)
|
nc.HandleMonitorTick(now, p, nb, out)
|
||||||
// Add an ip we have established a connection w/ to hostmap
|
// Add an ip we have established a connection w/ to hostmap
|
||||||
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
|
hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
certState: cs,
|
certState: cs,
|
||||||
H: &noise.HandshakeState{},
|
H: &noise.HandshakeState{},
|
||||||
|
@ -126,7 +126,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
nc.HandleMonitorTick(now, p, nb, out)
|
nc.HandleMonitorTick(now, p, nb, out)
|
||||||
// Add an ip we have established a connection w/ to hostmap
|
// Add an ip we have established a connection w/ to hostmap
|
||||||
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
|
hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
certState: cs,
|
certState: cs,
|
||||||
H: &noise.HandshakeState{},
|
H: &noise.HandshakeState{},
|
||||||
|
@ -232,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
nc := newConnectionManager(ctx, l, ifce, 5, 10)
|
nc := newConnectionManager(ctx, l, ifce, 5, 10)
|
||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
|
hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
certState: cs,
|
certState: cs,
|
||||||
peerCert: &peerCert,
|
peerCert: &peerCert,
|
||||||
|
|
|
@ -191,13 +191,13 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
|
||||||
hostinfo := c.pendingHostMap.AddVpnIp(vpnIp)
|
hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
|
||||||
// We lock here and use an array to insert items to prevent locking the
|
|
||||||
// main receive thread for very long by waiting to add items to the pending map
|
if created {
|
||||||
//TODO: what lock?
|
|
||||||
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
|
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
|
||||||
c.metricInitiated.Inc(1)
|
c.metricInitiated.Inc(1)
|
||||||
|
}
|
||||||
|
|
||||||
return hostinfo
|
return hostinfo
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,19 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||||
|
|
||||||
i := blah.AddVpnIp(ip)
|
var initCalled bool
|
||||||
|
initFunc := func(*HostInfo) {
|
||||||
|
initCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
i := blah.AddVpnIp(ip, initFunc)
|
||||||
|
assert.True(t, initCalled)
|
||||||
|
|
||||||
|
initCalled = false
|
||||||
|
i2 := blah.AddVpnIp(ip, initFunc)
|
||||||
|
assert.False(t, initCalled)
|
||||||
|
assert.Same(t, i, i2)
|
||||||
|
|
||||||
i.remotes = NewRemoteList()
|
i.remotes = NewRemoteList()
|
||||||
i.HandshakeReady = true
|
i.HandshakeReady = true
|
||||||
|
|
||||||
|
@ -71,7 +83,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
||||||
|
|
||||||
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||||
|
|
||||||
hi := blah.AddVpnIp(ip)
|
hi := blah.AddVpnIp(ip, nil)
|
||||||
hi.HandshakeReady = true
|
hi.HandshakeReady = true
|
||||||
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||||
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
|
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
|
||||||
|
|
13
hostmap.go
13
hostmap.go
|
@ -134,24 +134,25 @@ func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) {
|
||||||
h := &HostInfo{}
|
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
if _, ok := hm.Hosts[vpnIp]; !ok {
|
if h, ok := hm.Hosts[vpnIp]; !ok {
|
||||||
hm.RUnlock()
|
hm.RUnlock()
|
||||||
h = &HostInfo{
|
h = &HostInfo{
|
||||||
promoteCounter: 0,
|
promoteCounter: 0,
|
||||||
vpnIp: vpnIp,
|
vpnIp: vpnIp,
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
}
|
}
|
||||||
|
if init != nil {
|
||||||
|
init(h)
|
||||||
|
}
|
||||||
hm.Lock()
|
hm.Lock()
|
||||||
hm.Hosts[vpnIp] = h
|
hm.Hosts[vpnIp] = h
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
return h
|
return h, true
|
||||||
} else {
|
} else {
|
||||||
h = hm.Hosts[vpnIp]
|
|
||||||
hm.RUnlock()
|
hm.RUnlock()
|
||||||
return h
|
return h, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
18
inside.go
18
inside.go
|
@ -83,7 +83,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
|
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
|
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
|
@ -102,16 +102,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
||||||
return hostinfo
|
return hostinfo
|
||||||
}
|
}
|
||||||
|
|
||||||
if ci == nil {
|
|
||||||
// if we don't have a connection state, then send a handshake initiation
|
|
||||||
ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
|
|
||||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
|
||||||
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
|
|
||||||
hostinfo.ConnectionState = ci
|
|
||||||
} else if ci.eKey == nil {
|
|
||||||
// if we don't have any state at all, create it
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have already created the handshake packet, we don't want to call the function at all.
|
// If we have already created the handshake packet, we don't want to call the function at all.
|
||||||
if !hostinfo.HandshakeReady {
|
if !hostinfo.HandshakeReady {
|
||||||
ixHandshakeStage0(f, vpnIp, hostinfo)
|
ixHandshakeStage0(f, vpnIp, hostinfo)
|
||||||
|
@ -131,6 +121,12 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
||||||
return hostinfo
|
return hostinfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
|
||||||
|
// will create the initial Noise ConnectionState
|
||||||
|
func (f *Interface) initHostInfo(hostinfo *HostInfo) {
|
||||||
|
hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
|
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
|
||||||
fp := &firewall.Packet{}
|
fp := &firewall.Packet{}
|
||||||
err := newPacket(p, false, fp)
|
err := newPacket(p, false, fp)
|
||||||
|
|
2
ssh.go
2
ssh.go
|
@ -569,7 +569,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
|
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
|
||||||
if addr != nil {
|
if addr != nil {
|
||||||
hostInfo.SetRemote(addr)
|
hostInfo.SetRemote(addr)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue