Split the application into modules
Splitting into modules will help keep concerns separate, at the cost of a slightly more verbose code.
This commit is contained in:
parent
740a9c44c6
commit
dadfbee083
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
@ -10,30 +10,33 @@ import (
|
|||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/costela/wesher/common"
|
||||
"github.com/hashicorp/memberlist"
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ClusterState keeps track of information needed to rejoin the cluster
|
||||
type ClusterState struct {
|
||||
const KeyLen = 32
|
||||
|
||||
// State keeps track of information needed to rejoin the cluster
|
||||
type State struct {
|
||||
ClusterKey []byte
|
||||
Nodes []node
|
||||
Nodes []common.Node
|
||||
}
|
||||
|
||||
type cluster struct {
|
||||
localName string // used to avoid LocalNode(); should not change
|
||||
type Cluster struct {
|
||||
LocalName string // used to avoid LocalNode(); should not change
|
||||
ml *memberlist.Memberlist
|
||||
getMeta func(int) []byte
|
||||
state *ClusterState
|
||||
state *State
|
||||
events chan memberlist.NodeEvent
|
||||
}
|
||||
|
||||
const statePath = "/var/lib/wesher/state.json"
|
||||
|
||||
func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, useIPAsName bool, getMeta func(int) []byte) (*cluster, error) {
|
||||
state := &ClusterState{}
|
||||
func New(init bool, clusterKey []byte, bindAddr string, bindPort int, useIPAsName bool, getMeta func(int) []byte) (*Cluster, error) {
|
||||
state := &State{}
|
||||
if !init {
|
||||
loadState(state)
|
||||
}
|
||||
|
@ -58,8 +61,8 @@ func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, use
|
|||
return nil, err
|
||||
}
|
||||
|
||||
cluster := cluster{
|
||||
localName: ml.LocalNode().Name,
|
||||
cluster := Cluster{
|
||||
LocalName: ml.LocalNode().Name,
|
||||
ml: ml,
|
||||
getMeta: getMeta,
|
||||
// The big channel buffer is a work-around for https://github.com/hashicorp/memberlist/issues/23
|
||||
|
@ -74,21 +77,21 @@ func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, use
|
|||
return &cluster, nil
|
||||
}
|
||||
|
||||
func (c *cluster) NotifyConflict(node, other *memberlist.Node) {
|
||||
func (c *Cluster) NotifyConflict(node, other *memberlist.Node) {
|
||||
logrus.Errorf("node name conflict detected: %s", other.Name)
|
||||
}
|
||||
|
||||
func (c *cluster) NodeMeta(limit int) []byte {
|
||||
func (c *Cluster) NodeMeta(limit int) []byte {
|
||||
return c.getMeta(limit)
|
||||
}
|
||||
|
||||
// none of these are used
|
||||
func (c *cluster) NotifyMsg([]byte) {}
|
||||
func (c *cluster) GetBroadcasts(overhead, limit int) [][]byte { return nil }
|
||||
func (c *cluster) LocalState(join bool) []byte { return nil }
|
||||
func (c *cluster) MergeRemoteState(buf []byte, join bool) {}
|
||||
func (c *Cluster) NotifyMsg([]byte) {}
|
||||
func (c *Cluster) GetBroadcasts(overhead, limit int) [][]byte { return nil }
|
||||
func (c *Cluster) LocalState(join bool) []byte { return nil }
|
||||
func (c *Cluster) MergeRemoteState(buf []byte, join bool) {}
|
||||
|
||||
func (c *cluster) join(addrs []string) error {
|
||||
func (c *Cluster) Join(addrs []string) error {
|
||||
if len(addrs) == 0 {
|
||||
for _, n := range c.state.Nodes {
|
||||
addrs = append(addrs, n.Addr.String())
|
||||
|
@ -103,22 +106,22 @@ func (c *cluster) join(addrs []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *cluster) leave() {
|
||||
func (c *Cluster) Leave() {
|
||||
c.saveState()
|
||||
c.ml.Leave(10 * time.Second)
|
||||
c.ml.Shutdown() //nolint: errcheck
|
||||
}
|
||||
|
||||
func (c *cluster) update() {
|
||||
func (c *Cluster) Update() {
|
||||
c.ml.UpdateNode(1 * time.Second) // we currently do not update after creation
|
||||
}
|
||||
|
||||
func (c *cluster) members() <-chan []node {
|
||||
changes := make(chan []node)
|
||||
func (c *Cluster) Members() <-chan []common.Node {
|
||||
changes := make(chan []common.Node)
|
||||
go func() {
|
||||
for {
|
||||
event := <-c.events
|
||||
if event.Node.Name == c.localName {
|
||||
if event.Node.Name == c.LocalName {
|
||||
// ignore events about ourselves
|
||||
continue
|
||||
}
|
||||
|
@ -131,12 +134,12 @@ func (c *cluster) members() <-chan []node {
|
|||
logrus.Infof("node %s left", event.Node)
|
||||
}
|
||||
|
||||
nodes := make([]node, 0)
|
||||
nodes := make([]common.Node, 0)
|
||||
for _, n := range c.ml.Members() {
|
||||
if n.Name == c.localName {
|
||||
if n.Name == c.LocalName {
|
||||
continue
|
||||
}
|
||||
nodes = append(nodes, node{
|
||||
nodes = append(nodes, common.Node{
|
||||
Name: n.Name,
|
||||
Addr: n.Addr,
|
||||
Meta: n.Meta,
|
||||
|
@ -150,12 +153,12 @@ func (c *cluster) members() <-chan []node {
|
|||
return changes
|
||||
}
|
||||
|
||||
func computeClusterKey(state *ClusterState, clusterKey []byte) ([]byte, error) {
|
||||
func computeClusterKey(state *State, clusterKey []byte) ([]byte, error) {
|
||||
if len(clusterKey) == 0 {
|
||||
clusterKey = state.ClusterKey
|
||||
}
|
||||
if len(clusterKey) == 0 {
|
||||
clusterKey = make([]byte, clusterKeyLen)
|
||||
clusterKey = make([]byte, KeyLen)
|
||||
_, err := rand.Read(clusterKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -169,7 +172,7 @@ func computeClusterKey(state *ClusterState, clusterKey []byte) ([]byte, error) {
|
|||
return clusterKey, nil
|
||||
}
|
||||
|
||||
func (c *cluster) saveState() error {
|
||||
func (c *Cluster) saveState() error {
|
||||
if err := os.MkdirAll(path.Dir(statePath), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -182,7 +185,7 @@ func (c *cluster) saveState() error {
|
|||
return ioutil.WriteFile(statePath, stateOut, 0600)
|
||||
}
|
||||
|
||||
func loadState(cs *ClusterState) {
|
||||
func loadState(cs *State) {
|
||||
content, err := ioutil.ReadFile(statePath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
|
@ -192,7 +195,7 @@ func loadState(cs *ClusterState) {
|
|||
}
|
||||
|
||||
// avoid partially unmarshalled content by using a temp var
|
||||
csTmp := &ClusterState{}
|
||||
csTmp := &State{}
|
||||
if err := json.Unmarshal(content, csTmp); err != nil {
|
||||
logrus.Warnf("could not decode state: %s", err)
|
||||
} else {
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -9,25 +9,25 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// nodeMeta holds metadata sent over the cluster
|
||||
type nodeMeta struct {
|
||||
// NodeMeta holds metadata sent over the cluster
|
||||
type NodeMeta struct {
|
||||
OverlayAddr net.IPNet
|
||||
PubKey string
|
||||
}
|
||||
|
||||
// Node holds the memberlist node structure
|
||||
type node struct {
|
||||
type Node struct {
|
||||
Name string
|
||||
Addr net.IP
|
||||
Meta []byte
|
||||
nodeMeta
|
||||
NodeMeta
|
||||
}
|
||||
|
||||
func (n *node) String() string {
|
||||
func (n *Node) String() string {
|
||||
return n.Addr.String()
|
||||
}
|
||||
|
||||
func encodeNodeMeta(nm nodeMeta, limit int) []byte {
|
||||
func EncodeNodeMeta(nm NodeMeta, limit int) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
if err := gob.NewEncoder(buf).Encode(nm); err != nil {
|
||||
logrus.Errorf("could not encode local state: %s", err)
|
||||
|
@ -40,10 +40,10 @@ func encodeNodeMeta(nm nodeMeta, limit int) []byte {
|
|||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func decodeNodeMeta(b []byte) (nodeMeta, error) {
|
||||
func DecodeNodeMeta(b []byte) (NodeMeta, error) {
|
||||
// TODO: we blindly trust the info we get from the peers; We should be more defensive to limit the damage a leaked
|
||||
// PSK can cause.
|
||||
nm := nodeMeta{}
|
||||
nm := NodeMeta{}
|
||||
if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&nm); err != nil {
|
||||
return nm, errors.Wrap(err, "could not decode node meta")
|
||||
}
|
|
@ -4,12 +4,11 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/costela/wesher/cluster"
|
||||
"github.com/hashicorp/go-sockaddr"
|
||||
"github.com/stevenroose/gonfig"
|
||||
)
|
||||
|
||||
const clusterKeyLen = 32
|
||||
|
||||
type config struct {
|
||||
ClusterKey []byte `id:"cluster-key" desc:"shared key for cluster membership; must be 32 bytes base64 encoded; will be generated if not provided"`
|
||||
Join []string `desc:"comma separated list of hostnames or IP addresses to existing cluster members; if not provided, will attempt resuming any known state or otherwise wait for further members."`
|
||||
|
@ -36,8 +35,8 @@ func loadConfig() (*config, error) {
|
|||
}
|
||||
|
||||
// perform some validation
|
||||
if len(config.ClusterKey) != 0 && len(config.ClusterKey) != clusterKeyLen {
|
||||
return nil, fmt.Errorf("unsupported cluster key length; expected %d, got %d", clusterKeyLen, len(config.ClusterKey))
|
||||
if len(config.ClusterKey) != 0 && len(config.ClusterKey) != cluster.KeyLen {
|
||||
return nil, fmt.Errorf("unsupported cluster key length; expected %d, got %d", cluster.KeyLen, len(config.ClusterKey))
|
||||
}
|
||||
|
||||
if bits, _ := ((*net.IPNet)(config.OverlayNet)).Mask.Size(); bits%8 != 0 {
|
||||
|
|
33
main.go
33
main.go
|
@ -9,7 +9,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
"github.com/costela/wesher/cluster"
|
||||
"github.com/costela/wesher/common"
|
||||
"github.com/costela/wesher/etchosts"
|
||||
"github.com/costela/wesher/wg"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -30,28 +33,28 @@ func main() {
|
|||
}
|
||||
logrus.SetLevel(logLevel)
|
||||
|
||||
wg, err := newWGConfig(config.Interface, config.WireguardPort)
|
||||
wg, err := wg.NewWGConfig(config.Interface, config.WireguardPort)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("could not instantiate wireguard controller")
|
||||
}
|
||||
|
||||
getMeta := func(limit int) []byte {
|
||||
return encodeNodeMeta(nodeMeta{
|
||||
return common.EncodeNodeMeta(common.NodeMeta{
|
||||
OverlayAddr: wg.OverlayAddr,
|
||||
PubKey: wg.PubKey.String(),
|
||||
}, limit)
|
||||
}
|
||||
|
||||
cluster, err := newCluster(config.Init, config.ClusterKey, config.BindAddr, config.ClusterPort, config.UseIPAsName, getMeta)
|
||||
cluster, err := cluster.New(config.Init, config.ClusterKey, config.BindAddr, config.ClusterPort, config.UseIPAsName, getMeta)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("could not create cluster")
|
||||
}
|
||||
wg.assignOverlayAddr((*net.IPNet)(config.OverlayNet), cluster.localName)
|
||||
cluster.update()
|
||||
wg.AssignOverlayAddr((*net.IPNet)(config.OverlayNet), cluster.LocalName)
|
||||
cluster.Update()
|
||||
|
||||
nodec := cluster.members() // avoid deadlocks by starting before join
|
||||
nodec := cluster.Members() // avoid deadlocks by starting before join
|
||||
if err := backoff.RetryNotify(
|
||||
func() error { return cluster.join(config.Join) },
|
||||
func() error { return cluster.Join(config.Join) },
|
||||
backoff.NewExponentialBackOff(),
|
||||
func(err error, dur time.Duration) {
|
||||
logrus.WithError(err).Errorf("could not join cluster, retrying in %s", dur)
|
||||
|
@ -67,20 +70,20 @@ func main() {
|
|||
select {
|
||||
case rawNodes := <-nodec:
|
||||
logrus.Info("cluster members:\n")
|
||||
nodes := make([]node, 0, len(rawNodes))
|
||||
nodes := make([]common.Node, 0, len(rawNodes))
|
||||
for _, node := range rawNodes {
|
||||
meta, err := decodeNodeMeta(node.Meta)
|
||||
meta, err := common.DecodeNodeMeta(node.Meta)
|
||||
if err != nil {
|
||||
logrus.Warnf("\t addr: %s, could not decode metadata", node.Addr)
|
||||
continue
|
||||
}
|
||||
node.nodeMeta = meta
|
||||
node.NodeMeta = meta
|
||||
nodes = append(nodes, node)
|
||||
logrus.Infof("\taddr: %s, overlay: %s, pubkey: %s", node.Addr, node.OverlayAddr, node.PubKey)
|
||||
}
|
||||
if err := wg.setUpInterface(nodes); err != nil {
|
||||
if err := wg.SetUpInterface(nodes); err != nil {
|
||||
logrus.WithError(err).Error("could not up interface")
|
||||
wg.downInterface()
|
||||
wg.DownInterface()
|
||||
}
|
||||
if !config.NoEtcHosts {
|
||||
if err := writeToEtcHosts(nodes); err != nil {
|
||||
|
@ -89,13 +92,13 @@ func main() {
|
|||
}
|
||||
case <-incomingSigs:
|
||||
logrus.Info("terminating...")
|
||||
cluster.leave()
|
||||
cluster.Leave()
|
||||
if !config.NoEtcHosts {
|
||||
if err := writeToEtcHosts(nil); err != nil {
|
||||
logrus.WithError(err).Error("could not remove stale hosts entries")
|
||||
}
|
||||
}
|
||||
if err := wg.downInterface(); err != nil {
|
||||
if err := wg.DownInterface(); err != nil {
|
||||
logrus.WithError(err).Error("could not down interface")
|
||||
}
|
||||
os.Exit(0)
|
||||
|
@ -103,7 +106,7 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
func writeToEtcHosts(nodes []node) error {
|
||||
func writeToEtcHosts(nodes []common.Node) error {
|
||||
hosts := make(map[string][]string, len(nodes))
|
||||
for _, n := range nodes {
|
||||
hosts[n.OverlayAddr.IP.String()] = []string{n.Name}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package wg
|
||||
|
||||
import "github.com/vishvananda/netlink"
|
||||
|
|
@ -1,17 +1,18 @@
|
|||
package main
|
||||
package wg
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/costela/wesher/common"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type wgState struct {
|
||||
type WgState struct {
|
||||
iface string
|
||||
client *wgctrl.Client
|
||||
OverlayAddr net.IPNet
|
||||
|
@ -20,7 +21,7 @@ type wgState struct {
|
|||
PubKey wgtypes.Key
|
||||
}
|
||||
|
||||
func newWGConfig(iface string, port int) (*wgState, error) {
|
||||
func NewWGConfig(iface string, port int) (*WgState, error) {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not instantiate wireguard client")
|
||||
|
@ -32,7 +33,7 @@ func newWGConfig(iface string, port int) (*wgState, error) {
|
|||
}
|
||||
pubKey := privKey.PublicKey()
|
||||
|
||||
wgState := wgState{
|
||||
wgState := WgState{
|
||||
iface: iface,
|
||||
client: client,
|
||||
Port: port,
|
||||
|
@ -42,7 +43,7 @@ func newWGConfig(iface string, port int) (*wgState, error) {
|
|||
return &wgState, nil
|
||||
}
|
||||
|
||||
func (wg *wgState) assignOverlayAddr(ipnet *net.IPNet, name string) {
|
||||
func (wg *WgState) AssignOverlayAddr(ipnet *net.IPNet, name string) {
|
||||
// TODO: this is way too brittle and opaque
|
||||
bits, size := ipnet.Mask.Size()
|
||||
ip := make([]byte, len(ipnet.IP))
|
||||
|
@ -62,7 +63,7 @@ func (wg *wgState) assignOverlayAddr(ipnet *net.IPNet, name string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (wg *wgState) downInterface() error {
|
||||
func (wg *WgState) DownInterface() error {
|
||||
if _, err := wg.client.Device(wg.iface); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // device already gone; noop
|
||||
|
@ -76,12 +77,12 @@ func (wg *wgState) downInterface() error {
|
|||
return netlink.LinkDel(link)
|
||||
}
|
||||
|
||||
func (wg *wgState) setUpInterface(nodes []node) error {
|
||||
func (wg *WgState) SetUpInterface(nodes []common.Node) error {
|
||||
if err := netlink.LinkAdd(&wireguard{LinkAttrs: netlink.LinkAttrs{Name: wg.iface}}); err != nil && !os.IsExist(err) {
|
||||
return errors.Wrapf(err, "could not create interface %s", wg.iface)
|
||||
}
|
||||
|
||||
peerCfgs, err := wg.nodesToPeerConfigs(nodes)
|
||||
peerCfgs, err := wg.NodesToPeerConfigs(nodes)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error converting received node information to wireguard format")
|
||||
}
|
||||
|
@ -121,7 +122,7 @@ func (wg *wgState) setUpInterface(nodes []node) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (wg *wgState) nodesToPeerConfigs(nodes []node) ([]wgtypes.PeerConfig, error) {
|
||||
func (wg *WgState) NodesToPeerConfigs(nodes []common.Node) ([]wgtypes.PeerConfig, error) {
|
||||
peerCfgs := make([]wgtypes.PeerConfig, len(nodes))
|
||||
for i, node := range nodes {
|
||||
pubKey, err := wgtypes.ParseKey(node.PubKey)
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package wg
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
@ -31,8 +31,8 @@ func Test_wgState_assignOverlayAddr(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
wg := &wgState{}
|
||||
wg.assignOverlayAddr(tt.args.ipnet, tt.args.name)
|
||||
wg := &WgState{}
|
||||
wg.AssignOverlayAddr(tt.args.ipnet, tt.args.name)
|
||||
|
||||
if !reflect.DeepEqual(wg.OverlayAddr.IP.String(), tt.want) {
|
||||
t.Errorf("assignOverlayAddr() set = %s, want %s", wg.OverlayAddr, tt.want)
|
||||
|
@ -47,8 +47,8 @@ func Test_wgState_assignOverlayAddr_no_obvious_collisions(t *testing.T) {
|
|||
_, ipnet, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
assignments := make(map[string]string)
|
||||
for _, n := range []string{"test", "test1", "test2", "1test", "2test"} {
|
||||
wg := &wgState{}
|
||||
wg.assignOverlayAddr(ipnet, n)
|
||||
wg := &WgState{}
|
||||
wg.AssignOverlayAddr(ipnet, n)
|
||||
if assigned, ok := assignments[wg.OverlayAddr.String()]; ok {
|
||||
t.Errorf("IP assignment collision: hash(%s) = hash(%s)", n, assigned)
|
||||
}
|
||||
|
@ -59,10 +59,10 @@ func Test_wgState_assignOverlayAddr_no_obvious_collisions(t *testing.T) {
|
|||
// This should ensure the obvious fact that the same name should map to the same IP if called twice.
|
||||
func Test_wgState_assignOverlayAddr_consistent(t *testing.T) {
|
||||
_, ipnet, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
wg1 := &wgState{}
|
||||
wg1.assignOverlayAddr(ipnet, "test")
|
||||
wg2 := &wgState{}
|
||||
wg2.assignOverlayAddr(ipnet, "test")
|
||||
wg1 := &WgState{}
|
||||
wg1.AssignOverlayAddr(ipnet, "test")
|
||||
wg2 := &WgState{}
|
||||
wg2.AssignOverlayAddr(ipnet, "test")
|
||||
if wg1.OverlayAddr.String() != wg2.OverlayAddr.String() {
|
||||
t.Errorf("assignOverlayAddr() %s != %s", wg1.OverlayAddr, wg2.OverlayAddr)
|
||||
}
|
||||
|
@ -70,10 +70,10 @@ func Test_wgState_assignOverlayAddr_consistent(t *testing.T) {
|
|||
|
||||
func Test_wgState_assignOverlayAddr_repeatable(t *testing.T) {
|
||||
_, ipnet, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
wg := &wgState{}
|
||||
wg.assignOverlayAddr(ipnet, "test")
|
||||
wg := &WgState{}
|
||||
wg.AssignOverlayAddr(ipnet, "test")
|
||||
gen1 := wg.OverlayAddr.String()
|
||||
wg.assignOverlayAddr(ipnet, "test")
|
||||
wg.AssignOverlayAddr(ipnet, "test")
|
||||
gen2 := wg.OverlayAddr.String()
|
||||
if gen1 != gen2 {
|
||||
t.Errorf("assignOverlayAddr() %s != %s", gen1, gen2)
|
Loading…
Reference in New Issue