Ensure the Nebula device exists before attempting to bind to the Nebula IP (#375)

This commit is contained in:
brad-defined 2021-04-16 11:34:28 -04:00 committed by GitHub
parent ab08be1e3e
commit 17106f83a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 97 additions and 44 deletions

View File

@ -15,8 +15,11 @@ import (
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
type Control struct { type Control struct {
f *Interface f *Interface
l *logrus.Logger l *logrus.Logger
sshStart func()
statsStart func()
dnsStart func()
} }
type ControlHostInfo struct { type ControlHostInfo struct {
@ -32,6 +35,21 @@ type ControlHostInfo struct {
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() { func (c *Control) Start() {
// Activate the interface
c.f.activate()
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
go c.sshStart()
}
if c.statsStart != nil {
go c.statsStart()
}
if c.dnsStart != nil {
go c.dnsStart()
}
// Start reading packets.
c.f.run() c.f.run()
} }

View File

@ -109,7 +109,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m) w.WriteMsg(m)
} }
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) { func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
dnsR = newDnsRecords(hostMap) dnsR = newDnsRecords(hostMap)
// attach request handler func // attach request handler func
@ -120,7 +120,10 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
c.RegisterReloadCallback(func(c *Config) { c.RegisterReloadCallback(func(c *Config) {
reloadDns(l, c) reloadDns(l, c)
}) })
startDns(l, c)
return func() {
startDns(l, c)
}
} }
func getDnsServerAddr(c *Config) string { func getDnsServerAddr(c *Config) string {

View File

@ -130,7 +130,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
return ifce, nil return ifce, nil
} }
func (f *Interface) run() { // activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() {
// actually turn on tun dev // actually turn on tun dev
addr, err := f.outside.LocalAddr() addr, err := f.outside.LocalAddr()
@ -159,7 +162,9 @@ func (f *Interface) run() {
if err := f.inside.Activate(); err != nil { if err := f.inside.Activate(); err != nil {
f.l.Fatal(err) f.l.Fatal(err)
} }
}
func (f *Interface) run() {
// Launch n queues to read packets from udp // Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
go f.listenOut(i) go f.listenOut(i)

10
main.go
View File

@ -75,8 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(l, ssh, config) wireSSHReload(l, ssh, config)
var sshStart func()
if config.GetBool("sshd.enabled", false) { if config.GetBool("sshd.enabled", false) {
err = configSSH(l, ssh, config) sshStart, err = configSSH(l, ssh, config)
if err != nil { if err != nil {
return nil, NewContextualError("Error while configuring the sshd", nil, err) return nil, NewContextualError("Error while configuring the sshd", nil, err)
} }
@ -393,7 +394,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
go lightHouse.LhUpdateWorker(ifce) go lightHouse.LhUpdateWorker(ifce)
} }
err = startStats(l, config, buildVersion, configTest) statsStart, err := startStats(l, config, buildVersion, configTest)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err) return nil, NewContextualError("Failed to start stats emitter", nil, err)
} }
@ -408,10 +409,11 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host // Start DNS server last to allow using the nebula IP as lighthouse.dns.host
var dnsStart func()
if amLighthouse && serveDns { if amLighthouse && serveDns {
l.Debugln("Starting dns server") l.Debugln("Starting dns server")
go dnsMain(l, hostMap, config) dnsStart = dnsMain(l, hostMap, config)
} }
return &Control{ifce, l}, nil return &Control{ifce, l, sshStart, statsStart, dnsStart}, nil
} }

32
ssh.go
View File

@ -47,48 +47,55 @@ type sshCreateTunnelFlags struct {
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) { func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
c.RegisterReloadCallback(func(c *Config) { c.RegisterReloadCallback(func(c *Config) {
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
err := configSSH(l, ssh, c) sshRun, err := configSSH(l, ssh, c)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to reconfigure the sshd") l.WithError(err).Error("Failed to reconfigure the sshd")
ssh.Stop() ssh.Stop()
} }
if sshRun != nil {
go sshRun()
}
} else { } else {
ssh.Stop() ssh.Stop()
} }
}) })
} }
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error { // configSSH reads the ssh info out of the passed-in Config and
// updates the passed-in SSHServer. On success, it returns a function
// that callers may invoke to run the configured ssh server. On
// failure, it returns nil, error.
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) {
//TODO conntrack list //TODO conntrack list
//TODO print firewall rules or hash? //TODO print firewall rules or hash?
listen := c.GetString("sshd.listen", "") listen := c.GetString("sshd.listen", "")
if listen == "" { if listen == "" {
return fmt.Errorf("sshd.listen must be provided") return nil, fmt.Errorf("sshd.listen must be provided")
} }
_, port, err := net.SplitHostPort(listen) _, port, err := net.SplitHostPort(listen)
if err != nil { if err != nil {
return fmt.Errorf("invalid sshd.listen address: %s", err) return nil, fmt.Errorf("invalid sshd.listen address: %s", err)
} }
if port == "22" { if port == "22" {
return fmt.Errorf("sshd.listen can not use port 22") return nil, fmt.Errorf("sshd.listen can not use port 22")
} }
//TODO: no good way to reload this right now //TODO: no good way to reload this right now
hostKeyFile := c.GetString("sshd.host_key", "") hostKeyFile := c.GetString("sshd.host_key", "")
if hostKeyFile == "" { if hostKeyFile == "" {
return fmt.Errorf("sshd.host_key must be provided") return nil, fmt.Errorf("sshd.host_key must be provided")
} }
hostKeyBytes, err := ioutil.ReadFile(hostKeyFile) hostKeyBytes, err := ioutil.ReadFile(hostKeyFile)
if err != nil { if err != nil {
return fmt.Errorf("error while loading sshd.host_key file: %s", err) return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err)
} }
err = ssh.SetHostKey(hostKeyBytes) err = ssh.SetHostKey(hostKeyBytes)
if err != nil { if err != nil {
return fmt.Errorf("error while adding sshd.host_key: %s", err) return nil, fmt.Errorf("error while adding sshd.host_key: %s", err)
} }
rawKeys := c.Get("sshd.authorized_users") rawKeys := c.Get("sshd.authorized_users")
@ -139,14 +146,19 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
l.Info("no ssh users to authorize") l.Info("no ssh users to authorize")
} }
var runner func()
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
ssh.Stop() ssh.Stop()
go ssh.Run(listen) runner = func() {
if err := ssh.Run(listen); err != nil {
l.WithField("err", err).Warn("Failed to run the SSH server")
}
}
} else { } else {
ssh.Stop() ssh.Stop()
} }
return nil return runner, nil
} }
func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) { func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {

View File

@ -141,21 +141,22 @@ func (s *SSHServer) Run(addr string) error {
} }
func (s *SSHServer) Stop() { func (s *SSHServer) Stop() {
// Close the listener first, to prevent any new connections being accepted.
if s.listener != nil {
if err := s.listener.Close(); err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener")
} else {
s.l.Info("SSH server stopped listening")
}
}
// Force close all existing connections.
// TODO I believe this has a slight race if the listener has just accepted
// a connection. Can fix by moving this to the goroutine that's accepting.
for _, c := range s.conns { for _, c := range s.conns {
c.Close() c.Close()
} }
if s.listener == nil {
return
}
err := s.listener.Close()
if err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener")
return
}
s.l.Info("SSH server stopped listening")
return return
} }

View File

@ -17,24 +17,35 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) error { // startStats initializes stats from config. On success, if any futher work
// is needed to serve stats, it returns a func to handle that work. If no
// work is needed, it'll return nil. On failure, it returns nil, error.
func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) {
mType := c.GetString("stats.type", "") mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" { if mType == "" || mType == "none" {
return nil return nil, nil
} }
interval := c.GetDuration("stats.interval", 0) interval := c.GetDuration("stats.interval", 0)
if interval == 0 { if interval == 0 {
return fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) return nil, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", ""))
} }
var startFn func()
switch mType { switch mType {
case "graphite": case "graphite":
startGraphiteStats(l, interval, c, configTest) err := startGraphiteStats(l, interval, c, configTest)
if err != nil {
return nil, err
}
case "prometheus": case "prometheus":
startPrometheusStats(l, interval, c, buildVersion, configTest) var err error
startFn, err = startPrometheusStats(l, interval, c, buildVersion, configTest)
if err != nil {
return nil, err
}
default: default:
return fmt.Errorf("stats.type was not understood: %s", mType) return nil, fmt.Errorf("stats.type was not understood: %s", mType)
} }
metrics.RegisterDebugGCStats(metrics.DefaultRegistry) metrics.RegisterDebugGCStats(metrics.DefaultRegistry)
@ -43,7 +54,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo
go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval) go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval)
go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval) go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval)
return nil return startFn, nil
} }
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error { func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
@ -59,25 +70,25 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest
return fmt.Errorf("error while setting up graphite sink: %s", err) return fmt.Errorf("error while setting up graphite sink: %s", err)
} }
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
if !configTest { if !configTest {
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
} }
return nil return nil
} }
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) error { func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) {
namespace := c.GetString("stats.namespace", "") namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "") subsystem := c.GetString("stats.subsystem", "")
listen := c.GetString("stats.listen", "") listen := c.GetString("stats.listen", "")
if listen == "" { if listen == "" {
return fmt.Errorf("stats.listen should not be empty") return nil, fmt.Errorf("stats.listen should not be empty")
} }
path := c.GetString("stats.path", "") path := c.GetString("stats.path", "")
if path == "" { if path == "" {
return fmt.Errorf("stats.path should not be empty") return nil, fmt.Errorf("stats.path should not be empty")
} }
pr := prometheus.NewRegistry() pr := prometheus.NewRegistry()
@ -98,13 +109,14 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVer
pr.MustRegister(g) pr.MustRegister(g)
g.Set(1) g.Set(1)
var startFn func()
if !configTest { if !configTest {
go func() { startFn = func() {
l.Infof("Prometheus stats listening on %s at %s", listen, path) l.Infof("Prometheus stats listening on %s at %s", listen, path)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
log.Fatal(http.ListenAndServe(listen, nil)) log.Fatal(http.ListenAndServe(listen, nil))
}() }
} }
return nil return startFn, nil
} }