diff --git a/control.go b/control.go index d7a1c1f..c00a958 100644 --- a/control.go +++ b/control.go @@ -15,8 +15,11 @@ import ( // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc type Control struct { - f *Interface - l *logrus.Logger + f *Interface + l *logrus.Logger + sshStart func() + statsStart func() + dnsStart func() } type ControlHostInfo struct { @@ -32,6 +35,21 @@ type ControlHostInfo struct { // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() func (c *Control) Start() { + // Activate the interface + c.f.activate() + + // Call all the delayed funcs that waited patiently for the interface to be created. + if c.sshStart != nil { + go c.sshStart() + } + if c.statsStart != nil { + go c.statsStart() + } + if c.dnsStart != nil { + go c.dnsStart() + } + + // Start reading packets. c.f.run() } diff --git a/dns_server.go b/dns_server.go index a4e1f13..881d06b 100644 --- a/dns_server.go +++ b/dns_server.go @@ -109,7 +109,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) { +func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() { dnsR = newDnsRecords(hostMap) // attach request handler func @@ -120,7 +120,10 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) { c.RegisterReloadCallback(func(c *Config) { reloadDns(l, c) }) - startDns(l, c) + + return func() { + startDns(l, c) + } } func getDnsServerAddr(c *Config) string { diff --git a/interface.go b/interface.go index 3ded8db..6ad2d84 100644 --- a/interface.go +++ b/interface.go @@ -130,7 +130,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { return ifce, nil } -func (f *Interface) run() { +// activate creates the interface on the host. After the interface is created, any +// other services that want to bind listeners to its IP may do so successfully. However, +// the interface isn't going to process anything until run() is called. +func (f *Interface) activate() { // actually turn on tun dev addr, err := f.outside.LocalAddr() @@ -159,7 +162,9 @@ func (f *Interface) run() { if err := f.inside.Activate(); err != nil { f.l.Fatal(err) } +} +func (f *Interface) run() { // Launch n queues to read packets from udp for i := 0; i < f.routines; i++ { go f.listenOut(i) diff --git a/main.go b/main.go index 6abe5b6..fae490f 100644 --- a/main.go +++ b/main.go @@ -75,8 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) wireSSHReload(l, ssh, config) + var sshStart func() if config.GetBool("sshd.enabled", false) { - err = configSSH(l, ssh, config) + sshStart, err = configSSH(l, ssh, config) if err != nil { return nil, NewContextualError("Error while configuring the sshd", nil, err) } @@ -393,7 +394,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L go lightHouse.LhUpdateWorker(ifce) } - err = startStats(l, config, buildVersion, configTest) + statsStart, err := startStats(l, config, buildVersion, configTest) if err != nil { return nil, NewContextualError("Failed to start stats emitter", nil, err) } @@ -408,10 +409,11 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) // Start DNS server last to allow using the nebula IP as lighthouse.dns.host + var dnsStart func() if amLighthouse && serveDns { l.Debugln("Starting dns server") - go dnsMain(l, hostMap, config) + dnsStart = dnsMain(l, hostMap, config) } - return &Control{ifce, l}, nil + return &Control{ifce, l, sshStart, statsStart, dnsStart}, nil } diff --git a/ssh.go b/ssh.go index c8c8e40..3714671 100644 --- a/ssh.go +++ b/ssh.go @@ -47,48 +47,55 @@ type sshCreateTunnelFlags struct { func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) { c.RegisterReloadCallback(func(c *Config) { if c.GetBool("sshd.enabled", false) { - err := configSSH(l, ssh, c) + sshRun, err := configSSH(l, ssh, c) if err != nil { l.WithError(err).Error("Failed to reconfigure the sshd") ssh.Stop() } + if sshRun != nil { + go sshRun() + } } else { ssh.Stop() } }) } -func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error { +// configSSH reads the ssh info out of the passed-in Config and +// updates the passed-in SSHServer. On success, it returns a function +// that callers may invoke to run the configured ssh server. On +// failure, it returns nil, error. +func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) { //TODO conntrack list //TODO print firewall rules or hash? listen := c.GetString("sshd.listen", "") if listen == "" { - return fmt.Errorf("sshd.listen must be provided") + return nil, fmt.Errorf("sshd.listen must be provided") } _, port, err := net.SplitHostPort(listen) if err != nil { - return fmt.Errorf("invalid sshd.listen address: %s", err) + return nil, fmt.Errorf("invalid sshd.listen address: %s", err) } if port == "22" { - return fmt.Errorf("sshd.listen can not use port 22") + return nil, fmt.Errorf("sshd.listen can not use port 22") } //TODO: no good way to reload this right now hostKeyFile := c.GetString("sshd.host_key", "") if hostKeyFile == "" { - return fmt.Errorf("sshd.host_key must be provided") + return nil, fmt.Errorf("sshd.host_key must be provided") } hostKeyBytes, err := ioutil.ReadFile(hostKeyFile) if err != nil { - return fmt.Errorf("error while loading sshd.host_key file: %s", err) + return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err) } err = ssh.SetHostKey(hostKeyBytes) if err != nil { - return fmt.Errorf("error while adding sshd.host_key: %s", err) + return nil, fmt.Errorf("error while adding sshd.host_key: %s", err) } rawKeys := c.Get("sshd.authorized_users") @@ -139,14 +146,19 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error { l.Info("no ssh users to authorize") } + var runner func() if c.GetBool("sshd.enabled", false) { ssh.Stop() - go ssh.Run(listen) + runner = func() { + if err := ssh.Run(listen); err != nil { + l.WithField("err", err).Warn("Failed to run the SSH server") + } + } } else { ssh.Stop() } - return nil + return runner, nil } func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) { diff --git a/sshd/server.go b/sshd/server.go index 7f6da3b..1ff32eb 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -141,21 +141,22 @@ func (s *SSHServer) Run(addr string) error { } func (s *SSHServer) Stop() { + // Close the listener first, to prevent any new connections being accepted. + if s.listener != nil { + if err := s.listener.Close(); err != nil { + s.l.WithError(err).Warn("Failed to close the sshd listener") + } else { + s.l.Info("SSH server stopped listening") + } + } + + // Force close all existing connections. + // TODO I believe this has a slight race if the listener has just accepted + // a connection. Can fix by moving this to the goroutine that's accepting. for _, c := range s.conns { c.Close() } - if s.listener == nil { - return - } - - err := s.listener.Close() - if err != nil { - s.l.WithError(err).Warn("Failed to close the sshd listener") - return - } - - s.l.Info("SSH server stopped listening") return } diff --git a/stats.go b/stats.go index 76be8ce..205a89a 100644 --- a/stats.go +++ b/stats.go @@ -17,24 +17,35 @@ import ( "github.com/sirupsen/logrus" ) -func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) error { +// startStats initializes stats from config. On success, if any futher work +// is needed to serve stats, it returns a func to handle that work. If no +// work is needed, it'll return nil. On failure, it returns nil, error. +func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) { mType := c.GetString("stats.type", "") if mType == "" || mType == "none" { - return nil + return nil, nil } interval := c.GetDuration("stats.interval", 0) if interval == 0 { - return fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) + return nil, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) } + var startFn func() switch mType { case "graphite": - startGraphiteStats(l, interval, c, configTest) + err := startGraphiteStats(l, interval, c, configTest) + if err != nil { + return nil, err + } case "prometheus": - startPrometheusStats(l, interval, c, buildVersion, configTest) + var err error + startFn, err = startPrometheusStats(l, interval, c, buildVersion, configTest) + if err != nil { + return nil, err + } default: - return fmt.Errorf("stats.type was not understood: %s", mType) + return nil, fmt.Errorf("stats.type was not understood: %s", mType) } metrics.RegisterDebugGCStats(metrics.DefaultRegistry) @@ -43,7 +54,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval) go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval) - return nil + return startFn, nil } func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error { @@ -59,25 +70,25 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest return fmt.Errorf("error while setting up graphite sink: %s", err) } - l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) if !configTest { + l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) } return nil } -func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) error { +func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) { namespace := c.GetString("stats.namespace", "") subsystem := c.GetString("stats.subsystem", "") listen := c.GetString("stats.listen", "") if listen == "" { - return fmt.Errorf("stats.listen should not be empty") + return nil, fmt.Errorf("stats.listen should not be empty") } path := c.GetString("stats.path", "") if path == "" { - return fmt.Errorf("stats.path should not be empty") + return nil, fmt.Errorf("stats.path should not be empty") } pr := prometheus.NewRegistry() @@ -98,13 +109,14 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVer pr.MustRegister(g) g.Set(1) + var startFn func() if !configTest { - go func() { + startFn = func() { l.Infof("Prometheus stats listening on %s at %s", listen, path) http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) log.Fatal(http.ListenAndServe(listen, nil)) - }() + } } - return nil + return startFn, nil }