diff --git a/cluster.go b/cluster/cluster.go similarity index 72% rename from cluster.go rename to cluster/cluster.go index e3dd4df..fe73b39 100644 --- a/cluster.go +++ b/cluster/cluster.go @@ -1,4 +1,4 @@ -package main +package cluster import ( "crypto/rand" @@ -10,30 +10,33 @@ import ( "path" "time" + "github.com/costela/wesher/common" "github.com/hashicorp/memberlist" "github.com/mattn/go-isatty" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) -// ClusterState keeps track of information needed to rejoin the cluster -type ClusterState struct { +const KeyLen = 32 + +// State keeps track of information needed to rejoin the cluster +type State struct { ClusterKey []byte - Nodes []node + Nodes []common.Node } -type cluster struct { - localName string // used to avoid LocalNode(); should not change +type Cluster struct { + LocalName string // used to avoid LocalNode(); should not change ml *memberlist.Memberlist getMeta func(int) []byte - state *ClusterState + state *State events chan memberlist.NodeEvent } const statePath = "/var/lib/wesher/state.json" -func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, useIPAsName bool, getMeta func(int) []byte) (*cluster, error) { - state := &ClusterState{} +func New(init bool, clusterKey []byte, bindAddr string, bindPort int, useIPAsName bool, getMeta func(int) []byte) (*Cluster, error) { + state := &State{} if !init { loadState(state) } @@ -58,8 +61,8 @@ func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, use return nil, err } - cluster := cluster{ - localName: ml.LocalNode().Name, + cluster := Cluster{ + LocalName: ml.LocalNode().Name, ml: ml, getMeta: getMeta, // The big channel buffer is a work-around for https://github.com/hashicorp/memberlist/issues/23 @@ -74,21 +77,21 @@ func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, use return &cluster, nil } -func (c *cluster) NotifyConflict(node, other *memberlist.Node) { +func (c *Cluster) NotifyConflict(node, other *memberlist.Node) { logrus.Errorf("node name conflict detected: %s", other.Name) } -func (c *cluster) NodeMeta(limit int) []byte { +func (c *Cluster) NodeMeta(limit int) []byte { return c.getMeta(limit) } // none of these are used -func (c *cluster) NotifyMsg([]byte) {} -func (c *cluster) GetBroadcasts(overhead, limit int) [][]byte { return nil } -func (c *cluster) LocalState(join bool) []byte { return nil } -func (c *cluster) MergeRemoteState(buf []byte, join bool) {} +func (c *Cluster) NotifyMsg([]byte) {} +func (c *Cluster) GetBroadcasts(overhead, limit int) [][]byte { return nil } +func (c *Cluster) LocalState(join bool) []byte { return nil } +func (c *Cluster) MergeRemoteState(buf []byte, join bool) {} -func (c *cluster) join(addrs []string) error { +func (c *Cluster) Join(addrs []string) error { if len(addrs) == 0 { for _, n := range c.state.Nodes { addrs = append(addrs, n.Addr.String()) @@ -103,22 +106,22 @@ func (c *cluster) join(addrs []string) error { return nil } -func (c *cluster) leave() { +func (c *Cluster) Leave() { c.saveState() c.ml.Leave(10 * time.Second) c.ml.Shutdown() //nolint: errcheck } -func (c *cluster) update() { +func (c *Cluster) Update() { c.ml.UpdateNode(1 * time.Second) // we currently do not update after creation } -func (c *cluster) members() <-chan []node { - changes := make(chan []node) +func (c *Cluster) Members() <-chan []common.Node { + changes := make(chan []common.Node) go func() { for { event := <-c.events - if event.Node.Name == c.localName { + if event.Node.Name == c.LocalName { // ignore events about ourselves continue } @@ -131,12 +134,12 @@ func (c *cluster) members() <-chan []node { logrus.Infof("node %s left", event.Node) } - nodes := make([]node, 0) + nodes := make([]common.Node, 0) for _, n := range c.ml.Members() { - if n.Name == c.localName { + if n.Name == c.LocalName { continue } - nodes = append(nodes, node{ + nodes = append(nodes, common.Node{ Name: n.Name, Addr: n.Addr, Meta: n.Meta, @@ -150,12 +153,12 @@ func (c *cluster) members() <-chan []node { return changes } -func computeClusterKey(state *ClusterState, clusterKey []byte) ([]byte, error) { +func computeClusterKey(state *State, clusterKey []byte) ([]byte, error) { if len(clusterKey) == 0 { clusterKey = state.ClusterKey } if len(clusterKey) == 0 { - clusterKey = make([]byte, clusterKeyLen) + clusterKey = make([]byte, KeyLen) _, err := rand.Read(clusterKey) if err != nil { return nil, err @@ -169,7 +172,7 @@ func computeClusterKey(state *ClusterState, clusterKey []byte) ([]byte, error) { return clusterKey, nil } -func (c *cluster) saveState() error { +func (c *Cluster) saveState() error { if err := os.MkdirAll(path.Dir(statePath), 0700); err != nil { return err } @@ -182,7 +185,7 @@ func (c *cluster) saveState() error { return ioutil.WriteFile(statePath, stateOut, 0600) } -func loadState(cs *ClusterState) { +func loadState(cs *State) { content, err := ioutil.ReadFile(statePath) if err != nil { if !os.IsNotExist(err) { @@ -192,7 +195,7 @@ func loadState(cs *ClusterState) { } // avoid partially unmarshalled content by using a temp var - csTmp := &ClusterState{} + csTmp := &State{} if err := json.Unmarshal(content, csTmp); err != nil { logrus.Warnf("could not decode state: %s", err) } else { diff --git a/node.go b/common/node.go similarity index 75% rename from node.go rename to common/node.go index 49da3fc..812dc18 100644 --- a/node.go +++ b/common/node.go @@ -1,4 +1,4 @@ -package main +package common import ( "bytes" @@ -9,25 +9,25 @@ import ( "github.com/sirupsen/logrus" ) -// nodeMeta holds metadata sent over the cluster -type nodeMeta struct { +// NodeMeta holds metadata sent over the cluster +type NodeMeta struct { OverlayAddr net.IPNet PubKey string } // Node holds the memberlist node structure -type node struct { +type Node struct { Name string Addr net.IP Meta []byte - nodeMeta + NodeMeta } -func (n *node) String() string { +func (n *Node) String() string { return n.Addr.String() } -func encodeNodeMeta(nm nodeMeta, limit int) []byte { +func EncodeNodeMeta(nm NodeMeta, limit int) []byte { buf := &bytes.Buffer{} if err := gob.NewEncoder(buf).Encode(nm); err != nil { logrus.Errorf("could not encode local state: %s", err) @@ -40,10 +40,10 @@ func encodeNodeMeta(nm nodeMeta, limit int) []byte { return buf.Bytes() } -func decodeNodeMeta(b []byte) (nodeMeta, error) { +func DecodeNodeMeta(b []byte) (NodeMeta, error) { // TODO: we blindly trust the info we get from the peers; We should be more defensive to limit the damage a leaked // PSK can cause. - nm := nodeMeta{} + nm := NodeMeta{} if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&nm); err != nil { return nm, errors.Wrap(err, "could not decode node meta") } diff --git a/config.go b/config.go index 4a4b1d7..037b1f1 100644 --- a/config.go +++ b/config.go @@ -4,12 +4,11 @@ import ( "fmt" "net" + "github.com/costela/wesher/cluster" "github.com/hashicorp/go-sockaddr" "github.com/stevenroose/gonfig" ) -const clusterKeyLen = 32 - type config struct { ClusterKey []byte `id:"cluster-key" desc:"shared key for cluster membership; must be 32 bytes base64 encoded; will be generated if not provided"` Join []string `desc:"comma separated list of hostnames or IP addresses to existing cluster members; if not provided, will attempt resuming any known state or otherwise wait for further members."` @@ -36,8 +35,8 @@ func loadConfig() (*config, error) { } // perform some validation - if len(config.ClusterKey) != 0 && len(config.ClusterKey) != clusterKeyLen { - return nil, fmt.Errorf("unsupported cluster key length; expected %d, got %d", clusterKeyLen, len(config.ClusterKey)) + if len(config.ClusterKey) != 0 && len(config.ClusterKey) != cluster.KeyLen { + return nil, fmt.Errorf("unsupported cluster key length; expected %d, got %d", cluster.KeyLen, len(config.ClusterKey)) } if bits, _ := ((*net.IPNet)(config.OverlayNet)).Mask.Size(); bits%8 != 0 { diff --git a/main.go b/main.go index 53be385..c540e21 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,10 @@ import ( "time" "github.com/cenkalti/backoff" + "github.com/costela/wesher/cluster" + "github.com/costela/wesher/common" "github.com/costela/wesher/etchosts" + "github.com/costela/wesher/wg" "github.com/sirupsen/logrus" ) @@ -30,28 +33,28 @@ func main() { } logrus.SetLevel(logLevel) - wg, err := newWGConfig(config.Interface, config.WireguardPort) + wg, err := wg.NewWGConfig(config.Interface, config.WireguardPort) if err != nil { logrus.WithError(err).Fatal("could not instantiate wireguard controller") } getMeta := func(limit int) []byte { - return encodeNodeMeta(nodeMeta{ + return common.EncodeNodeMeta(common.NodeMeta{ OverlayAddr: wg.OverlayAddr, PubKey: wg.PubKey.String(), }, limit) } - cluster, err := newCluster(config.Init, config.ClusterKey, config.BindAddr, config.ClusterPort, config.UseIPAsName, getMeta) + cluster, err := cluster.New(config.Init, config.ClusterKey, config.BindAddr, config.ClusterPort, config.UseIPAsName, getMeta) if err != nil { logrus.WithError(err).Fatal("could not create cluster") } - wg.assignOverlayAddr((*net.IPNet)(config.OverlayNet), cluster.localName) - cluster.update() + wg.AssignOverlayAddr((*net.IPNet)(config.OverlayNet), cluster.LocalName) + cluster.Update() - nodec := cluster.members() // avoid deadlocks by starting before join + nodec := cluster.Members() // avoid deadlocks by starting before join if err := backoff.RetryNotify( - func() error { return cluster.join(config.Join) }, + func() error { return cluster.Join(config.Join) }, backoff.NewExponentialBackOff(), func(err error, dur time.Duration) { logrus.WithError(err).Errorf("could not join cluster, retrying in %s", dur) @@ -67,20 +70,20 @@ func main() { select { case rawNodes := <-nodec: logrus.Info("cluster members:\n") - nodes := make([]node, 0, len(rawNodes)) + nodes := make([]common.Node, 0, len(rawNodes)) for _, node := range rawNodes { - meta, err := decodeNodeMeta(node.Meta) + meta, err := common.DecodeNodeMeta(node.Meta) if err != nil { logrus.Warnf("\t addr: %s, could not decode metadata", node.Addr) continue } - node.nodeMeta = meta + node.NodeMeta = meta nodes = append(nodes, node) logrus.Infof("\taddr: %s, overlay: %s, pubkey: %s", node.Addr, node.OverlayAddr, node.PubKey) } - if err := wg.setUpInterface(nodes); err != nil { + if err := wg.SetUpInterface(nodes); err != nil { logrus.WithError(err).Error("could not up interface") - wg.downInterface() + wg.DownInterface() } if !config.NoEtcHosts { if err := writeToEtcHosts(nodes); err != nil { @@ -89,13 +92,13 @@ func main() { } case <-incomingSigs: logrus.Info("terminating...") - cluster.leave() + cluster.Leave() if !config.NoEtcHosts { if err := writeToEtcHosts(nil); err != nil { logrus.WithError(err).Error("could not remove stale hosts entries") } } - if err := wg.downInterface(); err != nil { + if err := wg.DownInterface(); err != nil { logrus.WithError(err).Error("could not down interface") } os.Exit(0) @@ -103,7 +106,7 @@ func main() { } } -func writeToEtcHosts(nodes []node) error { +func writeToEtcHosts(nodes []common.Node) error { hosts := make(map[string][]string, len(nodes)) for _, n := range nodes { hosts[n.OverlayAddr.IP.String()] = []string{n.Name} diff --git a/netlink.go b/wg/netlink.go similarity index 96% rename from netlink.go rename to wg/netlink.go index 0bd1c6a..391d838 100644 --- a/netlink.go +++ b/wg/netlink.go @@ -1,4 +1,4 @@ -package main +package wg import "github.com/vishvananda/netlink" diff --git a/wireguard.go b/wg/wireguard.go similarity index 87% rename from wireguard.go rename to wg/wireguard.go index d14e334..42357d1 100644 --- a/wireguard.go +++ b/wg/wireguard.go @@ -1,17 +1,18 @@ -package main +package wg import ( "hash/fnv" "net" "os" + "github.com/costela/wesher/common" "github.com/pkg/errors" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type wgState struct { +type WgState struct { iface string client *wgctrl.Client OverlayAddr net.IPNet @@ -20,7 +21,7 @@ type wgState struct { PubKey wgtypes.Key } -func newWGConfig(iface string, port int) (*wgState, error) { +func NewWGConfig(iface string, port int) (*WgState, error) { client, err := wgctrl.New() if err != nil { return nil, errors.Wrap(err, "could not instantiate wireguard client") @@ -32,7 +33,7 @@ func newWGConfig(iface string, port int) (*wgState, error) { } pubKey := privKey.PublicKey() - wgState := wgState{ + wgState := WgState{ iface: iface, client: client, Port: port, @@ -42,7 +43,7 @@ func newWGConfig(iface string, port int) (*wgState, error) { return &wgState, nil } -func (wg *wgState) assignOverlayAddr(ipnet *net.IPNet, name string) { +func (wg *WgState) AssignOverlayAddr(ipnet *net.IPNet, name string) { // TODO: this is way too brittle and opaque bits, size := ipnet.Mask.Size() ip := make([]byte, len(ipnet.IP)) @@ -62,7 +63,7 @@ func (wg *wgState) assignOverlayAddr(ipnet *net.IPNet, name string) { } } -func (wg *wgState) downInterface() error { +func (wg *WgState) DownInterface() error { if _, err := wg.client.Device(wg.iface); err != nil { if os.IsNotExist(err) { return nil // device already gone; noop @@ -76,12 +77,12 @@ func (wg *wgState) downInterface() error { return netlink.LinkDel(link) } -func (wg *wgState) setUpInterface(nodes []node) error { +func (wg *WgState) SetUpInterface(nodes []common.Node) error { if err := netlink.LinkAdd(&wireguard{LinkAttrs: netlink.LinkAttrs{Name: wg.iface}}); err != nil && !os.IsExist(err) { return errors.Wrapf(err, "could not create interface %s", wg.iface) } - peerCfgs, err := wg.nodesToPeerConfigs(nodes) + peerCfgs, err := wg.NodesToPeerConfigs(nodes) if err != nil { return errors.Wrap(err, "error converting received node information to wireguard format") } @@ -121,7 +122,7 @@ func (wg *wgState) setUpInterface(nodes []node) error { return nil } -func (wg *wgState) nodesToPeerConfigs(nodes []node) ([]wgtypes.PeerConfig, error) { +func (wg *WgState) NodesToPeerConfigs(nodes []common.Node) ([]wgtypes.PeerConfig, error) { peerCfgs := make([]wgtypes.PeerConfig, len(nodes)) for i, node := range nodes { pubKey, err := wgtypes.ParseKey(node.PubKey) diff --git a/wireguard_test.go b/wg/wireguard_test.go similarity index 85% rename from wireguard_test.go rename to wg/wireguard_test.go index 6e9295c..3d672ae 100644 --- a/wireguard_test.go +++ b/wg/wireguard_test.go @@ -1,4 +1,4 @@ -package main +package wg import ( "net" @@ -31,8 +31,8 @@ func Test_wgState_assignOverlayAddr(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - wg := &wgState{} - wg.assignOverlayAddr(tt.args.ipnet, tt.args.name) + wg := &WgState{} + wg.AssignOverlayAddr(tt.args.ipnet, tt.args.name) if !reflect.DeepEqual(wg.OverlayAddr.IP.String(), tt.want) { t.Errorf("assignOverlayAddr() set = %s, want %s", wg.OverlayAddr, tt.want) @@ -47,8 +47,8 @@ func Test_wgState_assignOverlayAddr_no_obvious_collisions(t *testing.T) { _, ipnet, _ := net.ParseCIDR("10.0.0.0/24") assignments := make(map[string]string) for _, n := range []string{"test", "test1", "test2", "1test", "2test"} { - wg := &wgState{} - wg.assignOverlayAddr(ipnet, n) + wg := &WgState{} + wg.AssignOverlayAddr(ipnet, n) if assigned, ok := assignments[wg.OverlayAddr.String()]; ok { t.Errorf("IP assignment collision: hash(%s) = hash(%s)", n, assigned) } @@ -59,10 +59,10 @@ func Test_wgState_assignOverlayAddr_no_obvious_collisions(t *testing.T) { // This should ensure the obvious fact that the same name should map to the same IP if called twice. func Test_wgState_assignOverlayAddr_consistent(t *testing.T) { _, ipnet, _ := net.ParseCIDR("10.0.0.0/8") - wg1 := &wgState{} - wg1.assignOverlayAddr(ipnet, "test") - wg2 := &wgState{} - wg2.assignOverlayAddr(ipnet, "test") + wg1 := &WgState{} + wg1.AssignOverlayAddr(ipnet, "test") + wg2 := &WgState{} + wg2.AssignOverlayAddr(ipnet, "test") if wg1.OverlayAddr.String() != wg2.OverlayAddr.String() { t.Errorf("assignOverlayAddr() %s != %s", wg1.OverlayAddr, wg2.OverlayAddr) } @@ -70,10 +70,10 @@ func Test_wgState_assignOverlayAddr_consistent(t *testing.T) { func Test_wgState_assignOverlayAddr_repeatable(t *testing.T) { _, ipnet, _ := net.ParseCIDR("10.0.0.0/8") - wg := &wgState{} - wg.assignOverlayAddr(ipnet, "test") + wg := &WgState{} + wg.AssignOverlayAddr(ipnet, "test") gen1 := wg.OverlayAddr.String() - wg.assignOverlayAddr(ipnet, "test") + wg.AssignOverlayAddr(ipnet, "test") gen2 := wg.OverlayAddr.String() if gen1 != gen2 { t.Errorf("assignOverlayAddr() %s != %s", gen1, gen2)