Remove WriteRaw, cidrTree -> routeTree to better describe its purpose, remove redundancy from field names (#582)

This commit is contained in:
Nate Brown 2021-11-12 12:47:09 -06:00 committed by GitHub
parent 467e605d5e
commit 78d0d46bae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 137 additions and 204 deletions

View File

@ -92,7 +92,7 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
SrcIP: c.f.inside.CidrNet().IP, SrcIP: c.f.inside.Cidr().IP,
DstIP: toIp, DstIP: toIp,
} }

View File

@ -147,7 +147,7 @@ func (f *Interface) activate() {
f.l.WithError(err).Error("Failed to get udp listen address") f.l.WithError(err).Error("Failed to get udp listen address")
} }
f.l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()). f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
WithField("build", f.version).WithField("udpAddr", addr). WithField("build", f.version).WithField("udpAddr", addr).
Info("Nebula interface is active") Info("Nebula interface is active")

View File

@ -10,9 +10,8 @@ import (
type Device interface { type Device interface {
io.ReadWriteCloser io.ReadWriteCloser
Activate() error Activate() error
CidrNet() *net.IPNet Cidr() *net.IPNet
DeviceName() string Name() string
WriteRaw([]byte) error
RouteFor(iputil.VpnIp) iputil.VpnIp RouteFor(iputil.VpnIp) iputil.VpnIp
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (io.ReadWriteCloser, error)
} }

View File

@ -4,8 +4,10 @@ import (
"fmt" "fmt"
"math" "math"
"net" "net"
"runtime"
"strconv" "strconv"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
@ -16,6 +18,20 @@ type Route struct {
Via *net.IP Via *net.IP
} }
func makeRouteTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) {
routeTree := cidr.NewTree4()
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
return nil, fmt.Errorf("route MTU is not supported in %s", runtime.GOOS)
}
if r.Via != nil {
routeTree.AddCIDR(r.Cidr, r.Via)
}
}
return routeTree, nil
}
func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
var err error var err error

View File

@ -1,12 +1,9 @@
package overlay package overlay
import ( import (
"fmt"
"net" "net"
"runtime"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@ -52,17 +49,3 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
) )
} }
} }
func makeCidrTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) {
cidrTree := cidr.NewTree4()
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
return nil, fmt.Errorf("route MTU is not supported in %s", runtime.GOOS)
}
if r.Via != nil {
cidrTree.AddCIDR(r.Cidr, r.Via)
}
}
return cidrTree, nil
}

View File

@ -12,13 +12,12 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"golang.org/x/sys/unix"
) )
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
Cidr *net.IPNet cidr *net.IPNet
l *logrus.Logger l *logrus.Logger
} }
@ -32,7 +31,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes
return &tun{ return &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
Cidr: cidr, cidr: cidr,
l: l, l: l,
}, nil }, nil
} }
@ -45,37 +44,15 @@ func (t *tun) RouteFor(iputil.VpnIp) iputil.VpnIp {
return 0 return 0
} }
func (t *tun) WriteRaw(b []byte) error {
var nn int
for {
max := len(b)
n, err := unix.Write(t.fd, b[nn:max])
if n > 0 {
nn += n
}
if nn == len(b) {
return err
}
if err != nil {
return err
}
if n == 0 {
return io.ErrUnexpectedEOF
}
}
}
func (t tun) Activate() error { func (t tun) Activate() error {
return nil return nil
} }
func (t *tun) CidrNet() *net.IPNet { func (t *tun) Cidr() *net.IPNet {
return t.Cidr return t.cidr
} }
func (t *tun) DeviceName() string { func (t *tun) Name() string {
return "android" return "android"
} }

View File

@ -21,10 +21,10 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
Device string Device string
Cidr *net.IPNet cidr *net.IPNet
DefaultMTU int DefaultMTU int
Routes []Route Routes []Route
cidrTree *cidr.Tree4 routeTree *cidr.Tree4
l *logrus.Logger l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
@ -77,7 +77,7 @@ type ifreqMTU struct {
} }
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
cidrTree, err := makeCidrTree(routes, false) routeTree, err := makeRouteTree(routes, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -152,10 +152,10 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout
tun := &tun{ tun := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
Device: name, Device: name,
Cidr: cidr, cidr: cidr,
DefaultMTU: defaultMTU, DefaultMTU: defaultMTU,
Routes: routes, Routes: routes,
cidrTree: cidrTree, routeTree: routeTree,
l: l, l: l,
} }
@ -185,8 +185,8 @@ func (t *tun) Activate() error {
var addr, mask [4]byte var addr, mask [4]byte
copy(addr[:], t.Cidr.IP.To4()) copy(addr[:], t.cidr.IP.To4())
copy(mask[:], t.Cidr.Mask) copy(mask[:], t.cidr.Mask)
s, err := unix.Socket( s, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
@ -303,7 +303,7 @@ func (t *tun) Activate() error {
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.cidrTree.MostSpecificContains(ip) r := t.routeTree.MostSpecificContains(ip)
if r != nil { if r != nil {
return r.(iputil.VpnIp) return r.(iputil.VpnIp)
} }
@ -403,19 +403,14 @@ func (t *tun) Write(from []byte) (int, error) {
return n - 4, err return n - 4, err
} }
func (t *tun) CidrNet() *net.IPNet { func (t *tun) Cidr() *net.IPNet {
return t.Cidr return t.cidr
} }
func (t *tun) DeviceName() string { func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }

View File

@ -48,11 +48,11 @@ func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
return 0 return 0
} }
func (t *disabledTun) CidrNet() *net.IPNet { func (t *disabledTun) Cidr() *net.IPNet {
return t.cidr return t.cidr
} }
func (*disabledTun) DeviceName() string { func (*disabledTun) Name() string {
return "disabled" return "disabled"
} }
@ -128,11 +128,6 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (t *disabledTun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil return t, nil
} }

View File

@ -21,12 +21,12 @@ import (
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
type tun struct { type tun struct {
Device string Device string
Cidr *net.IPNet cidr *net.IPNet
MTU int MTU int
Routes []Route Routes []Route
cidrTree *cidr.Tree4 routeTree *cidr.Tree4
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
} }
@ -43,7 +43,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
cidrTree, err := makeCidrTree(routes, false) routeTree, err := makeRouteTree(routes, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -55,12 +55,12 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`") return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
} }
return &tun{ return &tun{
Device: deviceName, Device: deviceName,
Cidr: cidr, cidr: cidr,
MTU: defaultMTU, MTU: defaultMTU,
Routes: routes, Routes: routes,
cidrTree: cidrTree, routeTree: routeTree,
l: l, l: l,
}, nil }, nil
} }
@ -72,12 +72,12 @@ func (t *tun) Activate() error {
} }
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
t.l.Debug("command: ifconfig", t.Device, t.Cidr.String(), t.Cidr.IP.String()) t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", t.Device, t.Cidr.String(), t.Cidr.IP.String()).Run(); err != nil { if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
t.l.Debug("command: route", "-n", "add", "-net", t.Cidr.String(), "-interface", t.Device) t.l.Debug("command: route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", t.Cidr.String(), "-interface", t.Device).Run(); err != nil { if err = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
} }
t.l.Debug("command: ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) t.l.Debug("command: ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
@ -101,7 +101,7 @@ func (t *tun) Activate() error {
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.cidrTree.MostSpecificContains(ip) r := t.routeTree.MostSpecificContains(ip)
if r != nil { if r != nil {
return r.(iputil.VpnIp) return r.(iputil.VpnIp)
} }
@ -109,19 +109,14 @@ func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
return 0 return 0
} }
func (t *tun) CidrNet() *net.IPNet { func (t *tun) Cidr() *net.IPNet {
return t.Cidr return t.cidr
} }
func (t *tun) DeviceName() string { func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
} }

View File

@ -19,7 +19,7 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
Cidr *net.IPNet cidr *net.IPNet
} }
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
@ -33,8 +33,7 @@ func newTunFromFd(_ *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes
file := os.NewFile(uintptr(deviceFd), "/dev/tun") file := os.NewFile(uintptr(deviceFd), "/dev/tun")
return &tun{ return &tun{
Cidr: cidr, cidr: cidr,
Device: "iOS",
ReadWriteCloser: &tunReadCloser{f: file}, ReadWriteCloser: &tunReadCloser{f: file},
}, nil }, nil
} }
@ -47,11 +46,6 @@ func (t *tun) RouteFor(iputil.VpnIp) iputil.VpnIp {
return 0 return 0
} }
func (t *tun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
// The following is hoisted up from water, we do this so we can inject our own fd on iOS // The following is hoisted up from water, we do this so we can inject our own fd on iOS
type tunReadCloser struct { type tunReadCloser struct {
f io.ReadWriteCloser f io.ReadWriteCloser
@ -110,11 +104,11 @@ func (tr *tunReadCloser) Close() error {
return tr.f.Close() return tr.f.Close()
} }
func (t *tun) CidrNet() *net.IPNet { func (t *tun) Cidr() *net.IPNet {
return t.Cidr return t.cidr
} }
func (t *tun) DeviceName() string { func (t *tun) Name() string {
return "iOS" return "iOS"
} }

View File

@ -22,12 +22,12 @@ type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
Device string Device string
Cidr *net.IPNet cidr *net.IPNet
MaxMTU int MaxMTU int
DefaultMTU int DefaultMTU int
TXQueueLen int TXQueueLen int
Routes []Route Routes []Route
cidrTree *cidr.Tree4 routeTree *cidr.Tree4
l *logrus.Logger l *logrus.Logger
} }
@ -64,7 +64,7 @@ type ifreqQLEN struct {
} }
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
cidrTree, err := makeCidrTree(routes, true) routeTree, err := makeRouteTree(routes, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,11 +75,11 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
ReadWriteCloser: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
Device: "tun0", Device: "tun0",
Cidr: cidr, cidr: cidr,
DefaultMTU: defaultMTU, DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen, TXQueueLen: txQueueLen,
Routes: routes, Routes: routes,
cidrTree: cidrTree, routeTree: routeTree,
l: l, l: l,
}, nil }, nil
} }
@ -110,7 +110,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
} }
} }
cidrTree, err := makeCidrTree(routes, true) routeTree, err := makeRouteTree(routes, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -119,12 +119,12 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
ReadWriteCloser: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
Device: name, Device: name,
Cidr: cidr, cidr: cidr,
MaxMTU: maxMTU, MaxMTU: maxMTU,
DefaultMTU: defaultMTU, DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen, TXQueueLen: txQueueLen,
Routes: routes, Routes: routes,
cidrTree: cidrTree, routeTree: routeTree,
l: l, l: l,
}, nil }, nil
} }
@ -148,7 +148,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.cidrTree.MostSpecificContains(ip) r := t.routeTree.MostSpecificContains(ip)
if r != nil { if r != nil {
return r.(iputil.VpnIp) return r.(iputil.VpnIp)
} }
@ -156,32 +156,29 @@ func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
return 0 return 0
} }
func (t *tun) WriteRaw(b []byte) error { func (t *tun) Write(b []byte) (int, error) {
var nn int var nn int
max := len(b)
for { for {
max := len(b)
n, err := unix.Write(t.fd, b[nn:max]) n, err := unix.Write(t.fd, b[nn:max])
if n > 0 { if n > 0 {
nn += n nn += n
} }
if nn == len(b) { if nn == len(b) {
return err return nn, err
} }
if err != nil { if err != nil {
return err return nn, err
} }
if n == 0 { if n == 0 {
return io.ErrUnexpectedEOF return nn, io.ErrUnexpectedEOF
} }
} }
} }
func (t *tun) Write(b []byte) (int, error) {
return len(b), t.WriteRaw(b)
}
func (t tun) deviceBytes() (o [16]byte) { func (t tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device { for i, c := range t.Device {
o[i] = byte(c) o[i] = byte(c)
@ -194,8 +191,8 @@ func (t tun) Activate() error {
var addr, mask [4]byte var addr, mask [4]byte
copy(addr[:], t.Cidr.IP.To4()) copy(addr[:], t.cidr.IP.To4())
copy(mask[:], t.Cidr.Mask) copy(mask[:], t.cidr.Mask)
s, err := unix.Socket( s, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
@ -259,14 +256,14 @@ func (t tun) Activate() error {
} }
// Default route // Default route
dr := &net.IPNet{IP: t.Cidr.IP.Mask(t.Cidr.Mask), Mask: t.Cidr.Mask} dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
nr := netlink.Route{ nr := netlink.Route{
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Dst: dr, Dst: dr,
MTU: t.DefaultMTU, MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}), AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
Src: t.Cidr.IP, Src: t.cidr.IP,
Protocol: unix.RTPROT_KERNEL, Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN, Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST, Type: unix.RTN_UNICAST,
@ -305,11 +302,11 @@ func (t tun) Activate() error {
return nil return nil
} }
func (t *tun) CidrNet() *net.IPNet { func (t *tun) Cidr() *net.IPNet {
return t.Cidr return t.cidr
} }
func (t *tun) DeviceName() string { func (t *tun) Name() string {
return t.Device return t.Device
} }

View File

@ -14,27 +14,27 @@ import (
) )
type TestTun struct { type TestTun struct {
Device string Device string
Cidr *net.IPNet cidr *net.IPNet
Routes []Route Routes []Route
cidrTree *cidr.Tree4 routeTree *cidr.Tree4
l *logrus.Logger l *logrus.Logger
rxPackets chan []byte // Packets to receive into nebula rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
cidrTree, err := makeCidrTree(routes, false) routeTree, err := makeRouteTree(routes, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &TestTun{ return &TestTun{
Device: deviceName, Device: deviceName,
Cidr: cidr, cidr: cidr,
Routes: routes, Routes: routes,
cidrTree: cidrTree, routeTree: routeTree,
l: l, l: l,
rxPackets: make(chan []byte, 1), rxPackets: make(chan []byte, 1),
TxPackets: make(chan []byte, 1), TxPackets: make(chan []byte, 1),
@ -74,7 +74,7 @@ func (t *TestTun) Get(block bool) []byte {
//********************************************************************************************************************// //********************************************************************************************************************//
func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.cidrTree.MostSpecificContains(ip) r := t.routeTree.MostSpecificContains(ip)
if r != nil { if r != nil {
return r.(iputil.VpnIp) return r.(iputil.VpnIp)
} }
@ -86,16 +86,19 @@ func (t *TestTun) Activate() error {
return nil return nil
} }
func (t *TestTun) CidrNet() *net.IPNet { func (t *TestTun) Cidr() *net.IPNet {
return t.Cidr return t.cidr
} }
func (t *TestTun) DeviceName() string { func (t *TestTun) Name() string {
return t.Device return t.Device
} }
func (t *TestTun) Write(b []byte) (n int, err error) { func (t *TestTun) Write(b []byte) (n int, err error) {
return len(b), t.WriteRaw(b) packet := make([]byte, len(b), len(b))
copy(packet, b)
t.TxPackets <- packet
return len(b), nil
} }
func (t *TestTun) Close() error { func (t *TestTun) Close() error {
@ -103,13 +106,6 @@ func (t *TestTun) Close() error {
return nil return nil
} }
func (t *TestTun) WriteRaw(b []byte) error {
packet := make([]byte, len(b), len(b))
copy(packet, b)
t.TxPackets <- packet
return nil
}
func (t *TestTun) Read(b []byte) (int, error) { func (t *TestTun) Read(b []byte) (int, error) {
p := <-t.rxPackets p := <-t.rxPackets
copy(b, p) copy(b, p)

View File

@ -13,27 +13,27 @@ import (
) )
type waterTun struct { type waterTun struct {
Device string Device string
Cidr *net.IPNet cidr *net.IPNet
MTU int MTU int
Routes []Route Routes []Route
cidrTree *cidr.Tree4 routeTree *cidr.Tree4
*water.Interface *water.Interface
} }
func newWaterTun(cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) { func newWaterTun(cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
cidrTree, err := makeCidrTree(routes, false) routeTree, err := makeRouteTree(routes, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
return &waterTun{ return &waterTun{
Cidr: cidr, cidr: cidr,
MTU: defaultMTU, MTU: defaultMTU,
Routes: routes, Routes: routes,
cidrTree: cidrTree, routeTree: routeTree,
}, nil }, nil
} }
@ -43,7 +43,7 @@ func (t *waterTun) Activate() error {
DeviceType: water.TUN, DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{ PlatformSpecificParams: water.PlatformSpecificParams{
ComponentID: "tap0901", ComponentID: "tap0901",
Network: t.Cidr.String(), Network: t.cidr.String(),
}, },
}) })
if err != nil { if err != nil {
@ -57,8 +57,8 @@ func (t *waterTun) Activate() error {
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", t.Device), fmt.Sprintf("name=%s", t.Device),
"source=static", "source=static",
fmt.Sprintf("addr=%s", t.Cidr.IP), fmt.Sprintf("addr=%s", t.cidr.IP),
fmt.Sprintf("mask=%s", net.IP(t.Cidr.Mask)), fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)),
"gateway=none", "gateway=none",
).Run() ).Run()
if err != nil { if err != nil {
@ -96,7 +96,7 @@ func (t *waterTun) Activate() error {
} }
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.cidrTree.MostSpecificContains(ip) r := t.routeTree.MostSpecificContains(ip)
if r != nil { if r != nil {
return r.(iputil.VpnIp) return r.(iputil.VpnIp)
} }
@ -104,19 +104,14 @@ func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
return 0 return 0
} }
func (t *waterTun) CidrNet() *net.IPNet { func (t *waterTun) Cidr() *net.IPNet {
return t.Cidr return t.cidr
} }
func (t *waterTun) DeviceName() string { func (t *waterTun) Name() string {
return t.Device return t.Device
} }
func (t *waterTun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
func (t *waterTun) Close() error { func (t *waterTun) Close() error {
if t.Interface == nil { if t.Interface == nil {
return nil return nil

View File

@ -17,11 +17,11 @@ import (
const tunGUIDLabel = "Fixed Nebula Windows GUID v1" const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct { type winTun struct {
Device string Device string
Cidr *net.IPNet cidr *net.IPNet
MTU int MTU int
Routes []Route Routes []Route
cidrTree *cidr.Tree4 routeTree *cidr.Tree4
tun *wintun.NativeTun tun *wintun.NativeTun
} }
@ -56,17 +56,17 @@ func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []Rout
return nil, fmt.Errorf("create TUN device failed: %w", err) return nil, fmt.Errorf("create TUN device failed: %w", err)
} }
cidrTree, err := makeCidrTree(routes, false) routeTree, err := makeRouteTree(routes, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &winTun{ return &winTun{
Device: deviceName, Device: deviceName,
Cidr: cidr, cidr: cidr,
MTU: defaultMTU, MTU: defaultMTU,
Routes: routes, Routes: routes,
cidrTree: cidrTree, routeTree: routeTree,
tun: tunDevice.(*wintun.NativeTun), tun: tunDevice.(*wintun.NativeTun),
}, nil }, nil
@ -75,7 +75,7 @@ func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []Rout
func (t *winTun) Activate() error { func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID()) luid := winipcfg.LUID(t.tun.LUID())
if err := luid.SetIPAddresses([]net.IPNet{*t.Cidr}); err != nil { if err := luid.SetIPAddresses([]net.IPNet{*t.cidr}); err != nil {
return fmt.Errorf("failed to set address: %w", err) return fmt.Errorf("failed to set address: %w", err)
} }
@ -125,7 +125,7 @@ func (t *winTun) Activate() error {
} }
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.cidrTree.MostSpecificContains(ip) r := t.routeTree.MostSpecificContains(ip)
if r != nil { if r != nil {
return r.(iputil.VpnIp) return r.(iputil.VpnIp)
} }
@ -133,11 +133,11 @@ func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
return 0 return 0
} }
func (t *winTun) CidrNet() *net.IPNet { func (t *winTun) Cidr() *net.IPNet {
return t.Cidr return t.cidr
} }
func (t *winTun) DeviceName() string { func (t *winTun) Name() string {
return t.Device return t.Device
} }
@ -149,11 +149,6 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0) return t.tun.Write(b, 0)
} }
func (t *winTun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
} }

View File

@ -18,11 +18,11 @@ func (NoopTun) Activate() error {
return nil return nil
} }
func (NoopTun) CidrNet() *net.IPNet { func (NoopTun) Cidr() *net.IPNet {
return nil return nil
} }
func (NoopTun) DeviceName() string { func (NoopTun) Name() string {
return "noop" return "noop"
} }
@ -34,10 +34,6 @@ func (NoopTun) Write([]byte) (int, error) {
return 0, nil return 0, nil
} }
func (NoopTun) WriteRaw([]byte) error {
return nil
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, errors.New("unsupported") return nil, errors.New("unsupported")
} }