Don't use a global ca pool (#426)
This commit is contained in:
parent
4603b5b2dd
commit
883e09a392
2
cert.go
2
cert.go
|
@ -11,8 +11,6 @@ import (
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var trustedCAs *cert.NebulaCAPool
|
|
||||||
|
|
||||||
type CertState struct {
|
type CertState struct {
|
||||||
certificate *cert.NebulaCertificate
|
certificate *cert.NebulaCertificate
|
||||||
rawCertificate []byte
|
rawCertificate []byte
|
||||||
|
|
|
@ -96,7 +96,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||||
|
@ -318,7 +318,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
|
|
|
@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
ci.queueLock.Unlock()
|
ci.queueLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache)
|
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
|
||||||
if dropReason == nil {
|
if dropReason == nil {
|
||||||
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
|
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
|
||||||
if f.lightHouse != nil && mc%5000 == 0 {
|
if f.lightHouse != nil && mc%5000 == 0 {
|
||||||
|
@ -140,7 +140,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
|
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("fwPacket", fp).
|
f.l.WithField("fwPacket", fp).
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const mtu = 9001
|
||||||
|
@ -41,6 +42,7 @@ type InterfaceConfig struct {
|
||||||
routines int
|
routines int
|
||||||
MessageMetrics *MessageMetrics
|
MessageMetrics *MessageMetrics
|
||||||
version string
|
version string
|
||||||
|
caPool *cert.NebulaCAPool
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
@ -63,6 +65,7 @@ type Interface struct {
|
||||||
dropMulticast bool
|
dropMulticast bool
|
||||||
udpBatchSize int
|
udpBatchSize int
|
||||||
routines int
|
routines int
|
||||||
|
caPool *cert.NebulaCAPool
|
||||||
|
|
||||||
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
|
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
|
||||||
rebindCount int8
|
rebindCount int8
|
||||||
|
@ -111,6 +114,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]*udpConn, c.routines),
|
writers: make([]*udpConn, c.routines),
|
||||||
readers: make([]io.ReadWriteCloser, c.routines),
|
readers: make([]io.ReadWriteCloser, c.routines),
|
||||||
|
caPool: c.caPool,
|
||||||
|
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
|
|
||||||
|
@ -218,8 +222,8 @@ func (f *Interface) reloadCA(c *Config) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
trustedCAs = newCAs
|
f.caPool = newCAs
|
||||||
f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadCertKey(c *Config) {
|
func (f *Interface) reloadCertKey(c *Config) {
|
||||||
|
|
6
main.go
6
main.go
|
@ -42,13 +42,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// trustedCAs is currently a global, so loadCA operates on that global directly
|
caPool, err := loadCAFromConfig(l, config)
|
||||||
trustedCAs, err = loadCAFromConfig(l, config)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of loadCA are already nicely formatted
|
//The errors coming out of loadCA are already nicely formatted
|
||||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
|
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||||
|
|
||||||
cs, err := NewCertStateFromConfig(config)
|
cs, err := NewCertStateFromConfig(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -365,6 +364,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
routines: routines,
|
routines: routines,
|
||||||
MessageMetrics: messageMetrics,
|
MessageMetrics: messageMetrics,
|
||||||
version: buildVersion,
|
version: buildVersion,
|
||||||
|
caPool: caPool,
|
||||||
|
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
l: l,
|
l: l,
|
||||||
|
|
|
@ -280,7 +280,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
|
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
|
@ -368,7 +368,7 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*cert.NebulaCertificate, error) {
|
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
|
||||||
pk := h.PeerStatic()
|
pk := h.PeerStatic()
|
||||||
|
|
||||||
if pk == nil {
|
if pk == nil {
|
||||||
|
@ -397,7 +397,7 @@ func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*ce
|
||||||
}
|
}
|
||||||
|
|
||||||
c, _ := cert.UnmarshalNebulaCertificate(recombined)
|
c, _ := cert.UnmarshalNebulaCertificate(recombined)
|
||||||
isValid, err := c.Verify(time.Now(), trustedCAs)
|
isValid, err := c.Verify(time.Now(), caPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c, fmt.Errorf("certificate validation failed: %s", err)
|
return c, fmt.Errorf("certificate validation failed: %s", err)
|
||||||
} else if !isValid {
|
} else if !isValid {
|
||||||
|
|
Loading…
Reference in New Issue