Merge pull request #79 from slackhq/subnet_support

Subnet and routing support.
This commit is contained in:
Ryan Huber 2019-12-12 12:36:56 -06:00 committed by GitHub
commit 5217f28264
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 249 additions and 67 deletions

View File

@ -1,9 +1,10 @@
package nebula
import (
"github.com/sirupsen/logrus"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet

View File

@ -10,12 +10,13 @@ import (
"github.com/stretchr/testify/assert"
)
var vpnIP uint32 = uint32(12341234)
var vpnIP uint32
func Test_NewConnectionManagerTest(t *testing.T) {
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects

View File

@ -100,6 +100,13 @@ tun:
routes:
#- mtu: 8800
# route: 10.0.0.0/16
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
# NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
unsafe_routes:
- route: 172.16.1.0/24
via: 192.168.100.99
# TODO
# Configure logging level

View File

@ -343,12 +343,17 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
return nil
}
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming) {
return false
}
// Make sure remote address matches nebula certificate
if h.remoteCidr.Contains(fp.RemoteIP) == nil {
return true
}
// Make sure we are supposed to be handling this local ip address
if f.localIps.Contains(fp.LocalIP) == nil {
return true
@ -360,7 +365,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert
}
// We now know which firewall table to check against
if !table.match(fp, incoming, c, caPool) {
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
return true
}

View File

@ -3,13 +3,14 @@ package nebula
import (
"encoding/binary"
"errors"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"math"
"net"
"testing"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
)
func TestNewFirewall(t *testing.T) {
@ -134,7 +135,7 @@ func TestFirewall_AddRule(t *testing.T) {
func TestFirewall_Drop(t *testing.T) {
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
101,
ip2int(net.IPv4(1, 2, 3, 4)),
10,
90,
fwProtoUDP,
@ -154,39 +155,51 @@ func TestFirewall_Drop(t *testing.T) {
Issuer: "signer-shasum",
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.True(t, fw.Drop([]byte{}, p, false, &c, cp))
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
// Allow inbound
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
// Allow outbound because conntrack
assert.False(t, fw.Drop([]byte{}, p, false, &c, cp))
assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
// test remote mismatch
oldRemote := p.RemoteIP
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
p.RemoteIP = oldRemote
// test caSha assertions true
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
// test caSha assertions false
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
// test caName true
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
// test caName false
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@ -286,7 +299,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
func TestFirewall_Drop2(t *testing.T) {
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
101,
ip2int(net.IPv4(1, 2, 3, 4)),
10,
90,
fwProtoUDP,
@ -305,6 +318,12 @@ func TestFirewall_Drop2(t *testing.T) {
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
}
h.CreateRemoteCIDR(&c)
c1 := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
@ -313,15 +332,21 @@ func TestFirewall_Drop2(t *testing.T) {
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
},
}
h1 := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c1,
},
}
h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
cp := cert.NewCAPool()
// c1 lacks the proper groups
assert.True(t, fw.Drop([]byte{}, p, true, &c1, cp))
// h1/c1 lacks the proper groups
assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
// c has the proper groups
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
}
func BenchmarkLookup(b *testing.B) {

View File

@ -205,6 +205,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
//hostinfo.ClearRemotes()
hostinfo.AddRemote(*addr)
hostinfo.CreateRemoteCIDR(remoteCert)
f.lightHouse.AddRemoteAndReset(ip, addr)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
@ -314,6 +315,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
//hostinfo.ClearRemotes()
f.hostMap.AddRemote(ip, addr)
hostinfo.CreateRemoteCIDR(remoteCert)
f.lightHouse.AddRemoteAndReset(ip, addr)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())

View File

@ -11,12 +11,13 @@ import (
var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
var ips []uint32 = []uint32{9000}
var ips []uint32
func Test_NewHandshakeManagerIndex(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
@ -54,9 +55,10 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
}
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
@ -102,9 +104,10 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
}
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
@ -114,7 +117,7 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
hostinfo := blah.AddVpnIP(101010)
hostinfo := blah.AddVpnIP(vpnIP)
// Pretned we have an index too
blah.AddIndexHostInfo(12341234, hostinfo)
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234))
@ -147,12 +150,12 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
l.Infoln(cumulative, next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick)
*/
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
}
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}

View File

@ -29,6 +29,7 @@ type HostMap struct {
preferredRanges []*net.IPNet
vpnCIDR *net.IPNet
defaultRoute uint32
unsafeRoutes *CIDRTree
}
type HostInfo struct {
@ -46,6 +47,7 @@ type HostInfo struct {
localIndexId uint32
hostId uint32
recvError int
remoteCidr *CIDRTree
lastRoam time.Time
lastRoamRemote *udpAddr
@ -82,6 +84,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
defaultRoute: 0,
unsafeRoutes: NewCIDRTree(),
}
return &m
}
@ -286,13 +289,6 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn
}
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
if hm.vpnCIDR.Contains(int2ip(vpnIp)) == false && hm.defaultRoute != 0 {
// FIXME: this shouldn't ship
d := hm.Hosts[hm.defaultRoute]
if d != nil {
return hm.Hosts[hm.defaultRoute], nil
}
}
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
if promoteIfce != nil {
@ -314,6 +310,15 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
}
}
func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
r := hm.unsafeRoutes.MostSpecificContains(ip)
if r != nil {
return r.(uint32)
} else {
return 0
}
}
func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool {
hm.RLock()
if i, ok := hm.Hosts[vpnIP]; ok {
@ -387,6 +392,13 @@ func (hm *HostMap) Punchy(conn *udpConn) {
}
}
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
for _, r := range *routes {
l.WithField("route", r.route).WithField("via", r.via).Error("Adding UNSAFE Route")
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
}
}
func (i *HostInfo) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"remote": i.remote,
@ -610,6 +622,18 @@ func (i *HostInfo) RecvErrorExceeded() bool {
return true
}
func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
remoteCidr := NewCIDRTree()
for _, ip := range c.Details.Ips {
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
}
for _, n := range c.Details.Subnets {
remoteCidr.AddCIDR(n, struct{}{})
}
i.remoteCidr = remoteCidr
}
//########################
func NewHostInfoDest(addr *udpAddr) *HostInfoDest {

View File

@ -74,26 +74,26 @@ func TestHostmap(t *testing.T) {
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
info, _ := m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1")))
info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
// There should be three remotes in the host map
assert.Equal(t, 3, len(info.Remotes))
// Adding an identical remote should not change the count
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
assert.Equal(t, 3, len(info.Remotes))
// Adding a fresh remote should add one
y = NewUDPAddrFromString("10.18.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
assert.Equal(t, 4, len(info.Remotes))
// Query and reference remote should get the first one (and not nil)
info, _ = m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1")))
info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
assert.NotNil(t, info.remote)
// Promotion should ensure that the best remote is chosen (y)
@ -111,9 +111,9 @@ func TestHostmapdebug(t *testing.T) {
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
//t.Errorf("%s", m.DebugRemotes(1))
}
@ -157,9 +157,9 @@ func BenchmarkHostmappromote2(b *testing.B) {
y := NewUDPAddrFromString("10.128.0.3:11111")
a := NewUDPAddrFromString("10.127.0.3:11111")
g := NewUDPAddrFromString("1.0.0.1:22222")
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), g)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
}
b.Errorf("hi")

View File

@ -39,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
ci.queueLock.Unlock()
}
if !f.firewall.Drop(packet, *fwPacket, false, ci.peerCert, trustedCAs) {
if !f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) {
f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
if f.lightHouse != nil && *ci.messageCounter%5000 == 0 {
f.lightHouse.Query(fwPacket.RemoteIP, f)
@ -52,6 +52,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
}
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
}
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
//if err != nil || hostinfo.ConnectionState == nil {
@ -97,7 +100,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
}
// check if packet is in outbound fw rules
if f.firewall.Drop(p, *fp, false, hostInfo.ConnectionState.peerCert, trustedCAs) {
if f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) {
l.WithField("fwPacket", fp).Debugln("dropping cached packet")
return
}

View File

@ -82,6 +82,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
if err != nil {
l.WithError(err).Fatal("Could not parse tun.routes")
}
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
if err != nil {
l.WithError(err).Fatal("Could not parse tun.unsafe_routes")
}
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(ssh, config)
@ -107,8 +111,9 @@ func Main(configPath string, configTest bool, buildVersion string) {
tun, err := newTun(
config.GetString("tun.dev", ""),
tunCidr,
config.GetInt("tun.mtu", 1300),
config.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
)
if err != nil {
@ -163,6 +168,8 @@ func Main(configPath string, configTest bool, buildVersion string) {
hostMap := NewHostMap("main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
hostMap.addUnsafeRoutes(&unsafeRoutes)
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
/*

View File

@ -255,13 +255,6 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
var err error
// TODO: This breaks subnet routing and needs to also check range of ip subnet
/*
if len(res) > 16 && binary.BigEndian.Uint32(res[12:16]) != ip2int(ci.peerCert.Details.Ips[0].IP) {
l.Debugf("Host %s tried to spoof packet as %s.", ci.peerCert.Details.Ips[0].IP, IntIp(binary.BigEndian.Uint32(res[12:16])))
}
*/
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet")
@ -283,7 +276,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return
}
if f.firewall.Drop(out, *fwPacket, true, hostinfo.ConnectionState.peerCert, trustedCAs) {
if f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
Debugln("dropping inbound packet")
return

View File

@ -6,9 +6,12 @@ import (
"strconv"
)
const DEFAULT_MTU = 1300
type route struct {
mtu int
route *net.IPNet
via *net.IP
}
func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
@ -81,6 +84,91 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
return routes, nil
}
func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
var err error
r := config.Get("tun.unsafe_routes")
if r == nil {
return []route{}, nil
}
rawRoutes, ok := r.([]interface{})
if !ok {
return nil, fmt.Errorf("tun.unsafe_routes is not an array")
}
if len(rawRoutes) < 1 {
return []route{}, nil
}
routes := make([]route, len(rawRoutes))
for i, r := range rawRoutes {
m, ok := r.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
}
rMtu, ok := m["mtu"]
if !ok {
rMtu = config.GetInt("tun.mtu", DEFAULT_MTU)
}
mtu, ok := rMtu.(int)
if !ok {
mtu, err = strconv.Atoi(rMtu.(string))
if err != nil {
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is not an integer: %v", i+1, err)
}
}
if mtu < 500 {
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is below 500: %v", i+1, mtu)
}
rVia, ok := m["via"]
if !ok {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
}
via, ok := rVia.(string)
if !ok {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: %v", i+1, err)
}
nVia := net.ParseIP(via)
if nVia == nil {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via)
}
rRoute, ok := m["route"]
if !ok {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
}
r := route{
via: &nVia,
}
_, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
}
if ipWithin(network, r.route) {
return nil, fmt.Errorf(
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
i+1,
r.route.String(),
network.String(),
)
}
routes[i] = r
}
return routes, nil
}
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
// Make sure o contains the lowest form of i
if !o.Contains(i.IP.Mask(i.Mask)) {

View File

@ -17,10 +17,13 @@ type Tun struct {
*water.Interface
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Darwin")
}
if len(unsafeRoutes) > 0 {
return nil, fmt.Errorf("unsafeRoutes not supported in Darwin")
}
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
return &Tun{
Cidr: cidr,

View File

@ -14,13 +14,14 @@ import (
type Tun struct {
io.ReadWriteCloser
fd int
Device string
Cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
Routes []route
fd int
Device string
Cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
Routes []route
UnsafeRoutes []route
}
type ifReq struct {
@ -74,7 +75,7 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
@ -106,6 +107,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
}
return
}
@ -238,6 +240,21 @@ func (c Tun) Activate() error {
}
}
// Unsafe path routes
for _, r := range c.UnsafeRoutes {
nr := netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: r.route,
MTU: r.mtu,
Scope: unix.RT_SCOPE_LINK,
}
err = netlink.RouteAdd(&nr)
if err != nil {
return fmt.Errorf("failed to set mtu %v on route %v; %v", r.mtu, r.route, err)
}
}
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {

View File

@ -16,10 +16,13 @@ type Tun struct {
*water.Interface
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Windows")
}
if len(unsafeRoutes) > 0 {
return nil, fmt.Errorf("unsafeRoutes not supported in Windows")
}
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
return &Tun{
Cidr: cidr,