Be more like a library to support mobile (#247)

This commit is contained in:
Nathan Brown 2020-06-30 13:48:58 -05:00 committed by GitHub
parent 1ea8847085
commit 41578ca971
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 477 additions and 69 deletions

View File

@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) {
func BenchmarkBits(b *testing.B) { func BenchmarkBits(b *testing.B) {
z := NewBits(10) z := NewBits(10)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
for i, _ := range z.bits { for i := range z.bits {
z.bits[i] = true z.bits[i] = true
} }
for i, _ := range z.bits { for i := range z.bits {
z.bits[i] = false z.bits[i] = false
} }

View File

@ -3,9 +3,9 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"os" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"os"
) )
// A version string that can be set with // A version string that can be set with
@ -45,5 +45,25 @@ func main() {
os.Exit(1) os.Exit(1)
} }
nebula.Main(*configPath, *configTest, Build) config := nebula.NewConfig()
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
switch v := err.(type) {
case nebula.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
os.Exit(0)
} }

View File

@ -1,6 +1,8 @@
package main package main
import ( import (
"fmt"
"github.com/sirupsen/logrus"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
@ -27,8 +29,15 @@ func (p *program) Start(s service.Service) error {
} }
func (p *program) run() error { func (p *program) run() error {
nebula.Main(*p.configPath, *p.configTest, Build) config := nebula.NewConfig()
return nil err := config.Load(*p.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %s", err)
}
l := logrus.New()
l.Out = os.Stdout
return nebula.Main(config, *p.configTest, true, Build, l, nil, nil)
} }
func (p *program) Stop(s service.Service) error { func (p *program) Stop(s service.Service) error {

View File

@ -3,6 +3,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/sirupsen/logrus"
"os" "os"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
@ -39,5 +40,25 @@ func main() {
os.Exit(1) os.Exit(1)
} }
nebula.Main(*configPath, *configTest, Build) config := nebula.NewConfig()
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
switch v := err.(type) {
case nebula.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
os.Exit(0)
} }

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"errors"
"fmt" "fmt"
"github.com/imdario/mergo" "github.com/imdario/mergo"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -56,6 +57,13 @@ func (c *Config) Load(path string) error {
return nil return nil
} }
func (c *Config) LoadString(raw string) error {
if raw == "" {
return errors.New("Empty configuration")
}
return c.parseRaw([]byte(raw))
}
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
// here should decide if they need to make a change to the current process before making the change. HasChanged can be // here should decide if they need to make a change to the current process before making the change. HasChanged can be
// used to help decide if a change is necessary. // used to help decide if a change is necessary.
@ -407,6 +415,18 @@ func (c *Config) addFile(path string, direct bool) error {
return nil return nil
} }
func (c *Config) parseRaw(b []byte) error {
var m map[interface{}]interface{}
err := yaml.Unmarshal(b, &m)
if err != nil {
return err
}
c.Settings = m
return nil
}
func (c *Config) parse() error { func (c *Config) parse() error {
var m map[interface{}]interface{} var m map[interface{}]interface{}

31
logger.go Normal file
View File

@ -0,0 +1,31 @@
package nebula
import (
"github.com/sirupsen/logrus"
)
type ContextualError struct {
RealError error
Fields map[string]interface{}
Context string
}
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
return ContextualError{Context: msg, Fields: fields, RealError: realError}
}
func (ce ContextualError) Error() string {
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
return ce.RealError
}
func (ce *ContextualError) Log(lr *logrus.Logger) {
if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
}
}

66
logger_test.go Normal file
View File

