Add a context object in nebula.Main to clean up on error (#550)
This commit is contained in:
parent
32cd9a93f1
commit
6ae8ba26f7
16
config.go
16
config.go
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
@ -114,14 +115,21 @@ func (c *Config) HasChanged(k string) bool {
|
|||
|
||||
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
|
||||
// original path provided to Load. The old settings are shallow copied for change detection after the reload.
|
||||
func (c *Config) CatchHUP() {
|
||||
func (c *Config) CatchHUP(ctx context.Context) {
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGHUP)
|
||||
|
||||
go func() {
|
||||
for range ch {
|
||||
c.l.Info("Caught HUP, reloading config")
|
||||
c.ReloadConfig()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
signal.Stop(ch)
|
||||
close(ch)
|
||||
return
|
||||
case <-ch:
|
||||
c.l.Info("Caught HUP, reloading config")
|
||||
c.ReloadConfig()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -32,7 +33,7 @@ type connectionManager struct {
|
|||
// I wanted to call one matLock
|
||||
}
|
||||
|
||||
func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
|
||||
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
|
||||
nc := &connectionManager{
|
||||
hostMap: intf.hostMap,
|
||||
in: make(map[uint32]struct{}),
|
||||
|
@ -50,7 +51,7 @@ func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pend
|
|||
pendingDeletionInterval: pendingDeletionInterval,
|
||||
l: l,
|
||||
}
|
||||
nc.Start()
|
||||
nc.Start(ctx)
|
||||
return nc
|
||||
}
|
||||
|
||||
|
@ -137,19 +138,26 @@ func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) {
|
|||
n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds))
|
||||
}
|
||||
|
||||
func (n *connectionManager) Start() {
|
||||
go n.Run()
|
||||
func (n *connectionManager) Start(ctx context.Context) {
|
||||
go n.Run(ctx)
|
||||
}
|
||||
|
||||
func (n *connectionManager) Run() {
|
||||
clockSource := time.Tick(500 * time.Millisecond)
|
||||
func (n *connectionManager) Run(ctx context.Context) {
|
||||
clockSource := time.NewTicker(500 * time.Millisecond)
|
||||
defer clockSource.Stop()
|
||||
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
|
||||
for now := range clockSource {
|
||||
n.HandleMonitorTick(now, p, nb, out)
|
||||
n.HandleDeletionTick(now)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case now := <-clockSource.C:
|
||||
n.HandleMonitorTick(now, p, nb, out)
|
||||
n.HandleDeletionTick(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
|
@ -45,7 +46,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
now := time.Now()
|
||||
|
||||
// Create manager
|
||||
nc := newConnectionManager(l, ifce, 5, 10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
nc := newConnectionManager(ctx, l, ifce, 5, 10)
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
|
@ -112,7 +115,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
now := time.Now()
|
||||
|
||||
// Create manager
|
||||
nc := newConnectionManager(l, ifce, 5, 10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
nc := newConnectionManager(ctx, l, ifce, 5, 10)
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
|
@ -220,7 +225,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create manager
|
||||
nc := newConnectionManager(l, ifce, 5, 10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
nc := newConnectionManager(ctx, l, ifce, 5, 10)
|
||||
ifce.connectionManager = nc
|
||||
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
@ -17,6 +18,7 @@ import (
|
|||
type Control struct {
|
||||
f *Interface
|
||||
l *logrus.Logger
|
||||
cancel context.CancelFunc
|
||||
sshStart func()
|
||||
statsStart func()
|
||||
dnsStart func()
|
||||
|
@ -57,6 +59,7 @@ func (c *Control) Start() {
|
|||
func (c *Control) Stop() {
|
||||
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
|
||||
c.CloseAllTunnels(false)
|
||||
c.cancel()
|
||||
c.l.Info("Goodbye")
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package nebula
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
@ -66,14 +67,18 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
|
|||
}
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) Run(f EncWriter) {
|
||||
clockSource := time.Tick(c.config.tryInterval)
|
||||
func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
|
||||
clockSource := time.NewTicker(c.config.tryInterval)
|
||||
defer clockSource.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case vpnIP := <-c.trigger:
|
||||
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
|
||||
c.handleOutbound(vpnIP, f, true)
|
||||
case now := <-clockSource:
|
||||
case now := <-clockSource.C:
|
||||
c.NextOutboundHandshakeTimerTick(now, f)
|
||||
}
|
||||
}
|
||||
|
|
15
hostmap.go
15
hostmap.go
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -369,7 +370,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
|
|||
}
|
||||
|
||||
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
|
||||
func (hm *HostMap) Punchy(conn *udpConn) {
|
||||
func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
|
||||
var metricsTxPunchy metrics.Counter
|
||||
if hm.metricsEnabled {
|
||||
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
|
||||
|
@ -379,6 +380,10 @@ func (hm *HostMap) Punchy(conn *udpConn) {
|
|||
|
||||
var remotes []*RemoteList
|
||||
b := []byte{1}
|
||||
|
||||
clockSource := time.NewTicker(time.Second * 10)
|
||||
defer clockSource.Stop()
|
||||
|
||||
for {
|
||||
remotes = hm.punchList(remotes[:0])
|
||||
for _, rl := range remotes {
|
||||
|
@ -388,7 +393,13 @@ func (hm *HostMap) Punchy(conn *udpConn) {
|
|||
conn.WriteTo(b, addr)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second * 10)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-clockSource.C:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
22
interface.go
22
interface.go
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -86,7 +87,7 @@ type Interface struct {
|
|||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||
func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
if c.Outside == nil {
|
||||
return nil, errors.New("no outside connection")
|
||||
}
|
||||
|
@ -135,7 +136,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
|||
l: c.l,
|
||||
}
|
||||
|
||||
ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
|
||||
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
|
||||
|
||||
return ifce, nil
|
||||
}
|
||||
|
@ -302,15 +303,20 @@ func (f *Interface) reloadFirewall(c *Config) {
|
|||
Info("New firewall has been installed")
|
||||
}
|
||||
|
||||
func (f *Interface) emitStats(i time.Duration) {
|
||||
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
ticker := time.NewTicker(i)
|
||||
defer ticker.Stop()
|
||||
|
||||
udpStats := NewUDPStatsEmitter(f.writers)
|
||||
|
||||
for range ticker.C {
|
||||
f.firewall.EmitStats()
|
||||
f.handshakeManager.EmitStats()
|
||||
|
||||
udpStats()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
f.firewall.EmitStats()
|
||||
f.handshakeManager.EmitStats()
|
||||
udpStats()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -328,14 +329,23 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr {
|
|||
return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
|
||||
}
|
||||
|
||||
func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
|
||||
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
|
||||
if lh.amLighthouse || lh.interval == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
clockSource := time.NewTicker(time.Second * time.Duration(lh.interval))
|
||||
defer clockSource.Stop()
|
||||
|
||||
for {
|
||||
lh.SendUpdate(f)
|
||||
time.Sleep(time.Second * time.Duration(lh.interval))
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-clockSource.C:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
34
main.go
34
main.go
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -13,7 +14,16 @@ import (
|
|||
|
||||
type m map[string]interface{}
|
||||
|
||||
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
|
||||
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
||||
defer func() {
|
||||
if reterr != nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
l := logger
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
|
@ -126,7 +136,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
|
||||
var tun Inside
|
||||
if !configTest {
|
||||
config.CatchHUP()
|
||||
config.CatchHUP(ctx)
|
||||
|
||||
switch {
|
||||
case config.GetBool("tun.disabled", false):
|
||||
|
@ -159,6 +169,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if reterr != nil {
|
||||
tun.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// set up our UDP listener
|
||||
udpConns := make([]*udpConn, routines)
|
||||
port := config.GetInt("listen.port", 0)
|
||||
|
@ -236,7 +252,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
punchy := NewPunchyFromConfig(config)
|
||||
if punchy.Punch && !configTest {
|
||||
l.Info("UDP hole punching enabled")
|
||||
go hostMap.Punchy(udpConns[0])
|
||||
go hostMap.Punchy(ctx, udpConns[0])
|
||||
}
|
||||
|
||||
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
|
||||
|
@ -388,7 +404,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
|
||||
var ifce *Interface
|
||||
if !configTest {
|
||||
ifce, err = NewInterface(ifConfig)
|
||||
ifce, err = NewInterface(ctx, ifConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize interface: %s", err)
|
||||
}
|
||||
|
@ -399,10 +415,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
|
||||
ifce.RegisterConfigChangeCallbacks(config)
|
||||
|
||||
go handshakeManager.Run(ifce)
|
||||
go lightHouse.LhUpdateWorker(ifce)
|
||||
go handshakeManager.Run(ctx, ifce)
|
||||
go lightHouse.LhUpdateWorker(ctx, ifce)
|
||||
}
|
||||
|
||||
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
|
||||
// a context so that they can exit when the context is Done.
|
||||
statsStart, err := startStats(l, config, buildVersion, configTest)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||
|
@ -413,7 +431,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
}
|
||||
|
||||
//TODO: check if we _should_ be emitting stats
|
||||
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
||||
go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10))
|
||||
|
||||
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||
|
||||
|
@ -424,5 +442,5 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
dnsStart = dnsMain(l, hostMap, config)
|
||||
}
|
||||
|
||||
return &Control{ifce, l, sshStart, statsStart, dnsStart}, nil
|
||||
return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil
|
||||
}
|
||||
|
|
4
stats.go
4
stats.go
|
@ -93,7 +93,9 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVer
|
|||
|
||||
pr := prometheus.NewRegistry()
|
||||
pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i)
|
||||
go pClient.UpdatePrometheusMetrics()
|
||||
if !configTest {
|
||||
go pClient.UpdatePrometheusMetrics()
|
||||
}
|
||||
|
||||
// Export our version information as labels on a static gauge
|
||||
g := prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
|
|
|
@ -41,6 +41,13 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
|
|||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
func (c *Tun) Close() error {
|
||||
if c.Interface != nil {
|
||||
return c.Interface.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Tun) Activate() error {
|
||||
var err error
|
||||
c.Interface, err = water.New(water.Config{
|
||||
|
|
|
@ -28,6 +28,13 @@ type Tun struct {
|
|||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (c *Tun) Close() error {
|
||||
if c.ReadWriteCloser != nil {
|
||||
return c.ReadWriteCloser.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
|
|
@ -24,6 +24,13 @@ type Tun struct {
|
|||
*water.Interface
|
||||
}
|
||||
|
||||
func (c *Tun) Close() error {
|
||||
if c.Interface != nil {
|
||||
return c.Interface.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue