diff --git a/overlay/route.go b/overlay/route.go index 5f0033c..b40067f 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -9,13 +9,14 @@ import ( "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" ) type Route struct { MTU int Metric int Cidr *net.IPNet - Via *net.IP + Via *iputil.VpnIp } func makeRouteTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) { @@ -26,7 +27,7 @@ func makeRouteTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) { } if r.Via != nil { - routeTree.AddCIDR(r.Cidr, r.Via) + routeTree.AddCIDR(r.Cidr, *r.Via) } } return routeTree, nil @@ -180,8 +181,10 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1) } + viaVpnIp := iputil.Ip2VpnIp(nVia) + r := Route{ - Via: &nVia, + Via: &viaVpnIp, MTU: mtu, Metric: metric, } diff --git a/overlay/route_test.go b/overlay/route_test.go index 2128ddb..04a8da0 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) @@ -235,3 +236,35 @@ func Test_parseUnsafeRoutes(t *testing.T) { t.Fatal("Did not see both unsafe_routes") } } + +func Test_makeRouteTree(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + _, n, _ := net.ParseCIDR("10.0.0.0/24") + + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ + map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, + map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, + }} + routes, err := parseUnsafeRoutes(c, n) + assert.NoError(t, err) + assert.Len(t, routes, 2) + routeTree, err := makeRouteTree(routes, true) + assert.NoError(t, err) + + ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) + r := routeTree.MostSpecificContains(ip) + assert.NotNil(t, r) + assert.IsType(t, iputil.VpnIp(0), r) + assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) + + ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) + r = routeTree.MostSpecificContains(ip) + assert.NotNil(t, r) + assert.IsType(t, iputil.VpnIp(0), r) + assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) + + ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) + r = routeTree.MostSpecificContains(ip) + assert.Nil(t, r) +} diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index 745f554..a86586d 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -97,7 +97,7 @@ func (t *winTun) Activate() error { // Add our unsafe route routes = append(routes, &winipcfg.RouteData{ Destination: *r.Cidr, - NextHop: *r.Via, + NextHop: r.Via.ToIP(), Metric: uint32(r.Metric), }) }