@ -0,0 +1,66 @@
package nebula
import (
"errors"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"testing"
)
type TestLogWriter struct {
Logs []string
}
func NewTestLogWriter() *TestLogWriter {
return &TestLogWriter{Logs: make([]string, 0)}
}
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
tl.Logs = append(tl.Logs, string(p))
return len(p), nil
}
func (tl *TestLogWriter) Reset() {
tl.Logs = tl.Logs[:0]
}
func TestContextualError_Log(t *testing.T) {
l := logrus.New()
l.Formatter = &logrus.TextFormatter{
DisableTimestamp: true,
DisableColors: true,
}
tl := NewTestLogWriter()
l.Out = tl
// Test a full context line
tl.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
// Test a line with an error and msg but no fields
tl.Reset()
e = NewContextualError("test message", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
// Test just a context and fields
tl.Reset()
e = NewContextualError("test message", m{"field": "1"}, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
// Test just a context
tl.Reset()
e = NewContextualError("test message", nil, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
// Test just an error
tl.Reset()
e = NewContextualError("", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
}

136
main.go
View File

@ -3,6 +3,9 @@ package nebula
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/sshd"
"gopkg.in/yaml.v2"
"net" "net"
"os" "os"
"os/signal" "os/signal"
@ -10,42 +13,38 @@ import (
"strings" "strings"
"syscall" "syscall"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/sshd"
"gopkg.in/yaml.v2"
) )
// The caller should provide a real logger, we have one just in case
var l = logrus.New() var l = logrus.New()
type m map[string]interface{} type m map[string]interface{}
func Main(configPath string, configTest bool, buildVersion string) { type CommandRequest struct {
l.Out = os.Stdout Command string
Callback chan error
}
func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error {
l = logger
l.Formatter = &logrus.TextFormatter{ l.Formatter = &logrus.TextFormatter{
FullTimestamp: true, FullTimestamp: true,
} }
config := NewConfig()
err := config.Load(configPath)
if err != nil {
l.WithError(err).Error("Failed to load config")
os.Exit(1)
}
// Print the config if in test, the exit comes later // Print the config if in test, the exit comes later
if configTest { if configTest {
b, err := yaml.Marshal(config.Settings) b, err := yaml.Marshal(config.Settings)
if err != nil { if err != nil {
l.Println(err) return err
os.Exit(1)
} }
// Print the final config
l.Println(string(b)) l.Println(string(b))
} }
err = configLogger(config) err := configLogger(config)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to configure the logger") return NewContextualError("Failed to configure the logger", nil, err)
} }
config.RegisterReloadCallback(func(c *Config) { config.RegisterReloadCallback(func(c *Config) {
@ -59,20 +58,20 @@ func Main(configPath string, configTest bool, buildVersion string) {
trustedCAs, err = loadCAFromConfig(config) trustedCAs, err = loadCAFromConfig(config)
if err != nil { if err != nil {
//The errors coming out of loadCA are already nicely formatted //The errors coming out of loadCA are already nicely formatted
l.WithError(err).Fatal("Failed to load ca from config") return NewContextualError("Failed to load ca from config", nil, err)
} }
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints") l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config) cs, err := NewCertStateFromConfig(config)
if err != nil { if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted //The errors coming out of NewCertStateFromConfig are already nicely formatted
l.WithError(err).Fatal("Failed to load certificate from config") return NewContextualError("Failed to load certificate from config", nil, err)
} }
l.WithField("cert", cs.certificate).Debug("Client nebula certificate") l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(cs.certificate, config) fw, err := NewFirewallFromConfig(cs.certificate, config)
if err != nil { if err != nil {
l.WithError(err).Fatal("Error while loading firewall rules") return NewContextualError("Error while loading firewall rules", nil, err)
} }
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
@ -80,11 +79,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
tunCidr := cs.certificate.Details.Ips[0] tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(config, tunCidr) routes, err := parseRoutes(config, tunCidr)
if err != nil { if err != nil {
l.WithError(err).Fatal("Could not parse tun.routes") return NewContextualError("Could not parse tun.routes", nil, err)
} }
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr) unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
if err != nil { if err != nil {
l.WithError(err).Fatal("Could not parse tun.unsafe_routes") return NewContextualError("Could not parse tun.unsafe_routes", nil, err)
} }
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
@ -92,7 +91,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if config.GetBool("sshd.enabled", false) { if config.GetBool("sshd.enabled", false) {
err = configSSH(ssh, config) err = configSSH(ssh, config)
if err != nil { if err != nil {
l.WithError(err).Fatal("Error while configuring the sshd") return NewContextualError("Error while configuring the sshd", nil, err)
} }
} }
@ -105,7 +104,16 @@ func Main(configPath string, configTest bool, buildVersion string) {
if !configTest { if !configTest {
config.CatchHUP() config.CatchHUP()
// set up our tun dev if tunFd != nil {
tun, err = newTunFromFd(
*tunFd,
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
)
} else {
tun, err = newTun( tun, err = newTun(
config.GetString("tun.dev", ""), config.GetString("tun.dev", ""),
tunCidr, tunCidr,
@ -114,8 +122,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
unsafeRoutes, unsafeRoutes,
config.GetInt("tun.tx_queue", 500), config.GetInt("tun.tx_queue", 500),
) )
}
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to get a tun/tap device") return NewContextualError("Failed to get a tun/tap device", nil, err)
} }
} }
@ -126,11 +136,28 @@ func Main(configPath string, configTest bool, buildVersion string) {
if !configTest { if !configTest {
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1) udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to open udp listener") return NewContextualError("Failed to open udp listener", nil, err)
} }
udpServer.reloadConfig(config) udpServer.reloadConfig(config)
} }
sigChan := make(chan os.Signal)
killChan := make(chan CommandRequest)
if commandChan != nil {
go func() {
cmd := CommandRequest{}
for {
cmd = <-commandChan
switch cmd.Command {
case "rebind":
udpServer.Rebind()
case "exit":
killChan <- cmd
}
}
}()
}
// Set up my internal host map // Set up my internal host map
var preferredRanges []*net.IPNet var preferredRanges []*net.IPNet
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{}) rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
@ -139,7 +166,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
for _, rawPreferredRange := range rawPreferredRanges { for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange) _, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to parse preferred ranges") return NewContextualError("Failed to parse preferred ranges", nil, err)
} }
preferredRanges = append(preferredRanges, preferredRange) preferredRanges = append(preferredRanges, preferredRange)
} }
@ -152,7 +179,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if rawLocalRange != "" { if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange) _, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to parse local range") return NewContextualError("Failed to parse local_range", nil, err)
} }
// Check if the entry for local_range was already specified in // Check if the entry for local_range was already specified in
@ -192,7 +219,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if port == 0 && !configTest { if port == 0 && !configTest {
uPort, err := udpServer.LocalAddr() uPort, err := udpServer.LocalAddr()
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to get listening port") return NewContextualError("Failed to get listening port", nil, err)
} }
port = int(uPort.Port) port = int(uPort.Port)
} }
@ -209,10 +236,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
for i, host := range rawLighthouseHosts { for i, host := range rawLighthouseHosts {
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip == nil { if ip == nil {
l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1) return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
} }
if !tunCidr.Contains(ip) { if !tunCidr.Contains(ip) {
l.WithField("vpnIp", ip).WithField("network", tunCidr.String()).Fatalf("lighthouse host is not in our subnet, invalid") return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
} }
lighthouseHosts[i] = ip2int(ip) lighthouseHosts[i] = ip2int(ip)
} }
@ -232,13 +259,13 @@ func Main(configPath string, configTest bool, buildVersion string) {
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false) remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
if err != nil { if err != nil {
l.WithError(err).Fatal("Invalid lighthouse.remote_allow_list") return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
} }
lightHouse.SetRemoteAllowList(remoteAllowList) lightHouse.SetRemoteAllowList(remoteAllowList)
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true) localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
if err != nil { if err != nil {
l.WithError(err).Fatal("Invalid lighthouse.local_allow_list") return NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
} }
lightHouse.SetLocalAllowList(localAllowList) lightHouse.SetLocalAllowList(localAllowList)
@ -246,7 +273,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
if !tunCidr.Contains(vpnIp) { if !tunCidr.Contains(vpnIp) {
l.WithField("vpnIp", vpnIp).WithField("network", tunCidr.String()).Fatalf("static_host_map key is not in our subnet, invalid") return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
} }
vals, ok := v.([]interface{}) vals, ok := v.([]interface{})
if ok { if ok {
@ -257,7 +284,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
ip := addr.IP ip := addr.IP
port, err := strconv.Atoi(parts[1]) port, err := strconv.Atoi(parts[1])
if err != nil { if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
} }
@ -270,7 +297,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
ip := addr.IP ip := addr.IP
port, err := strconv.Atoi(parts[1]) port, err := strconv.Atoi(parts[1])
if err != nil { if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
} }
@ -330,14 +357,14 @@ func Main(configPath string, configTest bool, buildVersion string) {
case "chachapoly": case "chachapoly":
noiseEndianness = binary.LittleEndian noiseEndianness = binary.LittleEndian
default: default:
l.Fatalf("Unknown cipher: %v", ifConfig.Cipher) return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
} }
var ifce *Interface var ifce *Interface
if !configTest { if !configTest {
ifce, err = NewInterface(ifConfig) ifce, err = NewInterface(ifConfig)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to initialize interface") return fmt.Errorf("failed to initialize interface: %s", err)
} }
ifce.RegisterConfigChangeCallbacks(config) ifce.RegisterConfigChangeCallbacks(config)
@ -348,11 +375,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
err = startStats(config, configTest) err = startStats(config, configTest)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to start stats emitter") return NewContextualError("Failed to start stats emitter", nil, err)
} }
if configTest { if configTest {
os.Exit(0) return nil
} }
//TODO: check if we _should_ be emitting stats //TODO: check if we _should_ be emitting stats
@ -367,19 +394,33 @@ func Main(configPath string, configTest bool, buildVersion string) {
go dnsMain(hostMap, config) go dnsMain(hostMap, config)
} }
if block {
// Just sit here and be friendly, main thread. // Just sit here and be friendly, main thread.
shutdownBlock(ifce) shutdownBlock(ifce, sigChan, killChan)
} else {
// Even though we aren't blocking we still want to shutdown gracefully
go shutdownBlock(ifce, sigChan, killChan)
}
return nil
} }
func shutdownBlock(ifce *Interface) { func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) {
var sigChan = make(chan os.Signal) var cmd CommandRequest
var sig string
signal.Notify(sigChan, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT) signal.Notify(sigChan, syscall.SIGINT)
sig := <-sigChan select {
case rawSig := <-sigChan:
sig = rawSig.String()
case cmd = <-killChan:
sig = "controlling app"
}
l.WithField("signal", sig).Info("Caught signal, shutting down") l.WithField("signal", sig).Info("Caught signal, shutting down")
//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though //TODO: stop tun and udp routines, the lock on hostMap effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly //TODO: this is probably better as a function in ConnectionManager or HostMap directly
ifce.hostMap.Lock() ifce.hostMap.Lock()
for _, h := range ifce.hostMap.Hosts { for _, h := range ifce.hostMap.Hosts {
@ -392,5 +433,8 @@ func shutdownBlock(ifce *Interface) {
ifce.hostMap.Unlock() ifce.hostMap.Unlock()
l.WithField("signal", sig).Info("Goodbye") l.WithField("signal", sig).Info("Goodbye")
os.Exit(0) select {
case cmd.Callback <- nil:
default:
}
} }

View File

@ -1,12 +1,13 @@
// +build !ios
package nebula package nebula
import ( import (
"fmt" "fmt"
"github.com/songgao/water"
"net" "net"
"os/exec" "os/exec"
"strconv" "strconv"
"github.com/songgao/water"
) )
type Tun struct { type Tun struct {
@ -20,8 +21,9 @@ type Tun struct {
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Darwin") return nil, fmt.Errorf("route MTU not supported in Darwin")
} }
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate() // NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
return &Tun{ return &Tun{
Cidr: cidr, Cidr: cidr,
@ -30,13 +32,17 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
}, nil }, nil
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
func (c *Tun) Activate() error { func (c *Tun) Activate() error {
var err error var err error
c.Interface, err = water.New(water.Config{ c.Interface, err = water.New(water.Config{
DeviceType: water.TUN, DeviceType: water.TUN,
}) })
if err != nil { if err != nil {
return fmt.Errorf("Activate failed: %v", err) return fmt.Errorf("activate failed: %v", err)
} }
c.Device = c.Interface.Name() c.Device = c.Interface.Name()

View File

@ -22,6 +22,10 @@ type Tun struct {
io.ReadWriteCloser io.ReadWriteCloser
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in FreeBSD") return nil, fmt.Errorf("Route MTU not supported in FreeBSD")

105
tun_ios.go Normal file
View File

@ -0,0 +1,105 @@
// +build ios
package nebula
import (
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
)
type Tun struct {
io.ReadWriteCloser
Device string
Cidr *net.IPNet
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin")
}
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
ifce = &Tun{
Cidr: cidr,
ReadWriteCloser: &tunReadCloser{f: file},
}
return
}
func (c *Tun) Activate() error {
c.Device = "iOS"
return nil
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
type tunReadCloser struct {
f io.ReadWriteCloser
rMu sync.Mutex
rBuf []byte
wMu sync.Mutex
wBuf []byte
}
func (t *tunReadCloser) Read(to []byte) (int, error) {
t.rMu.Lock()
defer t.rMu.Unlock()
if cap(t.rBuf) < len(to)+4 {
t.rBuf = make([]byte, len(to)+4)
}
t.rBuf = t.rBuf[:len(to)+4]
n, err := t.f.Read(t.rBuf)
copy(to, t.rBuf[4:])
return n - 4, err
}
func (t *tunReadCloser) Write(from []byte) (int, error) {
if len(from) == 0 {
return 0, syscall.EIO
}
t.wMu.Lock()
defer t.wMu.Unlock()
if cap(t.wBuf) < len(from)+4 {
t.wBuf = make([]byte, len(from)+4)
}
t.wBuf = t.wBuf[:len(from)+4]
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
t.wBuf[3] = syscall.AF_INET
} else if ipVer == 6 {
t.wBuf[3] = syscall.AF_INET6
} else {
return 0, errors.New("unable to determine IP version from packet")
}
copy(t.wBuf[4:], from)
n, err := t.f.Write(t.wBuf)
return n - 4, err
}
func (t *tunReadCloser) Close() error {
return t.f.Close()
}

View File

@ -75,6 +75,23 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "tun0",
Cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {

View File

@ -18,9 +18,13 @@ type Tun struct {
*water.Interface *water.Interface
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Windows") return nil, fmt.Errorf("route MTU not supported in Windows")
} }
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()

36
udp_android.go Normal file
View File

@ -0,0 +1,36 @@
package nebula
import (
"fmt"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}
func (u *udpConn) Rebind() {
return
}

View File

@ -32,3 +32,12 @@ func NewListenConfig(multi bool) net.ListenConfig {
}, },
} }
} }
func (u *udpConn) Rebind() error {
file, err := u.File()
if err != nil {
return err
}
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0)
}

View File

@ -32,3 +32,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
}, },
} }
} }
func (u *udpConn) Rebind() {
return
}

View File

@ -1,4 +1,4 @@
// +build !linux // +build !linux android
// udp_generic implements the nebula UDP interface in pure Go stdlib. This // udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows. // means it can be used on platforms like Darwin and Windows.

View File

@ -1,3 +1,5 @@
// +build !android
package nebula package nebula
import ( import (
@ -85,6 +87,10 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
return &udpConn{sysFd: fd}, err return &udpConn{sysFd: fd}, err
} }
func (u *udpConn) Rebind() {
return
}
func (u *udpConn) SetRecvBuffer(n int) error { func (u *udpConn) SetRecvBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
} }

View File

@ -1,5 +1,6 @@
// +build linux // +build linux
// +build 386 amd64p32 arm mips mipsle // +build 386 amd64p32 arm mips mipsle
// +build !android
package nebula package nebula

View File

@ -1,5 +1,6 @@
// +build linux // +build linux
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
// +build !android
package nebula package nebula

View File

@ -20,3 +20,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
}, },
} }
} }
func (u *udpConn) Rebind() {
return
}