diff --git a/connection_manager.go b/connection_manager.go index 33d3265..1785444 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -1,9 +1,10 @@ package nebula import ( - "github.com/sirupsen/logrus" "sync" "time" + + "github.com/sirupsen/logrus" ) // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet diff --git a/connection_manager_test.go b/connection_manager_test.go index 81bb049..68b9d02 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -10,12 +10,13 @@ import ( "github.com/stretchr/testify/assert" ) -var vpnIP uint32 = uint32(12341234) +var vpnIP uint32 func Test_NewConnectionManagerTest(t *testing.T) { //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + vpnIP = ip2int(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects diff --git a/examples/config.yml b/examples/config.yml index d48c347..5491eac 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -100,6 +100,13 @@ tun: routes: #- mtu: 8800 # route: 10.0.0.0/16 + # Unsafe routes allows you to route traffic over nebula to non-nebula nodes + # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula + # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate + unsafe_routes: + - route: 172.16.1.0/24 + via: 192.168.100.99 + # TODO # Configure logging level diff --git a/firewall.go b/firewall.go index 0ada61b..1a05256 100644 --- a/firewall.go +++ b/firewall.go @@ -343,12 +343,17 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa return nil } -func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool { // Check if we spoke to this tuple, if we did then allow this packet if f.inConns(packet, fp, incoming) { return false } + // Make sure remote address matches nebula certificate + if h.remoteCidr.Contains(fp.RemoteIP) == nil { + return true + } + // Make sure we are supposed to be handling this local ip address if f.localIps.Contains(fp.LocalIP) == nil { return true @@ -360,7 +365,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert } // We now know which firewall table to check against - if !table.match(fp, incoming, c, caPool) { + if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) { return true } diff --git a/firewall_test.go b/firewall_test.go index 0a18254..7ae1c56 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -3,13 +3,14 @@ package nebula import ( "encoding/binary" "errors" - "github.com/rcrowley/go-metrics" - "github.com/slackhq/nebula/cert" - "github.com/stretchr/testify/assert" "math" "net" "testing" "time" + + "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" + "github.com/stretchr/testify/assert" ) func TestNewFirewall(t *testing.T) { @@ -134,7 +135,7 @@ func TestFirewall_AddRule(t *testing.T) { func TestFirewall_Drop(t *testing.T) { p := FirewallPacket{ ip2int(net.IPv4(1, 2, 3, 4)), - 101, + ip2int(net.IPv4(1, 2, 3, 4)), 10, 90, fwProtoUDP, @@ -154,39 +155,51 @@ func TestFirewall_Drop(t *testing.T) { Issuer: "signer-shasum", }, } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + } + h.CreateRemoteCIDR(&c) fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) cp := cert.NewCAPool() // Drop outbound - assert.True(t, fw.Drop([]byte{}, p, false, &c, cp)) + assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) // Allow inbound - assert.False(t, fw.Drop([]byte{}, p, true, &c, cp)) + assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) // Allow outbound because conntrack - assert.False(t, fw.Drop([]byte{}, p, false, &c, cp)) + assert.False(t, fw.Drop([]byte{}, p, false, &h, cp)) + + // test remote mismatch + oldRemote := p.RemoteIP + p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10)) + assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) + p.RemoteIP = oldRemote // test caSha assertions true fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum")) - assert.False(t, fw.Drop([]byte{}, p, true, &c, cp)) + assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) // test caSha assertions false fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope")) - assert.True(t, fw.Drop([]byte{}, p, true, &c, cp)) + assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) // test caName true cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", "")) - assert.False(t, fw.Drop([]byte{}, p, true, &c, cp)) + assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) // test caName false cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", "")) - assert.True(t, fw.Drop([]byte{}, p, true, &c, cp)) + assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -286,7 +299,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { func TestFirewall_Drop2(t *testing.T) { p := FirewallPacket{ ip2int(net.IPv4(1, 2, 3, 4)), - 101, + ip2int(net.IPv4(1, 2, 3, 4)), 10, 90, fwProtoUDP, @@ -305,6 +318,12 @@ func TestFirewall_Drop2(t *testing.T) { InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, }, } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + } + h.CreateRemoteCIDR(&c) c1 := cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ @@ -313,15 +332,21 @@ func TestFirewall_Drop2(t *testing.T) { InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, }, } + h1 := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c1, + }, + } + h1.CreateRemoteCIDR(&c1) fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) cp := cert.NewCAPool() - // c1 lacks the proper groups - assert.True(t, fw.Drop([]byte{}, p, true, &c1, cp)) + // h1/c1 lacks the proper groups + assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp)) // c has the proper groups - assert.False(t, fw.Drop([]byte{}, p, true, &c, cp)) + assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) } func BenchmarkLookup(b *testing.B) { diff --git a/handshake_ix.go b/handshake_ix.go index 54a4239..0e29032 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -205,6 +205,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ //hostinfo.ClearRemotes() hostinfo.AddRemote(*addr) + hostinfo.CreateRemoteCIDR(remoteCert) f.lightHouse.AddRemoteAndReset(ip, addr) if f.serveDns { dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) @@ -314,6 +315,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ //hostinfo.ClearRemotes() f.hostMap.AddRemote(ip, addr) + hostinfo.CreateRemoteCIDR(remoteCert) f.lightHouse.AddRemoteAndReset(ip, addr) if f.serveDns { dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 6822b7c..bc4fc95 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -11,12 +11,13 @@ import ( var indexes []uint32 = []uint32{1000, 2000, 3000, 4000} //var ips []uint32 = []uint32{9000, 9999999, 3, 292394923} -var ips []uint32 = []uint32{9000} +var ips []uint32 func Test_NewHandshakeManagerIndex(t *testing.T) { - _, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))} preferredRanges := []*net.IPNet{localrange} mainHM := NewHostMap("test", vpncidr, preferredRanges) @@ -54,9 +55,10 @@ func Test_NewHandshakeManagerIndex(t *testing.T) { } func Test_NewHandshakeManagerVpnIP(t *testing.T) { - _, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))} preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap("test", vpncidr, preferredRanges) @@ -102,9 +104,10 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) { } func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { - _, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + vpnIP = ip2int(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap("test", vpncidr, preferredRanges) @@ -114,7 +117,7 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) - hostinfo := blah.AddVpnIP(101010) + hostinfo := blah.AddVpnIP(vpnIP) // Pretned we have an index too blah.AddIndexHostInfo(12341234, hostinfo) assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234)) @@ -147,12 +150,12 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { l.Infoln(cumulative, next_tick) blah.NextOutboundHandshakeTimerTick(next_tick) */ - assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010)) + assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP)) assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234)) } func Test_NewHandshakeManagerIndexcleanup(t *testing.T) { - _, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") preferredRanges := []*net.IPNet{localrange} diff --git a/hostmap.go b/hostmap.go index 20f8ce5..fbe1d64 100644 --- a/hostmap.go +++ b/hostmap.go @@ -29,6 +29,7 @@ type HostMap struct { preferredRanges []*net.IPNet vpnCIDR *net.IPNet defaultRoute uint32 + unsafeRoutes *CIDRTree } type HostInfo struct { @@ -46,6 +47,7 @@ type HostInfo struct { localIndexId uint32 hostId uint32 recvError int + remoteCidr *CIDRTree lastRoam time.Time lastRoamRemote *udpAddr @@ -82,6 +84,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) * preferredRanges: preferredRanges, vpnCIDR: vpnCIDR, defaultRoute: 0, + unsafeRoutes: NewCIDRTree(), } return &m } @@ -286,13 +289,6 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn } func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) { - if hm.vpnCIDR.Contains(int2ip(vpnIp)) == false && hm.defaultRoute != 0 { - // FIXME: this shouldn't ship - d := hm.Hosts[hm.defaultRoute] - if d != nil { - return hm.Hosts[hm.defaultRoute], nil - } - } hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { if promoteIfce != nil { @@ -314,6 +310,15 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, } } +func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 { + r := hm.unsafeRoutes.MostSpecificContains(ip) + if r != nil { + return r.(uint32) + } else { + return 0 + } +} + func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool { hm.RLock() if i, ok := hm.Hosts[vpnIP]; ok { @@ -387,6 +392,13 @@ func (hm *HostMap) Punchy(conn *udpConn) { } } +func (hm *HostMap) addUnsafeRoutes(routes *[]route) { + for _, r := range *routes { + l.WithField("route", r.route).WithField("via", r.via).Error("Adding UNSAFE Route") + hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via)) + } +} + func (i *HostInfo) MarshalJSON() ([]byte, error) { return json.Marshal(m{ "remote": i.remote, @@ -610,6 +622,18 @@ func (i *HostInfo) RecvErrorExceeded() bool { return true } +func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { + remoteCidr := NewCIDRTree() + for _, ip := range c.Details.Ips { + remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) + } + + for _, n := range c.Details.Subnets { + remoteCidr.AddCIDR(n, struct{}{}) + } + i.remoteCidr = remoteCidr +} + //######################## func NewHostInfoDest(addr *udpAddr) *HostInfoDest { diff --git a/hostmap_test.go b/hostmap_test.go index de5e198..f6579c7 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -74,26 +74,26 @@ func TestHostmap(t *testing.T) { a := NewUDPAddrFromString("10.127.0.3:11111") b := NewUDPAddrFromString("1.0.0.1:22222") y := NewUDPAddrFromString("10.128.0.3:11111") - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a) - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b) - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) - info, _ := m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1"))) + info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1"))) // There should be three remotes in the host map assert.Equal(t, 3, len(info.Remotes)) // Adding an identical remote should not change the count - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) assert.Equal(t, 3, len(info.Remotes)) // Adding a fresh remote should add one y = NewUDPAddrFromString("10.18.0.3:11111") - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) assert.Equal(t, 4, len(info.Remotes)) // Query and reference remote should get the first one (and not nil) - info, _ = m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1"))) + info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1"))) assert.NotNil(t, info.remote) // Promotion should ensure that the best remote is chosen (y) @@ -111,9 +111,9 @@ func TestHostmapdebug(t *testing.T) { a := NewUDPAddrFromString("10.127.0.3:11111") b := NewUDPAddrFromString("1.0.0.1:22222") y := NewUDPAddrFromString("10.128.0.3:11111") - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a) - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b) - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) //t.Errorf("%s", m.DebugRemotes(1)) } @@ -157,9 +157,9 @@ func BenchmarkHostmappromote2(b *testing.B) { y := NewUDPAddrFromString("10.128.0.3:11111") a := NewUDPAddrFromString("10.127.0.3:11111") g := NewUDPAddrFromString("1.0.0.1:22222") - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a) - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), g) - m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g) + m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) } b.Errorf("hi") diff --git a/inside.go b/inside.go index 34022aa..6930035 100644 --- a/inside.go +++ b/inside.go @@ -39,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, ci.queueLock.Unlock() } - if !f.firewall.Drop(packet, *fwPacket, false, ci.peerCert, trustedCAs) { + if !f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) { f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out) if f.lightHouse != nil && *ci.messageCounter%5000 == 0 { f.lightHouse.Query(fwPacket.RemoteIP, f) @@ -52,6 +52,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, } func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { + if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false { + vpnIp = f.hostMap.queryUnsafeRoute(vpnIp) + } hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f) //if err != nil || hostinfo.ConnectionState == nil { @@ -97,7 +100,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, } // check if packet is in outbound fw rules - if f.firewall.Drop(p, *fp, false, hostInfo.ConnectionState.peerCert, trustedCAs) { + if f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) { l.WithField("fwPacket", fp).Debugln("dropping cached packet") return } diff --git a/main.go b/main.go index 39c10ab..7169d66 100644 --- a/main.go +++ b/main.go @@ -79,6 +79,7 @@ func Main(configPath string, configTest bool, buildVersion string) { // TODO: make sure mask is 4 bytes tunCidr := cs.certificate.Details.Ips[0] routes, err := parseRoutes(config, tunCidr) + unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr) if err != nil { l.WithError(err).Fatal("Could not parse tun.routes") } @@ -109,6 +110,7 @@ func Main(configPath string, configTest bool, buildVersion string) { tunCidr, config.GetInt("tun.mtu", 1300), routes, + unsafeRoutes, config.GetInt("tun.tx_queue", 500), ) if err != nil { @@ -163,6 +165,8 @@ func Main(configPath string, configTest bool, buildVersion string) { hostMap := NewHostMap("main", tunCidr, preferredRanges) hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0")))) + hostMap.addUnsafeRoutes(&unsafeRoutes) + l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created") /* diff --git a/outside.go b/outside.go index 968b8d2..7f9544c 100644 --- a/outside.go +++ b/outside.go @@ -255,13 +255,6 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { var err error - // TODO: This breaks subnet routing and needs to also check range of ip subnet - /* - if len(res) > 16 && binary.BigEndian.Uint32(res[12:16]) != ip2int(ci.peerCert.Details.Ips[0].IP) { - l.Debugf("Host %s tried to spoof packet as %s.", ci.peerCert.Details.Ips[0].IP, IntIp(binary.BigEndian.Uint32(res[12:16]))) - } - */ - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) if err != nil { l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet") @@ -283,7 +276,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return } - if f.firewall.Drop(out, *fwPacket, true, hostinfo.ConnectionState.peerCert, trustedCAs) { + if f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) { l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket). Debugln("dropping inbound packet") return diff --git a/tun_common.go b/tun_common.go index 0731968..57855fa 100644 --- a/tun_common.go +++ b/tun_common.go @@ -9,6 +9,7 @@ import ( type route struct { mtu int route *net.IPNet + via *net.IP } func parseRoutes(config *Config, network *net.IPNet) ([]route, error) { @@ -81,6 +82,74 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) { return routes, nil } +func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) { + var err error + + r := config.Get("tun.unsafe_routes") + if r == nil { + return []route{}, nil + } + + rawRoutes, ok := r.([]interface{}) + if !ok { + return nil, fmt.Errorf("tun.unsafe_routes is not an array") + } + + if len(rawRoutes) < 1 { + return []route{}, nil + } + + routes := make([]route, len(rawRoutes)) + for i, r := range rawRoutes { + m, ok := r.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1) + } + + rVia, ok := m["via"] + if !ok { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1) + } + + via, ok := rVia.(string) + if !ok { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: %v", i+1, err) + } + + nVia := net.ParseIP(via) + if nVia == nil { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via) + } + + rRoute, ok := m["route"] + if !ok { + return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1) + } + + r := route{ + via: &nVia, + } + + _, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + if err != nil { + return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) + } + + if ipWithin(network, r.route) { + return nil, fmt.Errorf( + "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", + i+1, + r.route.String(), + network.String(), + ) + } + + routes[i] = r + } + + return routes, nil +} + func ipWithin(o *net.IPNet, i *net.IPNet) bool { // Make sure o contains the lowest form of i if !o.Contains(i.IP.Mask(i.Mask)) { diff --git a/tun_darwin.go b/tun_darwin.go index 43fc4fd..907add1 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -17,10 +17,13 @@ type Tun struct { *water.Interface } -func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) { +func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { if len(routes) > 0 { return nil, fmt.Errorf("Route MTU not supported in Darwin") } + if len(unsafeRoutes) > 0 { + return nil, fmt.Errorf("unsafeRoutes not supported in Darwin") + } // NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate() return &Tun{ Cidr: cidr, diff --git a/tun_linux.go b/tun_linux.go index f62d6a7..a527abf 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -14,13 +14,14 @@ import ( type Tun struct { io.ReadWriteCloser - fd int - Device string - Cidr *net.IPNet - MaxMTU int - DefaultMTU int - TXQueueLen int - Routes []route + fd int + Device string + Cidr *net.IPNet + MaxMTU int + DefaultMTU int + TXQueueLen int + Routes []route + UnsafeRoutes []route } type ifReq struct { @@ -74,7 +75,7 @@ type ifreqQLEN struct { pad [8]byte } -func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) { +func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -106,6 +107,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, DefaultMTU: defaultMTU, TXQueueLen: txQueueLen, Routes: routes, + UnsafeRoutes: unsafeRoutes, } return } @@ -238,6 +240,20 @@ func (c Tun) Activate() error { } } + // Unsafe path routes + for _, r := range c.UnsafeRoutes { + nr := netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: r.route, + Scope: unix.RT_SCOPE_LINK, + } + + err = netlink.RouteAdd(&nr) + if err != nil { + return fmt.Errorf("failed to set mtu %v on route %v; %v", r.mtu, r.route, err) + } + } + // Run the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { diff --git a/tun_windows.go b/tun_windows.go index 6c740a0..301a3f2 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -16,10 +16,13 @@ type Tun struct { *water.Interface } -func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) { +func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { if len(routes) > 0 { return nil, fmt.Errorf("Route MTU not supported in Windows") } + if len(unsafeRoutes) > 0 { + return nil, fmt.Errorf("unsafeRoutes not supported in Windows") + } // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() return &Tun{ Cidr: cidr,