Teardown tunnel automatically if peer's certificate expired (#370)
This commit is contained in:
parent
e8b08e49e6
commit
32e2619323
|
@ -166,7 +166,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
||||||
// Check for traffic coming back in from this host.
|
// Check for traffic coming back in from this host.
|
||||||
traf := n.CheckIn(vpnIP)
|
traf := n.CheckIn(vpnIP)
|
||||||
|
|
||||||
// If we saw incoming packets from this ip, just return
|
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
||||||
|
if err != nil {
|
||||||
|
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||||
|
|
||||||
|
if !n.intf.disconnectInvalid {
|
||||||
|
n.ClearIP(vpnIP)
|
||||||
|
n.ClearPendingDeletion(vpnIP)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we saw an incoming packets from this ip and peer's certificate is not
|
||||||
|
// expired, just ignore.
|
||||||
if traf {
|
if traf {
|
||||||
if n.l.Level >= logrus.DebugLevel {
|
if n.l.Level >= logrus.DebugLevel {
|
||||||
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
||||||
|
@ -178,15 +194,6 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we didn't we may need to probe or destroy the conn
|
|
||||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
|
||||||
if err != nil {
|
|
||||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
|
||||||
n.ClearIP(vpnIP)
|
|
||||||
n.ClearPendingDeletion(vpnIP)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo.logger(n.l).
|
hostinfo.logger(n.l).
|
||||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
|
@ -213,22 +220,31 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||||
|
|
||||||
vpnIP := ep.(uint32)
|
vpnIP := ep.(uint32)
|
||||||
|
|
||||||
// If we saw incoming packets from this ip, just return
|
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
||||||
|
if err != nil {
|
||||||
|
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||||
|
|
||||||
|
if !n.intf.disconnectInvalid {
|
||||||
|
n.ClearIP(vpnIP)
|
||||||
|
n.ClearPendingDeletion(vpnIP)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we saw an incoming packets from this ip and peer's certificate is not
|
||||||
|
// expired, just ignore.
|
||||||
traf := n.CheckIn(vpnIP)
|
traf := n.CheckIn(vpnIP)
|
||||||
if traf {
|
if traf {
|
||||||
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
n.ClearIP(vpnIP)
|
|
||||||
n.ClearPendingDeletion(vpnIP)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
|
||||||
if err != nil {
|
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIP)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIP)
|
||||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,3 +272,34 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
|
||||||
|
func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool {
|
||||||
|
if !n.intf.disconnectInvalid {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteCert := hostinfo.GetCert()
|
||||||
|
if remoteCert == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := remoteCert.Verify(now, n.intf.caPool)
|
||||||
|
if valid {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
fingerprint, _ := remoteCert.Sha256Sum()
|
||||||
|
n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err).
|
||||||
|
WithField("certName", remoteCert.Details.Name).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
|
||||||
|
// Inform the remote and close the tunnel locally
|
||||||
|
n.intf.sendCloseTunnel(hostinfo)
|
||||||
|
n.intf.closeTunnel(hostinfo, false)
|
||||||
|
|
||||||
|
n.ClearIP(vpnIP)
|
||||||
|
n.ClearPendingDeletion(vpnIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -148,3 +150,96 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we can disconnect the peer.
|
||||||
|
// Validate if the peer's certificate is invalid (expired, etc.)
|
||||||
|
// Disconnect only if disconnectInvalid: true is set.
|
||||||
|
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
l := NewTestLogger()
|
||||||
|
ipNet := net.IPNet{
|
||||||
|
IP: net.IPv4(172, 1, 1, 2),
|
||||||
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
}
|
||||||
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
|
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
|
|
||||||
|
// Generate keys for CA and peer's cert.
|
||||||
|
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
caCert := cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
Name: "ca",
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(1 * time.Hour),
|
||||||
|
IsCA: true,
|
||||||
|
PublicKey: pubCA,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
caCert.Sign(privCA)
|
||||||
|
ncp := &cert.NebulaCAPool{
|
||||||
|
CAs: cert.NewCAPool().CAs,
|
||||||
|
}
|
||||||
|
ncp.CAs["ca"] = &caCert
|
||||||
|
|
||||||
|
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
peerCert := cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
Name: "host",
|
||||||
|
Ips: []*net.IPNet{&ipNet},
|
||||||
|
Subnets: []*net.IPNet{},
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(60 * time.Second),
|
||||||
|
PublicKey: pubCrt,
|
||||||
|
IsCA: false,
|
||||||
|
Issuer: "ca",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
peerCert.Sign(privCA)
|
||||||
|
|
||||||
|
cs := &CertState{
|
||||||
|
rawCertificate: []byte{},
|
||||||
|
privateKey: []byte{},
|
||||||
|
certificate: &cert.NebulaCertificate{},
|
||||||
|
rawCertificateNoKey: []byte{},
|
||||||
|
}
|
||||||
|
|
||||||
|
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||||
|
ifce := &Interface{
|
||||||
|
hostMap: hostMap,
|
||||||
|
inside: &Tun{},
|
||||||
|
outside: &udpConn{},
|
||||||
|
certState: cs,
|
||||||
|
firewall: &Firewall{},
|
||||||
|
lightHouse: lh,
|
||||||
|
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||||
|
l: l,
|
||||||
|
disconnectInvalid: true,
|
||||||
|
caPool: ncp,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create manager
|
||||||
|
nc := newConnectionManager(l, ifce, 5, 10)
|
||||||
|
ifce.connectionManager = nc
|
||||||
|
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
||||||
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
|
certState: cs,
|
||||||
|
peerCert: &peerCert,
|
||||||
|
H: &noise.HandshakeState{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move ahead 45s.
|
||||||
|
// Check if to disconnect with invalid certificate.
|
||||||
|
// Should be alive.
|
||||||
|
nextTick := now.Add(45 * time.Second)
|
||||||
|
destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
|
||||||
|
assert.False(t, destroyed)
|
||||||
|
|
||||||
|
// Move ahead 61s.
|
||||||
|
// Check if to disconnect with invalid certificate.
|
||||||
|
// Should be disconnected.
|
||||||
|
nextTick = now.Add(61 * time.Second)
|
||||||
|
destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
|
||||||
|
assert.True(t, destroyed)
|
||||||
|
}
|
||||||
|
|
|
@ -10,6 +10,8 @@ pki:
|
||||||
# blocklist is a list of certificate fingerprints that we will refuse to talk to
|
# blocklist is a list of certificate fingerprints that we will refuse to talk to
|
||||||
#blocklist:
|
#blocklist:
|
||||||
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
|
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
|
||||||
|
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
||||||
|
#disconnect_invalid: false
|
||||||
|
|
||||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||||
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
||||||
|
|
|
@ -43,6 +43,7 @@ type InterfaceConfig struct {
|
||||||
MessageMetrics *MessageMetrics
|
MessageMetrics *MessageMetrics
|
||||||
version string
|
version string
|
||||||
caPool *cert.NebulaCAPool
|
caPool *cert.NebulaCAPool
|
||||||
|
disconnectInvalid bool
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
@ -67,6 +68,7 @@ type Interface struct {
|
||||||
udpBatchSize int
|
udpBatchSize int
|
||||||
routines int
|
routines int
|
||||||
caPool *cert.NebulaCAPool
|
caPool *cert.NebulaCAPool
|
||||||
|
disconnectInvalid bool
|
||||||
|
|
||||||
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
|
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
|
||||||
rebindCount int8
|
rebindCount int8
|
||||||
|
@ -118,6 +120,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||||
writers: make([]*udpConn, c.routines),
|
writers: make([]*udpConn, c.routines),
|
||||||
readers: make([]io.ReadWriteCloser, c.routines),
|
readers: make([]io.ReadWriteCloser, c.routines),
|
||||||
caPool: c.caPool,
|
caPool: c.caPool,
|
||||||
|
disconnectInvalid: c.disconnectInvalid,
|
||||||
myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP),
|
myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP),
|
||||||
|
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
|
|
1
main.go
1
main.go
|
@ -371,6 +371,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
MessageMetrics: messageMetrics,
|
MessageMetrics: messageMetrics,
|
||||||
version: buildVersion,
|
version: buildVersion,
|
||||||
caPool: caPool,
|
caPool: caPool,
|
||||||
|
disconnectInvalid: config.GetBool("pki.disconnect_invalid", false),
|
||||||
|
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
l: l,
|
l: l,
|
||||||
|
|
Loading…
Reference in New Issue