Merge pull request #79 from slackhq/subnet_support
Subnet and routing support.
This commit is contained in:
commit
5217f28264
|
@ -1,9 +1,10 @@
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
|
// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
|
||||||
|
|
|
@ -10,12 +10,13 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var vpnIP uint32 = uint32(12341234)
|
var vpnIP uint32
|
||||||
|
|
||||||
func Test_NewConnectionManagerTest(t *testing.T) {
|
func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
|
|
||||||
// Very incomplete mock objects
|
// Very incomplete mock objects
|
||||||
|
|
|
@ -100,6 +100,13 @@ tun:
|
||||||
routes:
|
routes:
|
||||||
#- mtu: 8800
|
#- mtu: 8800
|
||||||
# route: 10.0.0.0/16
|
# route: 10.0.0.0/16
|
||||||
|
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes
|
||||||
|
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
|
||||||
|
# NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
|
||||||
|
unsafe_routes:
|
||||||
|
- route: 172.16.1.0/24
|
||||||
|
via: 192.168.100.99
|
||||||
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# Configure logging level
|
# Configure logging level
|
||||||
|
|
|
@ -343,12 +343,17 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
|
||||||
// Check if we spoke to this tuple, if we did then allow this packet
|
// Check if we spoke to this tuple, if we did then allow this packet
|
||||||
if f.inConns(packet, fp, incoming) {
|
if f.inConns(packet, fp, incoming) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make sure remote address matches nebula certificate
|
||||||
|
if h.remoteCidr.Contains(fp.RemoteIP) == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
if f.localIps.Contains(fp.LocalIP) == nil {
|
if f.localIps.Contains(fp.LocalIP) == nil {
|
||||||
return true
|
return true
|
||||||
|
@ -360,7 +365,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert
|
||||||
}
|
}
|
||||||
|
|
||||||
// We now know which firewall table to check against
|
// We now know which firewall table to check against
|
||||||
if !table.match(fp, incoming, c, caPool) {
|
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,13 +3,14 @@ package nebula
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/rcrowley/go-metrics"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rcrowley/go-metrics"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewFirewall(t *testing.T) {
|
func TestNewFirewall(t *testing.T) {
|
||||||
|
@ -134,7 +135,7 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
p := FirewallPacket{
|
p := FirewallPacket{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
101,
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
10,
|
10,
|
||||||
90,
|
90,
|
||||||
fwProtoUDP,
|
fwProtoUDP,
|
||||||
|
@ -154,39 +155,51 @@ func TestFirewall_Drop(t *testing.T) {
|
||||||
Issuer: "signer-shasum",
|
Issuer: "signer-shasum",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
h := HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: &c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h.CreateRemoteCIDR(&c)
|
||||||
|
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.True(t, fw.Drop([]byte{}, p, false, &c, cp))
|
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
|
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
// Allow outbound because conntrack
|
// Allow outbound because conntrack
|
||||||
assert.False(t, fw.Drop([]byte{}, p, false, &c, cp))
|
assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||||
|
|
||||||
|
// test remote mismatch
|
||||||
|
oldRemote := p.RemoteIP
|
||||||
|
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
|
||||||
|
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||||
|
p.RemoteIP = oldRemote
|
||||||
|
|
||||||
// test caSha assertions true
|
// test caSha assertions true
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
|
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
|
|
||||||
// test caSha assertions false
|
// test caSha assertions false
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
|
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
|
|
||||||
// test caName true
|
// test caName true
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
|
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
|
|
||||||
// test caName false
|
// test caName false
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
|
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
|
@ -286,7 +299,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
func TestFirewall_Drop2(t *testing.T) {
|
func TestFirewall_Drop2(t *testing.T) {
|
||||||
p := FirewallPacket{
|
p := FirewallPacket{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
101,
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
10,
|
10,
|
||||||
90,
|
90,
|
||||||
fwProtoUDP,
|
fwProtoUDP,
|
||||||
|
@ -305,6 +318,12 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||||
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
|
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
h := HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: &c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h.CreateRemoteCIDR(&c)
|
||||||
|
|
||||||
c1 := cert.NebulaCertificate{
|
c1 := cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
@ -313,15 +332,21 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||||
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
|
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
h1 := HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: &c1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h1.CreateRemoteCIDR(&c1)
|
||||||
|
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// c1 lacks the proper groups
|
// h1/c1 lacks the proper groups
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &c1, cp))
|
assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
|
||||||
// c has the proper groups
|
// c has the proper groups
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
|
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLookup(b *testing.B) {
|
func BenchmarkLookup(b *testing.B) {
|
||||||
|
|
|
@ -205,6 +205,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
|
|
||||||
//hostinfo.ClearRemotes()
|
//hostinfo.ClearRemotes()
|
||||||
hostinfo.AddRemote(*addr)
|
hostinfo.AddRemote(*addr)
|
||||||
|
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||||
f.lightHouse.AddRemoteAndReset(ip, addr)
|
f.lightHouse.AddRemoteAndReset(ip, addr)
|
||||||
if f.serveDns {
|
if f.serveDns {
|
||||||
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
||||||
|
@ -314,6 +315,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
|
|
||||||
//hostinfo.ClearRemotes()
|
//hostinfo.ClearRemotes()
|
||||||
f.hostMap.AddRemote(ip, addr)
|
f.hostMap.AddRemote(ip, addr)
|
||||||
|
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||||
f.lightHouse.AddRemoteAndReset(ip, addr)
|
f.lightHouse.AddRemoteAndReset(ip, addr)
|
||||||
if f.serveDns {
|
if f.serveDns {
|
||||||
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
||||||
|
|
|
@ -11,12 +11,13 @@ import (
|
||||||
var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
|
var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
|
||||||
|
|
||||||
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
|
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
|
||||||
var ips []uint32 = []uint32{9000}
|
var ips []uint32
|
||||||
|
|
||||||
func Test_NewHandshakeManagerIndex(t *testing.T) {
|
func Test_NewHandshakeManagerIndex(t *testing.T) {
|
||||||
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||||
|
|
||||||
|
@ -54,9 +55,10 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||||
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||||
|
@ -102,9 +104,10 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
||||||
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||||
|
@ -114,7 +117,7 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||||
|
|
||||||
hostinfo := blah.AddVpnIP(101010)
|
hostinfo := blah.AddVpnIP(vpnIP)
|
||||||
// Pretned we have an index too
|
// Pretned we have an index too
|
||||||
blah.AddIndexHostInfo(12341234, hostinfo)
|
blah.AddIndexHostInfo(12341234, hostinfo)
|
||||||
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234))
|
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234))
|
||||||
|
@ -147,12 +150,12 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
||||||
l.Infoln(cumulative, next_tick)
|
l.Infoln(cumulative, next_tick)
|
||||||
blah.NextOutboundHandshakeTimerTick(next_tick)
|
blah.NextOutboundHandshakeTimerTick(next_tick)
|
||||||
*/
|
*/
|
||||||
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
|
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP))
|
||||||
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
|
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
|
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
|
||||||
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
|
|
38
hostmap.go
38
hostmap.go
|
@ -29,6 +29,7 @@ type HostMap struct {
|
||||||
preferredRanges []*net.IPNet
|
preferredRanges []*net.IPNet
|
||||||
vpnCIDR *net.IPNet
|
vpnCIDR *net.IPNet
|
||||||
defaultRoute uint32
|
defaultRoute uint32
|
||||||
|
unsafeRoutes *CIDRTree
|
||||||
}
|
}
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
|
@ -46,6 +47,7 @@ type HostInfo struct {
|
||||||
localIndexId uint32
|
localIndexId uint32
|
||||||
hostId uint32
|
hostId uint32
|
||||||
recvError int
|
recvError int
|
||||||
|
remoteCidr *CIDRTree
|
||||||
|
|
||||||
lastRoam time.Time
|
lastRoam time.Time
|
||||||
lastRoamRemote *udpAddr
|
lastRoamRemote *udpAddr
|
||||||
|
@ -82,6 +84,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *
|
||||||
preferredRanges: preferredRanges,
|
preferredRanges: preferredRanges,
|
||||||
vpnCIDR: vpnCIDR,
|
vpnCIDR: vpnCIDR,
|
||||||
defaultRoute: 0,
|
defaultRoute: 0,
|
||||||
|
unsafeRoutes: NewCIDRTree(),
|
||||||
}
|
}
|
||||||
return &m
|
return &m
|
||||||
}
|
}
|
||||||
|
@ -286,13 +289,6 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
|
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
|
||||||
if hm.vpnCIDR.Contains(int2ip(vpnIp)) == false && hm.defaultRoute != 0 {
|
|
||||||
// FIXME: this shouldn't ship
|
|
||||||
d := hm.Hosts[hm.defaultRoute]
|
|
||||||
if d != nil {
|
|
||||||
return hm.Hosts[hm.defaultRoute], nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
if h, ok := hm.Hosts[vpnIp]; ok {
|
if h, ok := hm.Hosts[vpnIp]; ok {
|
||||||
if promoteIfce != nil {
|
if promoteIfce != nil {
|
||||||
|
@ -314,6 +310,15 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
|
||||||
|
r := hm.unsafeRoutes.MostSpecificContains(ip)
|
||||||
|
if r != nil {
|
||||||
|
return r.(uint32)
|
||||||
|
} else {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool {
|
func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool {
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
if i, ok := hm.Hosts[vpnIP]; ok {
|
if i, ok := hm.Hosts[vpnIP]; ok {
|
||||||
|
@ -387,6 +392,13 @@ func (hm *HostMap) Punchy(conn *udpConn) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
|
||||||
|
for _, r := range *routes {
|
||||||
|
l.WithField("route", r.route).WithField("via", r.via).Error("Adding UNSAFE Route")
|
||||||
|
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (i *HostInfo) MarshalJSON() ([]byte, error) {
|
func (i *HostInfo) MarshalJSON() ([]byte, error) {
|
||||||
return json.Marshal(m{
|
return json.Marshal(m{
|
||||||
"remote": i.remote,
|
"remote": i.remote,
|
||||||
|
@ -610,6 +622,18 @@ func (i *HostInfo) RecvErrorExceeded() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
||||||
|
remoteCidr := NewCIDRTree()
|
||||||
|
for _, ip := range c.Details.Ips {
|
||||||
|
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, n := range c.Details.Subnets {
|
||||||
|
remoteCidr.AddCIDR(n, struct{}{})
|
||||||
|
}
|
||||||
|
i.remoteCidr = remoteCidr
|
||||||
|
}
|
||||||
|
|
||||||
//########################
|
//########################
|
||||||
|
|
||||||
func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
|
func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
|
||||||
|
|
|
@ -74,26 +74,26 @@ func TestHostmap(t *testing.T) {
|
||||||
a := NewUDPAddrFromString("10.127.0.3:11111")
|
a := NewUDPAddrFromString("10.127.0.3:11111")
|
||||||
b := NewUDPAddrFromString("1.0.0.1:22222")
|
b := NewUDPAddrFromString("1.0.0.1:22222")
|
||||||
y := NewUDPAddrFromString("10.128.0.3:11111")
|
y := NewUDPAddrFromString("10.128.0.3:11111")
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
|
||||||
|
|
||||||
info, _ := m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1")))
|
info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
|
||||||
|
|
||||||
// There should be three remotes in the host map
|
// There should be three remotes in the host map
|
||||||
assert.Equal(t, 3, len(info.Remotes))
|
assert.Equal(t, 3, len(info.Remotes))
|
||||||
|
|
||||||
// Adding an identical remote should not change the count
|
// Adding an identical remote should not change the count
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
|
||||||
assert.Equal(t, 3, len(info.Remotes))
|
assert.Equal(t, 3, len(info.Remotes))
|
||||||
|
|
||||||
// Adding a fresh remote should add one
|
// Adding a fresh remote should add one
|
||||||
y = NewUDPAddrFromString("10.18.0.3:11111")
|
y = NewUDPAddrFromString("10.18.0.3:11111")
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
|
||||||
assert.Equal(t, 4, len(info.Remotes))
|
assert.Equal(t, 4, len(info.Remotes))
|
||||||
|
|
||||||
// Query and reference remote should get the first one (and not nil)
|
// Query and reference remote should get the first one (and not nil)
|
||||||
info, _ = m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1")))
|
info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
|
||||||
assert.NotNil(t, info.remote)
|
assert.NotNil(t, info.remote)
|
||||||
|
|
||||||
// Promotion should ensure that the best remote is chosen (y)
|
// Promotion should ensure that the best remote is chosen (y)
|
||||||
|
@ -111,9 +111,9 @@ func TestHostmapdebug(t *testing.T) {
|
||||||
a := NewUDPAddrFromString("10.127.0.3:11111")
|
a := NewUDPAddrFromString("10.127.0.3:11111")
|
||||||
b := NewUDPAddrFromString("1.0.0.1:22222")
|
b := NewUDPAddrFromString("1.0.0.1:22222")
|
||||||
y := NewUDPAddrFromString("10.128.0.3:11111")
|
y := NewUDPAddrFromString("10.128.0.3:11111")
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
|
||||||
|
|
||||||
//t.Errorf("%s", m.DebugRemotes(1))
|
//t.Errorf("%s", m.DebugRemotes(1))
|
||||||
}
|
}
|
||||||
|
@ -157,9 +157,9 @@ func BenchmarkHostmappromote2(b *testing.B) {
|
||||||
y := NewUDPAddrFromString("10.128.0.3:11111")
|
y := NewUDPAddrFromString("10.128.0.3:11111")
|
||||||
a := NewUDPAddrFromString("10.127.0.3:11111")
|
a := NewUDPAddrFromString("10.127.0.3:11111")
|
||||||
g := NewUDPAddrFromString("1.0.0.1:22222")
|
g := NewUDPAddrFromString("1.0.0.1:22222")
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), g)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
|
||||||
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
|
||||||
}
|
}
|
||||||
b.Errorf("hi")
|
b.Errorf("hi")
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
ci.queueLock.Unlock()
|
ci.queueLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
if !f.firewall.Drop(packet, *fwPacket, false, ci.peerCert, trustedCAs) {
|
if !f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) {
|
||||||
f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
|
f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
|
||||||
if f.lightHouse != nil && *ci.messageCounter%5000 == 0 {
|
if f.lightHouse != nil && *ci.messageCounter%5000 == 0 {
|
||||||
f.lightHouse.Query(fwPacket.RemoteIP, f)
|
f.lightHouse.Query(fwPacket.RemoteIP, f)
|
||||||
|
@ -52,6 +52,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
||||||
|
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
|
||||||
|
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
|
||||||
|
}
|
||||||
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
|
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
|
||||||
|
|
||||||
//if err != nil || hostinfo.ConnectionState == nil {
|
//if err != nil || hostinfo.ConnectionState == nil {
|
||||||
|
@ -97,7 +100,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
if f.firewall.Drop(p, *fp, false, hostInfo.ConnectionState.peerCert, trustedCAs) {
|
if f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) {
|
||||||
l.WithField("fwPacket", fp).Debugln("dropping cached packet")
|
l.WithField("fwPacket", fp).Debugln("dropping cached packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
9
main.go
9
main.go
|
@ -82,6 +82,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Could not parse tun.routes")
|
l.WithError(err).Fatal("Could not parse tun.routes")
|
||||||
}
|
}
|
||||||
|
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
||||||
|
if err != nil {
|
||||||
|
l.WithError(err).Fatal("Could not parse tun.unsafe_routes")
|
||||||
|
}
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||||
wireSSHReload(ssh, config)
|
wireSSHReload(ssh, config)
|
||||||
|
@ -107,8 +111,9 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||||
tun, err := newTun(
|
tun, err := newTun(
|
||||||
config.GetString("tun.dev", ""),
|
config.GetString("tun.dev", ""),
|
||||||
tunCidr,
|
tunCidr,
|
||||||
config.GetInt("tun.mtu", 1300),
|
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||||
routes,
|
routes,
|
||||||
|
unsafeRoutes,
|
||||||
config.GetInt("tun.tx_queue", 500),
|
config.GetInt("tun.tx_queue", 500),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -163,6 +168,8 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||||
|
|
||||||
hostMap := NewHostMap("main", tunCidr, preferredRanges)
|
hostMap := NewHostMap("main", tunCidr, preferredRanges)
|
||||||
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
|
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
|
||||||
|
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
||||||
|
|
||||||
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
|
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -255,13 +255,6 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// TODO: This breaks subnet routing and needs to also check range of ip subnet
|
|
||||||
/*
|
|
||||||
if len(res) > 16 && binary.BigEndian.Uint32(res[12:16]) != ip2int(ci.peerCert.Details.Ips[0].IP) {
|
|
||||||
l.Debugf("Host %s tried to spoof packet as %s.", ci.peerCert.Details.Ips[0].IP, IntIp(binary.BigEndian.Uint32(res[12:16])))
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet")
|
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet")
|
||||||
|
@ -283,7 +276,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.firewall.Drop(out, *fwPacket, true, hostinfo.ConnectionState.peerCert, trustedCAs) {
|
if f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) {
|
||||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
|
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
|
||||||
Debugln("dropping inbound packet")
|
Debugln("dropping inbound packet")
|
||||||
return
|
return
|
||||||
|
|
|
@ -6,9 +6,12 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DEFAULT_MTU = 1300
|
||||||
|
|
||||||
type route struct {
|
type route struct {
|
||||||
mtu int
|
mtu int
|
||||||
route *net.IPNet
|
route *net.IPNet
|
||||||
|
via *net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
||||||
|
@ -81,6 +84,91 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
r := config.Get("tun.unsafe_routes")
|
||||||
|
if r == nil {
|
||||||
|
return []route{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rawRoutes, ok := r.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("tun.unsafe_routes is not an array")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rawRoutes) < 1 {
|
||||||
|
return []route{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := make([]route, len(rawRoutes))
|
||||||
|
for i, r := range rawRoutes {
|
||||||
|
m, ok := r.(map[interface{}]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
rMtu, ok := m["mtu"]
|
||||||
|
if !ok {
|
||||||
|
rMtu = config.GetInt("tun.mtu", DEFAULT_MTU)
|
||||||
|
}
|
||||||
|
|
||||||
|
mtu, ok := rMtu.(int)
|
||||||
|
if !ok {
|
||||||
|
mtu, err = strconv.Atoi(rMtu.(string))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is not an integer: %v", i+1, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if mtu < 500 {
|
||||||
|
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is below 500: %v", i+1, mtu)
|
||||||
|
}
|
||||||
|
|
||||||
|
rVia, ok := m["via"]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
via, ok := rVia.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: %v", i+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nVia := net.ParseIP(via)
|
||||||
|
if nVia == nil {
|
||||||
|
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via)
|
||||||
|
}
|
||||||
|
|
||||||
|
rRoute, ok := m["route"]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := route{
|
||||||
|
via: &nVia,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipWithin(network, r.route) {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
|
||||||
|
i+1,
|
||||||
|
r.route.String(),
|
||||||
|
network.String(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
routes[i] = r
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
|
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
|
||||||
// Make sure o contains the lowest form of i
|
// Make sure o contains the lowest form of i
|
||||||
if !o.Contains(i.IP.Mask(i.Mask)) {
|
if !o.Contains(i.IP.Mask(i.Mask)) {
|
||||||
|
|
|
@ -17,10 +17,13 @@ type Tun struct {
|
||||||
*water.Interface
|
*water.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
if len(routes) > 0 {
|
if len(routes) > 0 {
|
||||||
return nil, fmt.Errorf("Route MTU not supported in Darwin")
|
return nil, fmt.Errorf("Route MTU not supported in Darwin")
|
||||||
}
|
}
|
||||||
|
if len(unsafeRoutes) > 0 {
|
||||||
|
return nil, fmt.Errorf("unsafeRoutes not supported in Darwin")
|
||||||
|
}
|
||||||
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
|
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
|
||||||
return &Tun{
|
return &Tun{
|
||||||
Cidr: cidr,
|
Cidr: cidr,
|
||||||
|
|
19
tun_linux.go
19
tun_linux.go
|
@ -21,6 +21,7 @@ type Tun struct {
|
||||||
DefaultMTU int
|
DefaultMTU int
|
||||||
TXQueueLen int
|
TXQueueLen int
|
||||||
Routes []route
|
Routes []route
|
||||||
|
UnsafeRoutes []route
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifReq struct {
|
type ifReq struct {
|
||||||
|
@ -74,7 +75,7 @@ type ifreqQLEN struct {
|
||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -106,6 +107,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||||
DefaultMTU: defaultMTU,
|
DefaultMTU: defaultMTU,
|
||||||
TXQueueLen: txQueueLen,
|
TXQueueLen: txQueueLen,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
|
UnsafeRoutes: unsafeRoutes,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -238,6 +240,21 @@ func (c Tun) Activate() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
|
for _, r := range c.UnsafeRoutes {
|
||||||
|
nr := netlink.Route{
|
||||||
|
LinkIndex: link.Attrs().Index,
|
||||||
|
Dst: r.route,
|
||||||
|
MTU: r.mtu,
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = netlink.RouteAdd(&nr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set mtu %v on route %v; %v", r.mtu, r.route, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Run the interface
|
// Run the interface
|
||||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||||
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
|
|
|
@ -16,10 +16,13 @@ type Tun struct {
|
||||||
*water.Interface
|
*water.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
if len(routes) > 0 {
|
if len(routes) > 0 {
|
||||||
return nil, fmt.Errorf("Route MTU not supported in Windows")
|
return nil, fmt.Errorf("Route MTU not supported in Windows")
|
||||||
}
|
}
|
||||||
|
if len(unsafeRoutes) > 0 {
|
||||||
|
return nil, fmt.Errorf("unsafeRoutes not supported in Windows")
|
||||||
|
}
|
||||||
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
||||||
return &Tun{
|
return &Tun{
|
||||||
Cidr: cidr,
|
Cidr: cidr,
|
||||||
|
|
Loading…
Reference in New Issue