Be more like a library to support mobile (#247)
This commit is contained in:
parent
1ea8847085
commit
41578ca971
|
@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) {
|
|||
func BenchmarkBits(b *testing.B) {
|
||||
z := NewBits(10)
|
||||
for n := 0; n < b.N; n++ {
|
||||
for i, _ := range z.bits {
|
||||
for i := range z.bits {
|
||||
z.bits[i] = true
|
||||
}
|
||||
for i, _ := range z.bits {
|
||||
for i := range z.bits {
|
||||
z.bits[i] = false
|
||||
}
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"os"
|
||||
)
|
||||
|
||||
// A version string that can be set with
|
||||
|
@ -45,5 +45,25 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
nebula.Main(*configPath, *configTest, Build)
|
||||
config := nebula.NewConfig()
|
||||
err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to load config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case nebula.ContextualError:
|
||||
v.Log(l)
|
||||
os.Exit(1)
|
||||
case error:
|
||||
l.WithError(err).Error("Failed to start")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -27,8 +29,15 @@ func (p *program) Start(s service.Service) error {
|
|||
}
|
||||
|
||||
func (p *program) run() error {
|
||||
nebula.Main(*p.configPath, *p.configTest, Build)
|
||||
return nil
|
||||
config := nebula.NewConfig()
|
||||
err := config.Load(*p.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load config: %s", err)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
return nebula.Main(config, *p.configTest, true, Build, l, nil, nil)
|
||||
}
|
||||
|
||||
func (p *program) Stop(s service.Service) error {
|
||||
|
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"os"
|
||||
|
||||
"github.com/slackhq/nebula"
|
||||
|
@ -39,5 +40,25 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
nebula.Main(*configPath, *configTest, Build)
|
||||
config := nebula.NewConfig()
|
||||
err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to load config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case nebula.ContextualError:
|
||||
v.Log(l)
|
||||
os.Exit(1)
|
||||
case error:
|
||||
l.WithError(err).Error("Failed to start")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
|
20
config.go
20
config.go
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/imdario/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -56,6 +57,13 @@ func (c *Config) Load(path string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) LoadString(raw string) error {
|
||||
if raw == "" {
|
||||
return errors.New("Empty configuration")
|
||||
}
|
||||
return c.parseRaw([]byte(raw))
|
||||
}
|
||||
|
||||
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
|
||||
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
|
||||
// used to help decide if a change is necessary.
|
||||
|
@ -407,6 +415,18 @@ func (c *Config) addFile(path string, direct bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) parseRaw(b []byte) error {
|
||||
var m map[interface{}]interface{}
|
||||
|
||||
err := yaml.Unmarshal(b, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Settings = m
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) parse() error {
|
||||
var m map[interface{}]interface{}
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type ContextualError struct {
|
||||
RealError error
|
||||
Fields map[string]interface{}
|
||||
Context string
|
||||
}
|
||||
|
||||
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
|
||||
return ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||
}
|
||||
|
||||
func (ce ContextualError) Error() string {
|
||||
return ce.RealError.Error()
|
||||
}
|
||||
|
||||
func (ce ContextualError) Unwrap() error {
|
||||
return ce.RealError
|
||||
}
|
||||
|
||||
func (ce *ContextualError) Log(lr *logrus.Logger) {
|
||||
if ce.RealError != nil {
|
||||
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
|
||||
} else {
|
||||
lr.WithFields(ce.Fields).Error(ce.Context)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type TestLogWriter struct {
|
||||
Logs []string
|
||||
}
|
||||
|
||||
func NewTestLogWriter() *TestLogWriter {
|
||||
return &TestLogWriter{Logs: make([]string, 0)}
|
||||
}
|
||||
|
||||
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
|
||||
tl.Logs = append(tl.Logs, string(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (tl *TestLogWriter) Reset() {
|
||||
tl.Logs = tl.Logs[:0]
|
||||
}
|
||||
|
||||
func TestContextualError_Log(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
DisableTimestamp: true,
|
||||
DisableColors: true,
|
||||
}
|
||||
|
||||
tl := NewTestLogWriter()
|
||||
l.Out = tl
|
||||
|
||||
// Test a full context line
|
||||
tl.Reset()
|
||||
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
|
||||
|
||||
// Test a line with an error and msg but no fields
|
||||
tl.Reset()
|
||||
e = NewContextualError("test message", nil, errors.New("error"))
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
|
||||
|
||||
// Test just a context and fields
|
||||
tl.Reset()
|
||||
e = NewContextualError("test message", m{"field": "1"}, nil)
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
|
||||
|
||||
// Test just a context
|
||||
tl.Reset()
|
||||
e = NewContextualError("test message", nil, nil)
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
|
||||
|
||||
// Test just an error
|
||||
tl.Reset()
|
||||
e = NewContextualError("", nil, errors.New("error"))
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
|
||||
}
|
154
main.go
154
main.go
|
@ -3,6 +3,9 @@ package nebula
|
|||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
"gopkg.in/yaml.v2"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
@ -10,42 +13,38 @@ import (
|
|||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// The caller should provide a real logger, we have one just in case
|
||||
var l = logrus.New()
|
||||
|
||||
type m map[string]interface{}
|
||||
|
||||
func Main(configPath string, configTest bool, buildVersion string) {
|
||||
l.Out = os.Stdout
|
||||
type CommandRequest struct {
|
||||
Command string
|
||||
Callback chan error
|
||||
}
|
||||
|
||||
func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error {
|
||||
l = logger
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
}
|
||||
|
||||
config := NewConfig()
|
||||
err := config.Load(configPath)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to load config")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Print the config if in test, the exit comes later
|
||||
if configTest {
|
||||
b, err := yaml.Marshal(config.Settings)
|
||||
if err != nil {
|
||||
l.Println(err)
|
||||
os.Exit(1)
|
||||
return err
|
||||
}
|
||||
|
||||
// Print the final config
|
||||
l.Println(string(b))
|
||||
}
|
||||
|
||||
err = configLogger(config)
|
||||
err := configLogger(config)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to configure the logger")
|
||||
return NewContextualError("Failed to configure the logger", nil, err)
|
||||
}
|
||||
|
||||
config.RegisterReloadCallback(func(c *Config) {
|
||||
|
@ -59,20 +58,20 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
trustedCAs, err = loadCAFromConfig(config)
|
||||
if err != nil {
|
||||
//The errors coming out of loadCA are already nicely formatted
|
||||
l.WithError(err).Fatal("Failed to load ca from config")
|
||||
return NewContextualError("Failed to load ca from config", nil, err)
|
||||
}
|
||||
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
|
||||
cs, err := NewCertStateFromConfig(config)
|
||||
if err != nil {
|
||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||
l.WithError(err).Fatal("Failed to load certificate from config")
|
||||
return NewContextualError("Failed to load certificate from config", nil, err)
|
||||
}
|
||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||
|
||||
fw, err := NewFirewallFromConfig(cs.certificate, config)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Error while loading firewall rules")
|
||||
return NewContextualError("Error while loading firewall rules", nil, err)
|
||||
}
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
||||
|
||||
|
@ -80,11 +79,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
tunCidr := cs.certificate.Details.Ips[0]
|
||||
routes, err := parseRoutes(config, tunCidr)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Could not parse tun.routes")
|
||||
return NewContextualError("Could not parse tun.routes", nil, err)
|
||||
}
|
||||
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Could not parse tun.unsafe_routes")
|
||||
return NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
}
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
|
@ -92,7 +91,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
if config.GetBool("sshd.enabled", false) {
|
||||
err = configSSH(ssh, config)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Error while configuring the sshd")
|
||||
return NewContextualError("Error while configuring the sshd", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,17 +104,28 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
if !configTest {
|
||||
config.CatchHUP()
|
||||
|
||||
// set up our tun dev
|
||||
tun, err = newTun(
|
||||
config.GetString("tun.dev", ""),
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
routes,
|
||||
unsafeRoutes,
|
||||
config.GetInt("tun.tx_queue", 500),
|
||||
)
|
||||
if tunFd != nil {
|
||||
tun, err = newTunFromFd(
|
||||
*tunFd,
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
routes,
|
||||
unsafeRoutes,
|
||||
config.GetInt("tun.tx_queue", 500),
|
||||
)
|
||||
} else {
|
||||
tun, err = newTun(
|
||||
config.GetString("tun.dev", ""),
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
routes,
|
||||
unsafeRoutes,
|
||||
config.GetInt("tun.tx_queue", 500),
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to get a tun/tap device")
|
||||
return NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,11 +136,28 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
if !configTest {
|
||||
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to open udp listener")
|
||||
return NewContextualError("Failed to open udp listener", nil, err)
|
||||
}
|
||||
udpServer.reloadConfig(config)
|
||||
}
|
||||
|
||||
sigChan := make(chan os.Signal)
|
||||
killChan := make(chan CommandRequest)
|
||||
if commandChan != nil {
|
||||
go func() {
|
||||
cmd := CommandRequest{}
|
||||
for {
|
||||
cmd = <-commandChan
|
||||
switch cmd.Command {
|
||||
case "rebind":
|
||||
udpServer.Rebind()
|
||||
case "exit":
|
||||
killChan <- cmd
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Set up my internal host map
|
||||
var preferredRanges []*net.IPNet
|
||||
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
|
||||
|
@ -139,7 +166,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
for _, rawPreferredRange := range rawPreferredRanges {
|
||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to parse preferred ranges")
|
||||
return NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||
}
|
||||
preferredRanges = append(preferredRanges, preferredRange)
|
||||
}
|
||||
|
@ -152,7 +179,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
if rawLocalRange != "" {
|
||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to parse local range")
|
||||
return NewContextualError("Failed to parse local_range", nil, err)
|
||||
}
|
||||
|
||||
// Check if the entry for local_range was already specified in
|
||||
|
@ -192,7 +219,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
if port == 0 && !configTest {
|
||||
uPort, err := udpServer.LocalAddr()
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to get listening port")
|
||||
return NewContextualError("Failed to get listening port", nil, err)
|
||||
}
|
||||
port = int(uPort.Port)
|
||||
}
|
||||
|
@ -209,10 +236,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
for i, host := range rawLighthouseHosts {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1)
|
||||
return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||
}
|
||||
if !tunCidr.Contains(ip) {
|
||||
l.WithField("vpnIp", ip).WithField("network", tunCidr.String()).Fatalf("lighthouse host is not in our subnet, invalid")
|
||||
return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
lighthouseHosts[i] = ip2int(ip)
|
||||
}
|
||||
|
@ -232,13 +259,13 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
|
||||
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Invalid lighthouse.remote_allow_list")
|
||||
return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetRemoteAllowList(remoteAllowList)
|
||||
|
||||
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Invalid lighthouse.local_allow_list")
|
||||
return NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetLocalAllowList(localAllowList)
|
||||
|
||||
|
@ -246,7 +273,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
||||
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
|
||||
if !tunCidr.Contains(vpnIp) {
|
||||
l.WithField("vpnIp", vpnIp).WithField("network", tunCidr.String()).Fatalf("static_host_map key is not in our subnet, invalid")
|
||||
return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
vals, ok := v.([]interface{})
|
||||
if ok {
|
||||
|
@ -257,7 +284,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
ip := addr.IP
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
|
||||
return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||
}
|
||||
|
@ -270,7 +297,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
ip := addr.IP
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
|
||||
return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||
}
|
||||
|
@ -330,14 +357,14 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
case "chachapoly":
|
||||
noiseEndianness = binary.LittleEndian
|
||||
default:
|
||||
l.Fatalf("Unknown cipher: %v", ifConfig.Cipher)
|
||||
return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
|
||||
}
|
||||
|
||||
var ifce *Interface
|
||||
if !configTest {
|
||||
ifce, err = NewInterface(ifConfig)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to initialize interface")
|
||||
return fmt.Errorf("failed to initialize interface: %s", err)
|
||||
}
|
||||
|
||||
ifce.RegisterConfigChangeCallbacks(config)
|
||||
|
@ -348,11 +375,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
|
||||
err = startStats(config, configTest)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to start stats emitter")
|
||||
return NewContextualError("Failed to start stats emitter", nil, err)
|
||||
}
|
||||
|
||||
if configTest {
|
||||
os.Exit(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
//TODO: check if we _should_ be emitting stats
|
||||
|
@ -367,19 +394,33 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||
go dnsMain(hostMap, config)
|
||||
}
|
||||
|
||||
// Just sit here and be friendly, main thread.
|
||||
shutdownBlock(ifce)
|
||||
if block {
|
||||
// Just sit here and be friendly, main thread.
|
||||
shutdownBlock(ifce, sigChan, killChan)
|
||||
} else {
|
||||
// Even though we aren't blocking we still want to shutdown gracefully
|
||||
go shutdownBlock(ifce, sigChan, killChan)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shutdownBlock(ifce *Interface) {
|
||||
var sigChan = make(chan os.Signal)
|
||||
func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) {
|
||||
var cmd CommandRequest
|
||||
var sig string
|
||||
|
||||
signal.Notify(sigChan, syscall.SIGTERM)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
sig := <-sigChan
|
||||
select {
|
||||
case rawSig := <-sigChan:
|
||||
sig = rawSig.String()
|
||||
case cmd = <-killChan:
|
||||
sig = "controlling app"
|
||||
}
|
||||
|
||||
l.WithField("signal", sig).Info("Caught signal, shutting down")
|
||||
|
||||
//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though
|
||||
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
|
||||
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
||||
ifce.hostMap.Lock()
|
||||
for _, h := range ifce.hostMap.Hosts {
|
||||
|
@ -392,5 +433,8 @@ func shutdownBlock(ifce *Interface) {
|
|||
ifce.hostMap.Unlock()
|
||||
|
||||
l.WithField("signal", sig).Info("Goodbye")
|
||||
os.Exit(0)
|
||||
select {
|
||||
case cmd.Callback <- nil:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
// +build !ios
|
||||
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songgao/water"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
type Tun struct {
|
||||
|
@ -20,8 +21,9 @@ type Tun struct {
|
|||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("Route MTU not supported in Darwin")
|
||||
return nil, fmt.Errorf("route MTU not supported in Darwin")
|
||||
}
|
||||
|
||||
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
|
||||
return &Tun{
|
||||
Cidr: cidr,
|
||||
|
@ -30,13 +32,17 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
|||
}, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
func (c *Tun) Activate() error {
|
||||
var err error
|
||||
c.Interface, err = water.New(water.Config{
|
||||
DeviceType: water.TUN,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Activate failed: %v", err)
|
||||
return fmt.Errorf("activate failed: %v", err)
|
||||
}
|
||||
|
||||
c.Device = c.Interface.Name()
|
||||
|
|
|
@ -22,6 +22,10 @@ type Tun struct {
|
|||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
// +build ios
|
||||
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type Tun struct {
|
||||
io.ReadWriteCloser
|
||||
Device string
|
||||
Cidr *net.IPNet
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("route MTU not supported in Darwin")
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
ifce = &Tun{
|
||||
Cidr: cidr,
|
||||
ReadWriteCloser: &tunReadCloser{f: file},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Tun) Activate() error {
|
||||
c.Device = "iOS"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Tun) WriteRaw(b []byte) error {
|
||||
_, err := c.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
|
||||
type tunReadCloser struct {
|
||||
f io.ReadWriteCloser
|
||||
|
||||
rMu sync.Mutex
|
||||
rBuf []byte
|
||||
|
||||
wMu sync.Mutex
|
||||
wBuf []byte
|
||||
}
|
||||
|
||||
func (t *tunReadCloser) Read(to []byte) (int, error) {
|
||||
t.rMu.Lock()
|
||||
defer t.rMu.Unlock()
|
||||
|
||||
if cap(t.rBuf) < len(to)+4 {
|
||||
t.rBuf = make([]byte, len(to)+4)
|
||||
}
|
||||
t.rBuf = t.rBuf[:len(to)+4]
|
||||
|
||||
n, err := t.f.Read(t.rBuf)
|
||||
copy(to, t.rBuf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tunReadCloser) Write(from []byte) (int, error) {
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
t.wMu.Lock()
|
||||
defer t.wMu.Unlock()
|
||||
|
||||
if cap(t.wBuf) < len(from)+4 {
|
||||
t.wBuf = make([]byte, len(from)+4)
|
||||
}
|
||||
t.wBuf = t.wBuf[:len(from)+4]
|
||||
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
t.wBuf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
t.wBuf[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, errors.New("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
copy(t.wBuf[4:], from)
|
||||
|
||||
n, err := t.f.Write(t.wBuf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tunReadCloser) Close() error {
|
||||
return t.f.Close()
|
||||
}
|
17
tun_linux.go
17
tun_linux.go
|
@ -75,6 +75,23 @@ type ifreqQLEN struct {
|
|||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
ifce = &Tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
Device: "tun0",
|
||||
Cidr: cidr,
|
||||
DefaultMTU: defaultMTU,
|
||||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
|
|
|
@ -18,9 +18,13 @@ type Tun struct {
|
|||
*water.Interface
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("Route MTU not supported in Windows")
|
||||
return nil, fmt.Errorf("route MTU not supported in Windows")
|
||||
}
|
||||
|
||||
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListenConfig(multi bool) net.ListenConfig {
|
||||
return net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
if multi {
|
||||
var controlErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
||||
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if controlErr != nil {
|
||||
return controlErr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() {
|
||||
return
|
||||
}
|
|
@ -32,3 +32,12 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() error {
|
||||
file, err := u.File()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0)
|
||||
}
|
||||
|
|
|
@ -32,3 +32,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// +build !linux
|
||||
// +build !linux android
|
||||
|
||||
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
||||
// means it can be used on platforms like Darwin and Windows.
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// +build !android
|
||||
|
||||
package nebula
|
||||
|
||||
import (
|
||||
|
@ -85,6 +87,10 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
|||
return &udpConn{sysFd: fd}, err
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() {
|
||||
return
|
||||
}
|
||||
|
||||
func (u *udpConn) SetRecvBuffer(n int) error {
|
||||
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// +build linux
|
||||
// +build 386 amd64p32 arm mips mipsle
|
||||
// +build !android
|
||||
|
||||
package nebula
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// +build linux
|
||||
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
|
||||
// +build !android
|
||||
|
||||
package nebula
|
||||
|
||||
|
|
|
@ -20,3 +20,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() {
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue