diff --git a/.gitignore b/.gitignore index 059c35d..515eb78 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,6 @@ *.so *.dylib wesher -wg # Test binary, build with `go test -c` *.test diff --git a/tests/wg b/tests/wg new file mode 100755 index 0000000..bb95c6f --- /dev/null +++ b/tests/wg @@ -0,0 +1,19 @@ +#!/bin/sh + +# docker run --rm -it --hostname test3 --name test3 --network wesher_test --volume `pwd`:/app golang:1.12-alpine +# cd /app +# apk add git +# export PATH=/app:$PATH +# go run . --loglevel debug --clusterkey 'ILICZ3yBMCGAWNIq5Pn0bewBVimW3Q2yRVJ/Be+b1Uc=' + +case $1 in + genkey) + echo "ILICZ3yBMCGAWNIq5Pn0bewBVimW3Q2yRVJ/Be+b1Uc=" + ;; + pubkey) + read x + echo "VceweY6x/QdGXEQ6frXrSd8CwUAInUmqIc6G/qi8FHo=" + ;; + *) + ;; +esac \ No newline at end of file diff --git a/wireguard.go b/wireguard.go index 3ef2b5c..bd4b7b7 100644 --- a/wireguard.go +++ b/wireguard.go @@ -1,8 +1,8 @@ package main import ( - "crypto/md5" "fmt" + "hash/fnv" "net" "os" "os/exec" @@ -33,8 +33,11 @@ type wgState struct { PubKey string } +var wgPath = "wg" +var wgQuickPath = "wg-quick" + func newWGConfig(iface string, port int) (*wgState, error) { - if err := exec.Command("wg").Run(); err != nil { + if err := exec.Command(wgPath).Run(); err != nil { return nil, fmt.Errorf("could not exec wireguard: %s", err) } @@ -54,16 +57,18 @@ func newWGConfig(iface string, port int) (*wgState, error) { func (wg *wgState) assignOverlayAddr(ipnet *net.IPNet, name string) { // TODO: this is way too brittle and opaque - ip := []byte(ipnet.IP) bits, size := ipnet.Mask.Size() + ip := make([]byte, net.IPv6len) + copy(ip, []byte(ipnet.IP)) - h := md5.New() + h := fnv.New128a() h.Write([]byte(name)) hb := h.Sum(nil) - for i := 0; i < (size-bits)/8; i++ { - ip[size/8-i-1] = hb[i] + for i := 1; i <= (size-bits)/8; i++ { + ip[len(ip)-i] = hb[len(hb)-i] } + wg.OverlayAddr = net.IP(ip) } @@ -84,25 +89,25 @@ func (wg *wgState) writeConf(nodes []node) error { } func (wg *wgState) downInterface() error { - if err := exec.Command("wg", "show", wg.iface).Run(); err != nil { + if err := exec.Command(wgPath, "show", wg.iface).Run(); err != nil { return nil // assume a failure means the interface is not there } - return exec.Command("wg-quick", "down", wg.iface).Run() + return exec.Command(wgQuickPath, "down", wg.iface).Run() } func (wg *wgState) upInterface() error { - return exec.Command("wg-quick", "up", wg.iface).Run() + return exec.Command(wgQuickPath, "up", wg.iface).Run() } func wgKeyPair() (string, string, error) { - cmd := exec.Command("wg", "genkey") + cmd := exec.Command(wgPath, "genkey") outPriv := strings.Builder{} cmd.Stdout = &outPriv if err := cmd.Run(); err != nil { return "", "", err } - cmd = exec.Command("wg", "pubkey") + cmd = exec.Command(wgPath, "pubkey") outPub := strings.Builder{} cmd.Stdout = &outPub cmd.Stdin = strings.NewReader(outPriv.String()) diff --git a/wireguard_test.go b/wireguard_test.go new file mode 100644 index 0000000..7f71d1b --- /dev/null +++ b/wireguard_test.go @@ -0,0 +1,87 @@ +package main + +import ( + "net" + "reflect" + "testing" +) + +func init() { + wgPath = "tests/wg" + wgQuickPath = "tests/wg-quick" +} + +func Test_wgKeyPair(t *testing.T) { + tests := []struct { + name string + want string + want1 string + wantErr bool + }{ + // see tests/wg for values + {"generate fixed values", "ILICZ3yBMCGAWNIq5Pn0bewBVimW3Q2yRVJ/Be+b1Uc=", "VceweY6x/QdGXEQ6frXrSd8CwUAInUmqIc6G/qi8FHo=", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1, err := wgKeyPair() + if (err != nil) != tt.wantErr { + t.Errorf("wgKeyPair() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("wgKeyPair() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("wgKeyPair() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func Test_wgState_assignOverlayAddr(t *testing.T) { + type args struct { + ipnet *net.IPNet + name string + } + tests := []struct { + name string + args args + want net.IP + }{ + { + "assign in big ipv4 net", + args{&net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}, "test"}, + net.ParseIP("10.221.153.165"), // if we ever have to change this, we should probably also mark it as a breaking change + }, + { + "assign in ipv6 net", + args{&net.IPNet{IP: net.ParseIP("2001:db8::"), Mask: net.CIDRMask(32, 128)}, "test"}, + net.ParseIP("2001:db8:c575:7277:b806:e994:13dd:99a5"), // if we ever have to change this, we should probably also mark it as a breaking change + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wg := &wgState{} + wg.assignOverlayAddr(tt.args.ipnet, tt.args.name) + + if !reflect.DeepEqual(wg.OverlayAddr, tt.want) { + t.Errorf("assignOverlayAddr() set = %s, want %s", wg.OverlayAddr, tt.want) + } + }) + } +} + +// This is just to ensure - if we ever change the hashing function - that it spreads the results in a way that at least +// avoids the most obvious collisions. +func Test_wgState_assignOverlayAddr_no_obvious_collisions(t *testing.T) { + ipnet := &net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(24, 32)} + assignments := make(map[string]string) + for _, n := range []string{"test", "test1", "test2", "1test", "2test"} { + wg := &wgState{} + wg.assignOverlayAddr(ipnet, n) + if assigned, ok := assignments[wg.OverlayAddr.String()]; ok { + t.Errorf("IP assignment collision: hash(%s) = hash(%s)", n, assigned) + } + assignments[wg.OverlayAddr.String()] = n + } +}