Don't use a global ca pool (#426)
This commit is contained in:
parent
4603b5b2dd
commit
883e09a392
2
cert.go
2
cert.go
|
@ -11,8 +11,6 @@ import (
|
|||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
var trustedCAs *cert.NebulaCAPool
|
||||
|
||||
type CertState struct {
|
||||
certificate *cert.NebulaCertificate
|
||||
rawCertificate []byte
|
||||
|
|
|
@ -96,7 +96,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
return
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
|
@ -318,7 +318,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
return true
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
|
|
|
@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
|||
ci.queueLock.Unlock()
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache)
|
||||
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
|
||||
if dropReason == nil {
|
||||
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
|
||||
if f.lightHouse != nil && mc%5000 == 0 {
|
||||
|
@ -140,7 +140,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
|||
}
|
||||
|
||||
// check if packet is in outbound fw rules
|
||||
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
|
||||
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil)
|
||||
if dropReason != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("fwPacket", fp).
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
const mtu = 9001
|
||||
|
@ -41,6 +42,7 @@ type InterfaceConfig struct {
|
|||
routines int
|
||||
MessageMetrics *MessageMetrics
|
||||
version string
|
||||
caPool *cert.NebulaCAPool
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
l *logrus.Logger
|
||||
|
@ -63,6 +65,7 @@ type Interface struct {
|
|||
dropMulticast bool
|
||||
udpBatchSize int
|
||||
routines int
|
||||
caPool *cert.NebulaCAPool
|
||||
|
||||
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
|
||||
rebindCount int8
|
||||
|
@ -111,6 +114,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
|||
version: c.version,
|
||||
writers: make([]*udpConn, c.routines),
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
caPool: c.caPool,
|
||||
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
|
||||
|
@ -218,8 +222,8 @@ func (f *Interface) reloadCA(c *Config) {
|
|||
return
|
||||
}
|
||||
|
||||
trustedCAs = newCAs
|
||||
f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
||||
f.caPool = newCAs
|
||||
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
||||
}
|
||||
|
||||
func (f *Interface) reloadCertKey(c *Config) {
|
||||
|
|
6
main.go
6
main.go
|
@ -42,13 +42,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
}
|
||||
})
|
||||
|
||||
// trustedCAs is currently a global, so loadCA operates on that global directly
|
||||
trustedCAs, err = loadCAFromConfig(l, config)
|
||||
caPool, err := loadCAFromConfig(l, config)
|
||||
if err != nil {
|
||||
//The errors coming out of loadCA are already nicely formatted
|
||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||
}
|
||||
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
|
||||
cs, err := NewCertStateFromConfig(config)
|
||||
if err != nil {
|
||||
|
@ -365,6 +364,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
routines: routines,
|
||||
MessageMetrics: messageMetrics,
|
||||
version: buildVersion,
|
||||
caPool: caPool,
|
||||
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
l: l,
|
||||
|
|
|
@ -280,7 +280,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||
return
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
|
||||
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
|
||||
if dropReason != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||
|
@ -368,7 +368,7 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
|
|||
}
|
||||
*/
|
||||
|
||||
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*cert.NebulaCertificate, error) {
|
||||
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
|
||||
pk := h.PeerStatic()
|
||||
|
||||
if pk == nil {
|
||||
|
@ -397,7 +397,7 @@ func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*ce
|
|||
}
|
||||
|
||||
c, _ := cert.UnmarshalNebulaCertificate(recombined)
|
||||
isValid, err := c.Verify(time.Now(), trustedCAs)
|
||||
isValid, err := c.Verify(time.Now(), caPool)
|
||||
if err != nil {
|
||||
return c, fmt.Errorf("certificate validation failed: %s", err)
|
||||
} else if !isValid {
|
||||
|
|
Loading…
Reference in New Issue