fix makeRouteTree allowMTU (#611)
With the previous implementation, we check if route.MTU is greater than zero, but it will always be because we set it to the default MTU in parseUnsafeRoutes. This change leaves it as zero in parseUnsafeRoutes so it can be examined later.
This commit is contained in:
parent
15fdabc3ab
commit
068a93d1f4
|
@ -7,6 +7,7 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cidr"
|
"github.com/slackhq/nebula/cidr"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
@ -19,11 +20,11 @@ type Route struct {
|
||||||
Via *iputil.VpnIp
|
Via *iputil.VpnIp
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRouteTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) {
|
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
|
||||||
routeTree := cidr.NewTree4()
|
routeTree := cidr.NewTree4()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !allowMTU && r.MTU > 0 {
|
if !allowMTU && r.MTU > 0 {
|
||||||
return nil, fmt.Errorf("route MTU is not supported in %s", runtime.GOOS)
|
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Via != nil {
|
if r.Via != nil {
|
||||||
|
@ -127,21 +128,19 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
|
||||||
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
|
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
rMtu, ok := m["mtu"]
|
var mtu int
|
||||||
if !ok {
|
if rMtu, ok := m["mtu"]; ok {
|
||||||
rMtu = c.GetInt("tun.mtu", DefaultMTU)
|
mtu, ok = rMtu.(int)
|
||||||
}
|
if !ok {
|
||||||
|
mtu, err = strconv.Atoi(rMtu.(string))
|
||||||
mtu, ok := rMtu.(int)
|
if err != nil {
|
||||||
if !ok {
|
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is not an integer: %v", i+1, err)
|
||||||
mtu, err = strconv.Atoi(rMtu.(string))
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is not an integer: %v", i+1, err)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if mtu < 500 {
|
if mtu != 0 && mtu < 500 {
|
||||||
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is below 500: %v", i+1, mtu)
|
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is below 500: %v", i+1, mtu)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rMetric, ok := m["metric"]
|
rMetric, ok := m["metric"]
|
||||||
|
|
|
@ -191,7 +191,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, n)
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Equal(t, DefaultMTU, routes[0].MTU)
|
assert.Equal(t, 0, routes[0].MTU)
|
||||||
|
|
||||||
// bad mtu
|
// bad mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
|
||||||
|
@ -249,7 +249,7 @@ func Test_makeRouteTree(t *testing.T) {
|
||||||
routes, err := parseUnsafeRoutes(c, n)
|
routes, err := parseUnsafeRoutes(c, n)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeTree, err := makeRouteTree(routes, true)
|
routeTree, err := makeRouteTree(l, routes, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
|
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
|
||||||
|
|
|
@ -77,7 +77,7 @@ type ifreqMTU struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
|
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
|
||||||
routeTree, err := makeRouteTree(routes, false)
|
routeTree, err := makeRouteTree(l, routes, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
|
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
|
||||||
routeTree, err := makeRouteTree(routes, false)
|
routeTree, err := makeRouteTree(l, routes, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ type ifreqQLEN struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
|
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
|
||||||
routeTree, err := makeRouteTree(routes, true)
|
routeTree, err := makeRouteTree(l, routes, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -105,12 +105,16 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
|
||||||
|
|
||||||
maxMTU := defaultMTU
|
maxMTU := defaultMTU
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
if r.MTU == 0 {
|
||||||
|
r.MTU = defaultMTU
|
||||||
|
}
|
||||||
|
|
||||||
if r.MTU > maxMTU {
|
if r.MTU > maxMTU {
|
||||||
maxMTU = r.MTU
|
maxMTU = r.MTU
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
routeTree, err := makeRouteTree(routes, true)
|
routeTree, err := makeRouteTree(l, routes, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ type TestTun struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
|
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
|
||||||
routeTree, err := makeRouteTree(routes, false)
|
routeTree, err := makeRouteTree(l, routes, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cidr"
|
"github.com/slackhq/nebula/cidr"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/songgao/water"
|
"github.com/songgao/water"
|
||||||
|
@ -22,8 +23,8 @@ type waterTun struct {
|
||||||
*water.Interface
|
*water.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWaterTun(cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
|
func newWaterTun(l *logrus.Logger, cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
|
||||||
routeTree, err := makeRouteTree(routes, false)
|
routeTree, err := makeRouteTree(l, routes, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,14 +30,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
|
||||||
}
|
}
|
||||||
|
|
||||||
if useWintun {
|
if useWintun {
|
||||||
device, err := newWinTun(deviceName, cidr, defaultMTU, routes)
|
device, err := newWinTun(l, deviceName, cidr, defaultMTU, routes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
|
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
|
||||||
}
|
}
|
||||||
return device, nil
|
return device, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
device, err := newWaterTun(cidr, defaultMTU, routes)
|
device, err := newWaterTun(l, cidr, defaultMTU, routes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create wintap driver failed, %w", err)
|
return nil, fmt.Errorf("create wintap driver failed, %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cidr"
|
"github.com/slackhq/nebula/cidr"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/wintun"
|
"github.com/slackhq/nebula/wintun"
|
||||||
|
@ -45,7 +46,7 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
|
||||||
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
|
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) {
|
func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) {
|
||||||
guid, err := generateGUIDByDeviceName(deviceName)
|
guid, err := generateGUIDByDeviceName(deviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
||||||
|
@ -56,7 +57,7 @@ func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []Rout
|
||||||
return nil, fmt.Errorf("create TUN device failed: %w", err)
|
return nil, fmt.Errorf("create TUN device failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
routeTree, err := makeRouteTree(routes, false)
|
routeTree, err := makeRouteTree(l, routes, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue