Be more like a library to support mobile (#247)
This commit is contained in:
		| @@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) { | ||||
| func BenchmarkBits(b *testing.B) { | ||||
| 	z := NewBits(10) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		for i, _ := range z.bits { | ||||
| 		for i := range z.bits { | ||||
| 			z.bits[i] = true | ||||
| 		} | ||||
| 		for i, _ := range z.bits { | ||||
| 		for i := range z.bits { | ||||
| 			z.bits[i] = false | ||||
| 		} | ||||
|  | ||||
|   | ||||
| @@ -3,9 +3,9 @@ package main | ||||
| import ( | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
|  | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/slackhq/nebula" | ||||
| 	"os" | ||||
| ) | ||||
|  | ||||
| // A version string that can be set with | ||||
| @@ -45,5 +45,25 @@ func main() { | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	nebula.Main(*configPath, *configTest, Build) | ||||
| 	config := nebula.NewConfig() | ||||
| 	err := config.Load(*configPath) | ||||
| 	if err != nil { | ||||
| 		fmt.Printf("failed to load config: %s", err) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	l := logrus.New() | ||||
| 	l.Out = os.Stdout | ||||
| 	err = nebula.Main(config, *configTest, true, Build, l, nil, nil) | ||||
|  | ||||
| 	switch v := err.(type) { | ||||
| 	case nebula.ContextualError: | ||||
| 		v.Log(l) | ||||
| 		os.Exit(1) | ||||
| 	case error: | ||||
| 		l.WithError(err).Error("Failed to start") | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	os.Exit(0) | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| @@ -27,8 +29,15 @@ func (p *program) Start(s service.Service) error { | ||||
| } | ||||
|  | ||||
| func (p *program) run() error { | ||||
| 	nebula.Main(*p.configPath, *p.configTest, Build) | ||||
| 	return nil | ||||
| 	config := nebula.NewConfig() | ||||
| 	err := config.Load(*p.configPath) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to load config: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	l := logrus.New() | ||||
| 	l.Out = os.Stdout | ||||
| 	return nebula.Main(config, *p.configTest, true, Build, l, nil, nil) | ||||
| } | ||||
|  | ||||
| func (p *program) Stop(s service.Service) error { | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package main | ||||
| import ( | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"os" | ||||
|  | ||||
| 	"github.com/slackhq/nebula" | ||||
| @@ -39,5 +40,25 @@ func main() { | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	nebula.Main(*configPath, *configTest, Build) | ||||
| 	config := nebula.NewConfig() | ||||
| 	err := config.Load(*configPath) | ||||
| 	if err != nil { | ||||
| 		fmt.Printf("failed to load config: %s", err) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	l := logrus.New() | ||||
| 	l.Out = os.Stdout | ||||
| 	err = nebula.Main(config, *configTest, true, Build, l, nil, nil) | ||||
|  | ||||
| 	switch v := err.(type) { | ||||
| 	case nebula.ContextualError: | ||||
| 		v.Log(l) | ||||
| 		os.Exit(1) | ||||
| 	case error: | ||||
| 		l.WithError(err).Error("Failed to start") | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	os.Exit(0) | ||||
| } | ||||
|   | ||||
							
								
								
									
										20
									
								
								config.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								config.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | ||||
| package nebula | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/imdario/mergo" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| @@ -56,6 +57,13 @@ func (c *Config) Load(path string) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *Config) LoadString(raw string) error { | ||||
| 	if raw == "" { | ||||
| 		return errors.New("Empty configuration") | ||||
| 	} | ||||
| 	return c.parseRaw([]byte(raw)) | ||||
| } | ||||
|  | ||||
| // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered | ||||
| // here should decide if they need to make a change to the current process before making the change. HasChanged can be | ||||
| // used to help decide if a change is necessary. | ||||
| @@ -407,6 +415,18 @@ func (c *Config) addFile(path string, direct bool) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *Config) parseRaw(b []byte) error { | ||||
| 	var m map[interface{}]interface{} | ||||
|  | ||||
| 	err := yaml.Unmarshal(b, &m) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	c.Settings = m | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *Config) parse() error { | ||||
| 	var m map[interface{}]interface{} | ||||
|  | ||||
|   | ||||
							
								
								
									
										31
									
								
								logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								logger.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| package nebula | ||||
|  | ||||
| import ( | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
|  | ||||
| type ContextualError struct { | ||||
| 	RealError error | ||||
| 	Fields    map[string]interface{} | ||||
| 	Context   string | ||||
| } | ||||
|  | ||||
| func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { | ||||
| 	return ContextualError{Context: msg, Fields: fields, RealError: realError} | ||||
| } | ||||
|  | ||||
| func (ce ContextualError) Error() string { | ||||
| 	return ce.RealError.Error() | ||||
| } | ||||
|  | ||||
| func (ce ContextualError) Unwrap() error { | ||||
| 	return ce.RealError | ||||
| } | ||||
|  | ||||
| func (ce *ContextualError) Log(lr *logrus.Logger) { | ||||
| 	if ce.RealError != nil { | ||||
| 		lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) | ||||
| 	} else { | ||||
| 		lr.WithFields(ce.Fields).Error(ce.Context) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										66
									
								
								logger_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								logger_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| package nebula | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| type TestLogWriter struct { | ||||
| 	Logs []string | ||||
| } | ||||
|  | ||||
| func NewTestLogWriter() *TestLogWriter { | ||||
| 	return &TestLogWriter{Logs: make([]string, 0)} | ||||
| } | ||||
|  | ||||
| func (tl *TestLogWriter) Write(p []byte) (n int, err error) { | ||||
| 	tl.Logs = append(tl.Logs, string(p)) | ||||
| 	return len(p), nil | ||||
| } | ||||
|  | ||||
| func (tl *TestLogWriter) Reset() { | ||||
| 	tl.Logs = tl.Logs[:0] | ||||
| } | ||||
|  | ||||
| func TestContextualError_Log(t *testing.T) { | ||||
| 	l := logrus.New() | ||||
| 	l.Formatter = &logrus.TextFormatter{ | ||||
| 		DisableTimestamp: true, | ||||
| 		DisableColors:    true, | ||||
| 	} | ||||
|  | ||||
| 	tl := NewTestLogWriter() | ||||
| 	l.Out = tl | ||||
|  | ||||
| 	// Test a full context line | ||||
| 	tl.Reset() | ||||
| 	e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) | ||||
| 	e.Log(l) | ||||
| 	assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) | ||||
|  | ||||
| 	// Test a line with an error and msg but no fields | ||||
| 	tl.Reset() | ||||
| 	e = NewContextualError("test message", nil, errors.New("error")) | ||||
| 	e.Log(l) | ||||
| 	assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs) | ||||
|  | ||||
| 	// Test just a context and fields | ||||
| 	tl.Reset() | ||||
| 	e = NewContextualError("test message", m{"field": "1"}, nil) | ||||
| 	e.Log(l) | ||||
| 	assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs) | ||||
|  | ||||
| 	// Test just a context | ||||
| 	tl.Reset() | ||||
| 	e = NewContextualError("test message", nil, nil) | ||||
| 	e.Log(l) | ||||
| 	assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs) | ||||
|  | ||||
| 	// Test just an error | ||||
| 	tl.Reset() | ||||
| 	e = NewContextualError("", nil, errors.New("error")) | ||||
| 	e.Log(l) | ||||
| 	assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) | ||||
| } | ||||
							
								
								
									
										136
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										136
									
								
								main.go
									
									
									
									
									
								
							| @@ -3,6 +3,9 @@ package nebula | ||||
| import ( | ||||
| 	"encoding/binary" | ||||
| 	"fmt" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/slackhq/nebula/sshd" | ||||
| 	"gopkg.in/yaml.v2" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| @@ -10,42 +13,38 @@ import ( | ||||
| 	"strings" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/slackhq/nebula/sshd" | ||||
| 	"gopkg.in/yaml.v2" | ||||
| ) | ||||
|  | ||||
| // The caller should provide a real logger, we have one just in case | ||||
| var l = logrus.New() | ||||
|  | ||||
| type m map[string]interface{} | ||||
|  | ||||
| func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	l.Out = os.Stdout | ||||
| type CommandRequest struct { | ||||
| 	Command  string | ||||
| 	Callback chan error | ||||
| } | ||||
|  | ||||
| func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error { | ||||
| 	l = logger | ||||
| 	l.Formatter = &logrus.TextFormatter{ | ||||
| 		FullTimestamp: true, | ||||
| 	} | ||||
|  | ||||
| 	config := NewConfig() | ||||
| 	err := config.Load(configPath) | ||||
| 	if err != nil { | ||||
| 		l.WithError(err).Error("Failed to load config") | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	// Print the config if in test, the exit comes later | ||||
| 	if configTest { | ||||
| 		b, err := yaml.Marshal(config.Settings) | ||||
| 		if err != nil { | ||||
| 			l.Println(err) | ||||
| 			os.Exit(1) | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// Print the final config | ||||
| 		l.Println(string(b)) | ||||
| 	} | ||||
|  | ||||
| 	err = configLogger(config) | ||||
| 	err := configLogger(config) | ||||
| 	if err != nil { | ||||
| 		l.WithError(err).Error("Failed to configure the logger") | ||||
| 		return NewContextualError("Failed to configure the logger", nil, err) | ||||
| 	} | ||||
|  | ||||
| 	config.RegisterReloadCallback(func(c *Config) { | ||||
| @@ -59,20 +58,20 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	trustedCAs, err = loadCAFromConfig(config) | ||||
| 	if err != nil { | ||||
| 		//The errors coming out of loadCA are already nicely formatted | ||||
| 		l.WithError(err).Fatal("Failed to load ca from config") | ||||
| 		return NewContextualError("Failed to load ca from config", nil, err) | ||||
| 	} | ||||
| 	l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints") | ||||
|  | ||||
| 	cs, err := NewCertStateFromConfig(config) | ||||
| 	if err != nil { | ||||
| 		//The errors coming out of NewCertStateFromConfig are already nicely formatted | ||||
| 		l.WithError(err).Fatal("Failed to load certificate from config") | ||||
| 		return NewContextualError("Failed to load certificate from config", nil, err) | ||||
| 	} | ||||
| 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate") | ||||
|  | ||||
| 	fw, err := NewFirewallFromConfig(cs.certificate, config) | ||||
| 	if err != nil { | ||||
| 		l.WithError(err).Fatal("Error while loading firewall rules") | ||||
| 		return NewContextualError("Error while loading firewall rules", nil, err) | ||||
| 	} | ||||
| 	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") | ||||
|  | ||||
| @@ -80,11 +79,11 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	tunCidr := cs.certificate.Details.Ips[0] | ||||
| 	routes, err := parseRoutes(config, tunCidr) | ||||
| 	if err != nil { | ||||
| 		l.WithError(err).Fatal("Could not parse tun.routes") | ||||
| 		return NewContextualError("Could not parse tun.routes", nil, err) | ||||
| 	} | ||||
| 	unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr) | ||||
| 	if err != nil { | ||||
| 		l.WithError(err).Fatal("Could not parse tun.unsafe_routes") | ||||
| 		return NewContextualError("Could not parse tun.unsafe_routes", nil, err) | ||||
| 	} | ||||
|  | ||||
| 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) | ||||
| @@ -92,7 +91,7 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	if config.GetBool("sshd.enabled", false) { | ||||
| 		err = configSSH(ssh, config) | ||||
| 		if err != nil { | ||||
| 			l.WithError(err).Fatal("Error while configuring the sshd") | ||||
| 			return NewContextualError("Error while configuring the sshd", nil, err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -105,7 +104,16 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	if !configTest { | ||||
| 		config.CatchHUP() | ||||
|  | ||||
| 		// set up our tun dev | ||||
| 		if tunFd != nil { | ||||
| 			tun, err = newTunFromFd( | ||||
| 				*tunFd, | ||||
| 				tunCidr, | ||||
| 				config.GetInt("tun.mtu", DEFAULT_MTU), | ||||
| 				routes, | ||||
| 				unsafeRoutes, | ||||
| 				config.GetInt("tun.tx_queue", 500), | ||||
| 			) | ||||
| 		} else { | ||||
| 			tun, err = newTun( | ||||
| 				config.GetString("tun.dev", ""), | ||||
| 				tunCidr, | ||||
| @@ -114,8 +122,10 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 				unsafeRoutes, | ||||
| 				config.GetInt("tun.tx_queue", 500), | ||||
| 			) | ||||
| 		} | ||||
|  | ||||
| 		if err != nil { | ||||
| 			l.WithError(err).Fatal("Failed to get a tun/tap device") | ||||
| 			return NewContextualError("Failed to get a tun/tap device", nil, err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -126,11 +136,28 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	if !configTest { | ||||
| 		udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1) | ||||
| 		if err != nil { | ||||
| 			l.WithError(err).Fatal("Failed to open udp listener") | ||||
| 			return NewContextualError("Failed to open udp listener", nil, err) | ||||
| 		} | ||||
| 		udpServer.reloadConfig(config) | ||||
| 	} | ||||
|  | ||||
| 	sigChan := make(chan os.Signal) | ||||
| 	killChan := make(chan CommandRequest) | ||||
| 	if commandChan != nil { | ||||
| 		go func() { | ||||
| 			cmd := CommandRequest{} | ||||
| 			for { | ||||
| 				cmd = <-commandChan | ||||
| 				switch cmd.Command { | ||||
| 				case "rebind": | ||||
| 					udpServer.Rebind() | ||||
| 				case "exit": | ||||
| 					killChan <- cmd | ||||
| 				} | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	// Set up my internal host map | ||||
| 	var preferredRanges []*net.IPNet | ||||
| 	rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{}) | ||||
| @@ -139,7 +166,7 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 		for _, rawPreferredRange := range rawPreferredRanges { | ||||
| 			_, preferredRange, err := net.ParseCIDR(rawPreferredRange) | ||||
| 			if err != nil { | ||||
| 				l.WithError(err).Fatal("Failed to parse preferred ranges") | ||||
| 				return NewContextualError("Failed to parse preferred ranges", nil, err) | ||||
| 			} | ||||
| 			preferredRanges = append(preferredRanges, preferredRange) | ||||
| 		} | ||||
| @@ -152,7 +179,7 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	if rawLocalRange != "" { | ||||
| 		_, localRange, err := net.ParseCIDR(rawLocalRange) | ||||
| 		if err != nil { | ||||
| 			l.WithError(err).Fatal("Failed to parse local range") | ||||
| 			return NewContextualError("Failed to parse local_range", nil, err) | ||||
| 		} | ||||
|  | ||||
| 		// Check if the entry for local_range was already specified in | ||||
| @@ -192,7 +219,7 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	if port == 0 && !configTest { | ||||
| 		uPort, err := udpServer.LocalAddr() | ||||
| 		if err != nil { | ||||
| 			l.WithError(err).Fatal("Failed to get listening port") | ||||
| 			return NewContextualError("Failed to get listening port", nil, err) | ||||
| 		} | ||||
| 		port = int(uPort.Port) | ||||
| 	} | ||||
| @@ -209,10 +236,10 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	for i, host := range rawLighthouseHosts { | ||||
| 		ip := net.ParseIP(host) | ||||
| 		if ip == nil { | ||||
| 			l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1) | ||||
| 			return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) | ||||
| 		} | ||||
| 		if !tunCidr.Contains(ip) { | ||||
| 			l.WithField("vpnIp", ip).WithField("network", tunCidr.String()).Fatalf("lighthouse host is not in our subnet, invalid") | ||||
| 			return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) | ||||
| 		} | ||||
| 		lighthouseHosts[i] = ip2int(ip) | ||||
| 	} | ||||
| @@ -232,13 +259,13 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
|  | ||||
| 	remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false) | ||||
| 	if err != nil { | ||||
| 		l.WithError(err).Fatal("Invalid lighthouse.remote_allow_list") | ||||
| 		return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) | ||||
| 	} | ||||
| 	lightHouse.SetRemoteAllowList(remoteAllowList) | ||||
|  | ||||
| 	localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true) | ||||
| 	if err != nil { | ||||
| 		l.WithError(err).Fatal("Invalid lighthouse.local_allow_list") | ||||
| 		return NewContextualError("Invalid lighthouse.local_allow_list", nil, err) | ||||
| 	} | ||||
| 	lightHouse.SetLocalAllowList(localAllowList) | ||||
|  | ||||
| @@ -246,7 +273,7 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { | ||||
| 		vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) | ||||
| 		if !tunCidr.Contains(vpnIp) { | ||||
| 			l.WithField("vpnIp", vpnIp).WithField("network", tunCidr.String()).Fatalf("static_host_map key is not in our subnet, invalid") | ||||
| 			return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) | ||||
| 		} | ||||
| 		vals, ok := v.([]interface{}) | ||||
| 		if ok { | ||||
| @@ -257,7 +284,7 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 					ip := addr.IP | ||||
| 					port, err := strconv.Atoi(parts[1]) | ||||
| 					if err != nil { | ||||
| 						l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) | ||||
| 						return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | ||||
| 					} | ||||
| 					lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) | ||||
| 				} | ||||
| @@ -270,7 +297,7 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 				ip := addr.IP | ||||
| 				port, err := strconv.Atoi(parts[1]) | ||||
| 				if err != nil { | ||||
| 					l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) | ||||
| 					return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | ||||
| 				} | ||||
| 				lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) | ||||
| 			} | ||||
| @@ -330,14 +357,14 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 	case "chachapoly": | ||||
| 		noiseEndianness = binary.LittleEndian | ||||
| 	default: | ||||
| 		l.Fatalf("Unknown cipher: %v", ifConfig.Cipher) | ||||
| 		return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher) | ||||
| 	} | ||||
|  | ||||
| 	var ifce *Interface | ||||
| 	if !configTest { | ||||
| 		ifce, err = NewInterface(ifConfig) | ||||
| 		if err != nil { | ||||
| 			l.WithError(err).Fatal("Failed to initialize interface") | ||||
| 			return fmt.Errorf("failed to initialize interface: %s", err) | ||||
| 		} | ||||
|  | ||||
| 		ifce.RegisterConfigChangeCallbacks(config) | ||||
| @@ -348,11 +375,11 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
|  | ||||
| 	err = startStats(config, configTest) | ||||
| 	if err != nil { | ||||
| 		l.WithError(err).Fatal("Failed to start stats emitter") | ||||
| 		return NewContextualError("Failed to start stats emitter", nil, err) | ||||
| 	} | ||||
|  | ||||
| 	if configTest { | ||||
| 		os.Exit(0) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	//TODO: check if we _should_ be emitting stats | ||||
| @@ -367,19 +394,33 @@ func Main(configPath string, configTest bool, buildVersion string) { | ||||
| 		go dnsMain(hostMap, config) | ||||
| 	} | ||||
|  | ||||
| 	if block { | ||||
| 		// Just sit here and be friendly, main thread. | ||||
| 	shutdownBlock(ifce) | ||||
| 		shutdownBlock(ifce, sigChan, killChan) | ||||
| 	} else { | ||||
| 		// Even though we aren't blocking we still want to shutdown gracefully | ||||
| 		go shutdownBlock(ifce, sigChan, killChan) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func shutdownBlock(ifce *Interface) { | ||||
| 	var sigChan = make(chan os.Signal) | ||||
| func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) { | ||||
| 	var cmd CommandRequest | ||||
| 	var sig string | ||||
|  | ||||
| 	signal.Notify(sigChan, syscall.SIGTERM) | ||||
| 	signal.Notify(sigChan, syscall.SIGINT) | ||||
|  | ||||
| 	sig := <-sigChan | ||||
| 	select { | ||||
| 	case rawSig := <-sigChan: | ||||
| 		sig = rawSig.String() | ||||
| 	case cmd = <-killChan: | ||||
| 		sig = "controlling app" | ||||
| 	} | ||||
|  | ||||
| 	l.WithField("signal", sig).Info("Caught signal, shutting down") | ||||
|  | ||||
| 	//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though | ||||
| 	//TODO: stop tun and udp routines, the lock on hostMap effectively does that though | ||||
| 	//TODO: this is probably better as a function in ConnectionManager or HostMap directly | ||||
| 	ifce.hostMap.Lock() | ||||
| 	for _, h := range ifce.hostMap.Hosts { | ||||
| @@ -392,5 +433,8 @@ func shutdownBlock(ifce *Interface) { | ||||
| 	ifce.hostMap.Unlock() | ||||
|  | ||||
| 	l.WithField("signal", sig).Info("Goodbye") | ||||
| 	os.Exit(0) | ||||
| 	select { | ||||
| 	case cmd.Callback <- nil: | ||||
| 	default: | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -1,12 +1,13 @@ | ||||
| // +build !ios | ||||
|  | ||||
| package nebula | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songgao/water" | ||||
| 	"net" | ||||
| 	"os/exec" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/songgao/water" | ||||
| ) | ||||
|  | ||||
| type Tun struct { | ||||
| @@ -20,8 +21,9 @@ type Tun struct { | ||||
|  | ||||
| func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
| 	if len(routes) > 0 { | ||||
| 		return nil, fmt.Errorf("Route MTU not supported in Darwin") | ||||
| 		return nil, fmt.Errorf("route MTU not supported in Darwin") | ||||
| 	} | ||||
|  | ||||
| 	// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate() | ||||
| 	return &Tun{ | ||||
| 		Cidr:         cidr, | ||||
| @@ -30,13 +32,17 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
| 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin") | ||||
| } | ||||
|  | ||||
| func (c *Tun) Activate() error { | ||||
| 	var err error | ||||
| 	c.Interface, err = water.New(water.Config{ | ||||
| 		DeviceType: water.TUN, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("Activate failed: %v", err) | ||||
| 		return fmt.Errorf("activate failed: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	c.Device = c.Interface.Name() | ||||
|   | ||||
| @@ -22,6 +22,10 @@ type Tun struct { | ||||
| 	io.ReadWriteCloser | ||||
| } | ||||
|  | ||||
| func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
| 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") | ||||
| } | ||||
|  | ||||
| func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
| 	if len(routes) > 0 { | ||||
| 		return nil, fmt.Errorf("Route MTU not supported in FreeBSD") | ||||
|   | ||||
							
								
								
									
										105
									
								
								tun_ios.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								tun_ios.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | ||||
| // +build ios | ||||
|  | ||||
| package nebula | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"sync" | ||||
| 	"syscall" | ||||
| ) | ||||
|  | ||||
| type Tun struct { | ||||
| 	io.ReadWriteCloser | ||||
| 	Device string | ||||
| 	Cidr   *net.IPNet | ||||
| } | ||||
|  | ||||
| func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
| 	return nil, fmt.Errorf("newTun not supported in iOS") | ||||
| } | ||||
|  | ||||
| func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
| 	if len(routes) > 0 { | ||||
| 		return nil, fmt.Errorf("route MTU not supported in Darwin") | ||||
| 	} | ||||
|  | ||||
| 	file := os.NewFile(uintptr(deviceFd), "/dev/tun") | ||||
| 	ifce = &Tun{ | ||||
| 		Cidr:            cidr, | ||||
| 		ReadWriteCloser: &tunReadCloser{f: file}, | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (c *Tun) Activate() error { | ||||
| 	c.Device = "iOS" | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *Tun) WriteRaw(b []byte) error { | ||||
| 	_, err := c.Write(b) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // The following is hoisted up from water, we do this so we can inject our own fd on iOS | ||||
| type tunReadCloser struct { | ||||
| 	f io.ReadWriteCloser | ||||
|  | ||||
| 	rMu  sync.Mutex | ||||
| 	rBuf []byte | ||||
|  | ||||
| 	wMu  sync.Mutex | ||||
| 	wBuf []byte | ||||
| } | ||||
|  | ||||
| func (t *tunReadCloser) Read(to []byte) (int, error) { | ||||
| 	t.rMu.Lock() | ||||
| 	defer t.rMu.Unlock() | ||||
|  | ||||
| 	if cap(t.rBuf) < len(to)+4 { | ||||
| 		t.rBuf = make([]byte, len(to)+4) | ||||
| 	} | ||||
| 	t.rBuf = t.rBuf[:len(to)+4] | ||||
|  | ||||
| 	n, err := t.f.Read(t.rBuf) | ||||
| 	copy(to, t.rBuf[4:]) | ||||
| 	return n - 4, err | ||||
| } | ||||
|  | ||||
| func (t *tunReadCloser) Write(from []byte) (int, error) { | ||||
|  | ||||
| 	if len(from) == 0 { | ||||
| 		return 0, syscall.EIO | ||||
| 	} | ||||
|  | ||||
| 	t.wMu.Lock() | ||||
| 	defer t.wMu.Unlock() | ||||
|  | ||||
| 	if cap(t.wBuf) < len(from)+4 { | ||||
| 		t.wBuf = make([]byte, len(from)+4) | ||||
| 	} | ||||
| 	t.wBuf = t.wBuf[:len(from)+4] | ||||
|  | ||||
| 	// Determine the IP Family for the NULL L2 Header | ||||
| 	ipVer := from[0] >> 4 | ||||
| 	if ipVer == 4 { | ||||
| 		t.wBuf[3] = syscall.AF_INET | ||||
| 	} else if ipVer == 6 { | ||||
| 		t.wBuf[3] = syscall.AF_INET6 | ||||
| 	} else { | ||||
| 		return 0, errors.New("unable to determine IP version from packet") | ||||
| 	} | ||||
|  | ||||
| 	copy(t.wBuf[4:], from) | ||||
|  | ||||
| 	n, err := t.f.Write(t.wBuf) | ||||
| 	return n - 4, err | ||||
| } | ||||
|  | ||||
| func (t *tunReadCloser) Close() error { | ||||
| 	return t.f.Close() | ||||
| } | ||||
							
								
								
									
										17
									
								
								tun_linux.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								tun_linux.go
									
									
									
									
									
								
							| @@ -75,6 +75,23 @@ type ifreqQLEN struct { | ||||
| 	pad   [8]byte | ||||
| } | ||||
|  | ||||
| func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
|  | ||||
| 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") | ||||
|  | ||||
| 	ifce = &Tun{ | ||||
| 		ReadWriteCloser: file, | ||||
| 		fd:              int(file.Fd()), | ||||
| 		Device:          "tun0", | ||||
| 		Cidr:            cidr, | ||||
| 		DefaultMTU:      defaultMTU, | ||||
| 		TXQueueLen:      txQueueLen, | ||||
| 		Routes:          routes, | ||||
| 		UnsafeRoutes:    unsafeRoutes, | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
| 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -18,9 +18,13 @@ type Tun struct { | ||||
| 	*water.Interface | ||||
| } | ||||
|  | ||||
| func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
| 	return nil, fmt.Errorf("newTunFromFd not supported in Windows") | ||||
| } | ||||
|  | ||||
| func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { | ||||
| 	if len(routes) > 0 { | ||||
| 		return nil, fmt.Errorf("Route MTU not supported in Windows") | ||||
| 		return nil, fmt.Errorf("route MTU not supported in Windows") | ||||
| 	} | ||||
|  | ||||
| 	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() | ||||
|   | ||||
							
								
								
									
										36
									
								
								udp_android.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								udp_android.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package nebula | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"golang.org/x/sys/unix" | ||||
| ) | ||||
|  | ||||
| func NewListenConfig(multi bool) net.ListenConfig { | ||||
| 	return net.ListenConfig{ | ||||
| 		Control: func(network, address string, c syscall.RawConn) error { | ||||
| 			if multi { | ||||
| 				var controlErr error | ||||
| 				err := c.Control(func(fd uintptr) { | ||||
| 					if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { | ||||
| 						controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err) | ||||
| 						return | ||||
| 					} | ||||
| 				}) | ||||
| 				if err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 				if controlErr != nil { | ||||
| 					return controlErr | ||||
| 				} | ||||
| 			} | ||||
| 			return nil | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (u *udpConn) Rebind() { | ||||
| 	return | ||||
| } | ||||
| @@ -32,3 +32,12 @@ func NewListenConfig(multi bool) net.ListenConfig { | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (u *udpConn) Rebind() error { | ||||
| 	file, err := u.File() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0) | ||||
| } | ||||
|   | ||||
| @@ -32,3 +32,7 @@ func NewListenConfig(multi bool) net.ListenConfig { | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (u *udpConn) Rebind() { | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| // +build !linux | ||||
| // +build !linux android | ||||
|  | ||||
| // udp_generic implements the nebula UDP interface in pure Go stdlib. This | ||||
| // means it can be used on platforms like Darwin and Windows. | ||||
|   | ||||
| @@ -1,3 +1,5 @@ | ||||
| // +build !android | ||||
|  | ||||
| package nebula | ||||
|  | ||||
| import ( | ||||
| @@ -85,6 +87,10 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) { | ||||
| 	return &udpConn{sysFd: fd}, err | ||||
| } | ||||
|  | ||||
| func (u *udpConn) Rebind() { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (u *udpConn) SetRecvBuffer(n int) error { | ||||
| 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) | ||||
| } | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| // +build linux | ||||
| // +build 386 amd64p32 arm mips mipsle | ||||
| // +build !android | ||||
|  | ||||
| package nebula | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| // +build linux | ||||
| // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x | ||||
| // +build !android | ||||
|  | ||||
| package nebula | ||||
|  | ||||
|   | ||||
| @@ -20,3 +20,7 @@ func NewListenConfig(multi bool) net.ListenConfig { | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (u *udpConn) Rebind() { | ||||
| 	return | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nathan Brown
					Nathan Brown