Fix most known data races (#396)

This change fixes all of the known data races that `make smoke-docker-race` finds, except for one.

Most of these races are around the handshake phase for a hostinfo, so we add a RWLock to the hostinfo and Lock during each of the handshake stages.

Some of the other races are around consistently using `atomic` around the `messageCounter` field. To make this harder to mess up, I have renamed the field to `atomicMessageCounter` (I also removed the unnecessary extra pointer deference as we can just point directly to the struct field).

The last remaining data race is around reading `ConnectionInfo.ready`, which is a boolean that is only written to once when the handshake has finished. Due to it being in the hot path for packets and the rare case that this could actually be an issue, holding off on fixing that one for now.

here is the results of `make smoke-docker-race`:

before:

    lighthouse1: Found 2 data race(s)
    host2:       Found 36 data race(s)
    host3:       Found 17 data race(s)
    host4:       Found 31 data race(s)

after:

    host2: Found 1 data race(s)
    host4: Found 1 data race(s)

Fixes: #147
Fixes: #226
Fixes: #283
Fixes: #316
This commit is contained in:
Wade Simmons 2021-03-05 21:18:33 -05:00 committed by GitHub
parent 29c5f31f90
commit d604270966
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 97 additions and 76 deletions

View File

@ -51,7 +51,6 @@ func Test_NewConnectionManagerTest(t *testing.T) {
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
certState: cs, certState: cs,
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
messageCounter: new(uint64),
} }
// We saw traffic out to vpnIP // We saw traffic out to vpnIP
@ -117,7 +116,6 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
certState: cs, certState: cs,
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
messageCounter: new(uint64),
} }
// We saw traffic out to vpnIP // We saw traffic out to vpnIP

View File

@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"sync" "sync"
"sync/atomic"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
@ -18,7 +19,7 @@ type ConnectionState struct {
certState *CertState certState *CertState
peerCert *cert.NebulaCertificate peerCert *cert.NebulaCertificate
initiator bool initiator bool
messageCounter *uint64 atomicMessageCounter uint64
window *Bits window *Bits
queueLock sync.Mutex queueLock sync.Mutex
writeLock sync.Mutex writeLock sync.Mutex
@ -59,7 +60,6 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
window: b, window: b,
ready: false, ready: false,
certState: curCertState, certState: curCertState,
messageCounter: new(uint64),
} }
return ci return ci
@ -69,7 +69,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
return json.Marshal(m{ return json.Marshal(m{
"certificate": cs.peerCert, "certificate": cs.peerCert,
"initiator": cs.initiator, "initiator": cs.initiator,
"message_counter": cs.messageCounter, "message_counter": atomic.LoadUint64(&cs.atomicMessageCounter),
"ready": cs.ready, "ready": cs.ready,
}) })
} }

View File

