Don't use a global ca pool (#426)

This commit is contained in:
Nathan Brown 2021-03-29 12:10:19 -05:00 committed by GitHub
parent 4603b5b2dd
commit 883e09a392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 16 additions and 14 deletions

View File

@ -11,8 +11,6 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
var trustedCAs *cert.NebulaCAPool
type CertState struct { type CertState struct {
certificate *cert.NebulaCertificate certificate *cert.NebulaCertificate
rawCertificate []byte rawCertificate []byte

View File

@ -96,7 +96,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return return
} }
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
if err != nil { if err != nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
@ -318,7 +318,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
return true return true
} }
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).

View File

@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
ci.queueLock.Unlock() ci.queueLock.Unlock()
} }
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache) dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
if dropReason == nil { if dropReason == nil {
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
if f.lightHouse != nil && mc%5000 == 0 { if f.lightHouse != nil && mc%5000 == 0 {
@ -140,7 +140,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
} }
// check if packet is in outbound fw rules // check if packet is in outbound fw rules
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil) dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil)
if dropReason != nil { if dropReason != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp). f.l.WithField("fwPacket", fp).

View File

@ -10,6 +10,7 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
) )
const mtu = 9001 const mtu = 9001
@ -41,6 +42,7 @@ type InterfaceConfig struct {
routines int routines int
MessageMetrics *MessageMetrics MessageMetrics *MessageMetrics
version string version string
caPool *cert.NebulaCAPool
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
l *logrus.Logger l *logrus.Logger
@ -63,6 +65,7 @@ type Interface struct {
dropMulticast bool dropMulticast bool
udpBatchSize int udpBatchSize int
routines int routines int
caPool *cert.NebulaCAPool
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse // rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
rebindCount int8 rebindCount int8
@ -111,6 +114,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
version: c.version, version: c.version,
writers: make([]*udpConn, c.routines), writers: make([]*udpConn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines), readers: make([]io.ReadWriteCloser, c.routines),
caPool: c.caPool,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
@ -218,8 +222,8 @@ func (f *Interface) reloadCA(c *Config) {
return return
} }
trustedCAs = newCAs f.caPool = newCAs
f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed") f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
} }
func (f *Interface) reloadCertKey(c *Config) { func (f *Interface) reloadCertKey(c *Config) {

View File

@ -42,13 +42,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
}) })
// trustedCAs is currently a global, so loadCA operates on that global directly caPool, err := loadCAFromConfig(l, config)
trustedCAs, err = loadCAFromConfig(l, config)
if err != nil { if err != nil {
//The errors coming out of loadCA are already nicely formatted //The errors coming out of loadCA are already nicely formatted
return nil, NewContextualError("Failed to load ca from config", nil, err) return nil, NewContextualError("Failed to load ca from config", nil, err)
} }
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints") l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config) cs, err := NewCertStateFromConfig(config)
if err != nil { if err != nil {
@ -365,6 +364,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
routines: routines, routines: routines,
MessageMetrics: messageMetrics, MessageMetrics: messageMetrics,
version: buildVersion, version: buildVersion,
caPool: caPool,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
l: l, l: l,

View File

@ -280,7 +280,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return return
} }
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache) dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
if dropReason != nil { if dropReason != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket). hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
@ -368,7 +368,7 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
} }
*/ */
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*cert.NebulaCertificate, error) { func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
pk := h.PeerStatic() pk := h.PeerStatic()
if pk == nil { if pk == nil {
@ -397,7 +397,7 @@ func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*ce
} }
c, _ := cert.UnmarshalNebulaCertificate(recombined) c, _ := cert.UnmarshalNebulaCertificate(recombined)
isValid, err := c.Verify(time.Now(), trustedCAs) isValid, err := c.Verify(time.Now(), caPool)
if err != nil { if err != nil {
return c, fmt.Errorf("certificate validation failed: %s", err) return c, fmt.Errorf("certificate validation failed: %s", err)
} else if !isValid { } else if !isValid {