@ -4,6 +4,7 @@ import (
"net" "net"
"os" "os"
"os/signal" "os/signal"
"sync/atomic"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -156,7 +157,7 @@ func copyHostInfo(h *HostInfo) ControlHostInfo {
RemoteIndex: h.remoteIndexId, RemoteIndex: h.remoteIndexId,
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)), RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
CachedPackets: len(h.packetStore), CachedPackets: len(h.packetStore),
MessageCounter: *h.ConnectionState.messageCounter, MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter),
} }
if c := h.GetCert(); c != nil { if c := h.GetCert(); c != nil {

View File

@ -43,7 +43,6 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
}, },
Signature: []byte{1, 2, 1, 2, 1, 3}, Signature: []byte{1, 2, 1, 2, 1, 3},
} }
counter := uint64(0)
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)} remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
hm.Add(ip2int(ipNet.IP), &HostInfo{ hm.Add(ip2int(ipNet.IP), &HostInfo{
@ -51,7 +50,6 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
Remotes: remotes, Remotes: remotes,
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: crt, peerCert: crt,
messageCounter: &counter,
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,
@ -63,7 +61,6 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
Remotes: remotes, Remotes: remotes,
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: nil, peerCert: nil,
messageCounter: &counter,
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,

View File

@ -54,7 +54,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
} }
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1) header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1)
atomic.AddUint64(ci.messageCounter, 1) atomic.AddUint64(&ci.atomicMessageCounter, 1)
msg, _, _, err := ci.H.WriteMessage(header, hsBytes) msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
if err != nil { if err != nil {
@ -99,7 +99,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
} }
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex) hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) { if hostinfo != nil {
hostinfo.RLock()
defer hostinfo.RUnlock()
if bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
if msg, ok := hostinfo.HandshakePacket[2]; ok { if msg, ok := hostinfo.HandshakePacket[2]; ok {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
@ -120,6 +124,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
WithField("packets", hostinfo.HandshakePacket). WithField("packets", hostinfo.HandshakePacket).
Error("Seen this handshake packet already but don't have a cached packet to return") Error("Seen this handshake packet already but don't have a cached packet to return")
} }
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if err != nil { if err != nil {
@ -150,6 +155,9 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
return true return true
} }
hostinfo.Lock()
defer hostinfo.Unlock()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@ -272,6 +280,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if hostinfo == nil { if hostinfo == nil {
return true return true
} }
hostinfo.Lock()
defer hostinfo.Unlock()
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).

View File

@ -103,6 +103,8 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
if err != nil { if err != nil {
return return
} }
hostinfo.Lock()
defer hostinfo.Unlock()
// If we haven't finished the handshake and we haven't hit max retries, query // If we haven't finished the handshake and we haven't hit max retries, query
// lighthouse and then send the handshake packet again. // lighthouse and then send the handshake packet again.

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
@ -35,6 +36,8 @@ type HostMap struct {
} }
type HostInfo struct { type HostInfo struct {
sync.RWMutex
remote *udpAddr remote *udpAddr
Remotes []*HostInfoDest Remotes []*HostInfoDest
promoteCounter uint32 promoteCounter uint32
@ -231,6 +234,9 @@ func (hm *HostMap) DeleteIndex(index uint32) {
hm.Lock() hm.Lock()
hostinfo, ok := hm.Indexes[index] hostinfo, ok := hm.Indexes[index]
if ok { if ok {
hostinfo.Lock()
defer hostinfo.Unlock()
delete(hm.Indexes, index) delete(hm.Indexes, index)
delete(hm.RemoteIndexes, hostinfo.remoteIndexId) delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
@ -513,8 +519,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
return return
} }
i.promoteCounter++ if atomic.AddUint32(&i.promoteCounter, 1)&PromoteEvery == 0 {
if i.promoteCounter%PromoteEvery == 0 {
// return early if we are already on a preferred remote // return early if we are already on a preferred remote
rIP := udp2ip(i.remote) rIP := udp2ip(i.remote)
for _, l := range preferredRanges { for _, l := range preferredRanges {
@ -615,10 +620,12 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
copy(tempPacket, packet) copy(tempPacket, packet)
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket) //l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket}) i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
if l.Level >= logrus.DebugLevel {
i.logger(). i.logger().
WithField("length", len(i.packetStore)). WithField("length", len(i.packetStore)).
WithField("stored", true). WithField("stored", true).
Debugf("Packet store") Debugf("Packet store")
}
} else if l.Level >= logrus.DebugLevel { } else if l.Level >= logrus.DebugLevel {
i.logger(). i.logger().
@ -638,7 +645,7 @@ func (i *HostInfo) handshakeComplete() {
i.HandshakeComplete = true i.HandshakeComplete = true
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
// Clamping it to 2 gets us out of the woods for now // Clamping it to 2 gets us out of the woods for now
*i.ConnectionState.messageCounter = 2 atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
i.logger().Debugf("Sending %d stored packets", len(i.packetStore)) i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)

View File

@ -102,6 +102,10 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
} }
// If we have already created the handshake packet, we don't want to call the function at all. // If we have already created the handshake packet, we don't want to call the function at all.
if !hostinfo.HandshakeReady {
hostinfo.Lock()
defer hostinfo.Unlock()
if !hostinfo.HandshakeReady { if !hostinfo.HandshakeReady {
ixHandshakeStage0(f, vpnIp, hostinfo) ixHandshakeStage0(f, vpnIp, hostinfo)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
@ -116,6 +120,7 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
} }
} }
} }
}
return hostinfo return hostinfo
} }
@ -139,8 +144,8 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
return return
} }
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0) messageCounter := f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 { if f.lightHouse != nil && messageCounter%5000 == 0 {
f.lightHouse.Query(fp.RemoteIP, f) f.lightHouse.Query(fp.RemoteIP, f)
} }
} }
@ -223,7 +228,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
var err error var err error
//TODO: enable if we do more than 1 tun queue //TODO: enable if we do more than 1 tun queue
//ci.writeLock.Lock() //ci.writeLock.Lock()
c := atomic.AddUint64(ci.messageCounter, 1) c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c) out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c)
@ -247,7 +252,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
if err != nil { if err != nil {
hostinfo.logger().WithError(err). hostinfo.logger().WithError(err).
WithField("udpAddr", remote).WithField("counter", c). WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", ci.messageCounter). WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet") Error("Failed to encrypt outgoing packet")
return c return c
} }

View File

@ -134,11 +134,6 @@ func (f *Interface) run() {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
go f.listenOut(i)
}
// Prepare n tun queues // Prepare n tun queues
var reader io.ReadWriteCloser = f.inside var reader io.ReadWriteCloser = f.inside
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
@ -155,6 +150,11 @@ func (f *Interface) run() {
l.Fatal(err) l.Fatal(err)
} }
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
go f.listenOut(i)
}
// Launch n queues to read packets from tun dev // Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i) go f.listenIn(f.readers[i], i)

3
ssh.go
View File

@ -11,6 +11,7 @@ import (
"reflect" "reflect"
"runtime/pprof" "runtime/pprof"
"strings" "strings"
"sync/atomic"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -353,7 +354,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
} }
if v.ConnectionState != nil { if v.ConnectionState != nil {
h["messageCounter"] = v.ConnectionState.messageCounter h["messageCounter"] = atomic.LoadUint64(&v.ConnectionState.atomicMessageCounter)
} }
d[x] = h d[x] = h