Split the application into modules
Splitting into modules will help keep concerns separate, at the cost of a slightly more verbose code.
This commit is contained in:
parent
740a9c44c6
commit
dadfbee083
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package cluster
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -10,30 +10,33 @@ import (
|
||||||
"path"
|
"path"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/costela/wesher/common"
|
||||||
"github.com/hashicorp/memberlist"
|
"github.com/hashicorp/memberlist"
|
||||||
"github.com/mattn/go-isatty"
|
"github.com/mattn/go-isatty"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClusterState keeps track of information needed to rejoin the cluster
|
const KeyLen = 32
|
||||||
type ClusterState struct {
|
|
||||||
|
// State keeps track of information needed to rejoin the cluster
|
||||||
|
type State struct {
|
||||||
ClusterKey []byte
|
ClusterKey []byte
|
||||||
Nodes []node
|
Nodes []common.Node
|
||||||
}
|
}
|
||||||
|
|
||||||
type cluster struct {
|
type Cluster struct {
|
||||||
localName string // used to avoid LocalNode(); should not change
|
LocalName string // used to avoid LocalNode(); should not change
|
||||||
ml *memberlist.Memberlist
|
ml *memberlist.Memberlist
|
||||||
getMeta func(int) []byte
|
getMeta func(int) []byte
|
||||||
state *ClusterState
|
state *State
|
||||||
events chan memberlist.NodeEvent
|
events chan memberlist.NodeEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
const statePath = "/var/lib/wesher/state.json"
|
const statePath = "/var/lib/wesher/state.json"
|
||||||
|
|
||||||
func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, useIPAsName bool, getMeta func(int) []byte) (*cluster, error) {
|
func New(init bool, clusterKey []byte, bindAddr string, bindPort int, useIPAsName bool, getMeta func(int) []byte) (*Cluster, error) {
|
||||||
state := &ClusterState{}
|
state := &State{}
|
||||||
if !init {
|
if !init {
|
||||||
loadState(state)
|
loadState(state)
|
||||||
}
|
}
|
||||||
|
@ -58,8 +61,8 @@ func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, use
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cluster := cluster{
|
cluster := Cluster{
|
||||||
localName: ml.LocalNode().Name,
|
LocalName: ml.LocalNode().Name,
|
||||||
ml: ml,
|
ml: ml,
|
||||||
getMeta: getMeta,
|
getMeta: getMeta,
|
||||||
// The big channel buffer is a work-around for https://github.com/hashicorp/memberlist/issues/23
|
// The big channel buffer is a work-around for https://github.com/hashicorp/memberlist/issues/23
|
||||||
|
@ -74,21 +77,21 @@ func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, use
|
||||||
return &cluster, nil
|
return &cluster, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cluster) NotifyConflict(node, other *memberlist.Node) {
|
func (c *Cluster) NotifyConflict(node, other *memberlist.Node) {
|
||||||
logrus.Errorf("node name conflict detected: %s", other.Name)
|
logrus.Errorf("node name conflict detected: %s", other.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cluster) NodeMeta(limit int) []byte {
|
func (c *Cluster) NodeMeta(limit int) []byte {
|
||||||
return c.getMeta(limit)
|
return c.getMeta(limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// none of these are used
|
// none of these are used
|
||||||
func (c *cluster) NotifyMsg([]byte) {}
|
func (c *Cluster) NotifyMsg([]byte) {}
|
||||||
func (c *cluster) GetBroadcasts(overhead, limit int) [][]byte { return nil }
|
func (c *Cluster) GetBroadcasts(overhead, limit int) [][]byte { return nil }
|
||||||
func (c *cluster) LocalState(join bool) []byte { return nil }
|
func (c *Cluster) LocalState(join bool) []byte { return nil }
|
||||||
func (c *cluster) MergeRemoteState(buf []byte, join bool) {}
|
func (c *Cluster) MergeRemoteState(buf []byte, join bool) {}
|
||||||
|
|
||||||
func (c *cluster) join(addrs []string) error {
|
func (c *Cluster) Join(addrs []string) error {
|
||||||
if len(addrs) == 0 {
|
if len(addrs) == 0 {
|
||||||
for _, n := range c.state.Nodes {
|
for _, n := range c.state.Nodes {
|
||||||
addrs = append(addrs, n.Addr.String())
|
addrs = append(addrs, n.Addr.String())
|
||||||
|
@ -103,22 +106,22 @@ func (c *cluster) join(addrs []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cluster) leave() {
|
func (c *Cluster) Leave() {
|
||||||
c.saveState()
|
c.saveState()
|
||||||
c.ml.Leave(10 * time.Second)
|
c.ml.Leave(10 * time.Second)
|
||||||
c.ml.Shutdown() //nolint: errcheck
|
c.ml.Shutdown() //nolint: errcheck
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cluster) update() {
|
func (c *Cluster) Update() {
|
||||||
c.ml.UpdateNode(1 * time.Second) // we currently do not update after creation
|
c.ml.UpdateNode(1 * time.Second) // we currently do not update after creation
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cluster) members() <-chan []node {
|
func (c *Cluster) Members() <-chan []common.Node {
|
||||||
changes := make(chan []node)
|
changes := make(chan []common.Node)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
event := <-c.events
|
event := <-c.events
|
||||||
if event.Node.Name == c.localName {
|
if event.Node.Name == c.LocalName {
|
||||||
// ignore events about ourselves
|
// ignore events about ourselves
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -131,12 +134,12 @@ func (c *cluster) members() <-chan []node {
|
||||||
logrus.Infof("node %s left", event.Node)
|
logrus.Infof("node %s left", event.Node)
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes := make([]node, 0)
|
nodes := make([]common.Node, 0)
|
||||||
for _, n := range c.ml.Members() {
|
for _, n := range c.ml.Members() {
|
||||||
if n.Name == c.localName {
|
if n.Name == c.LocalName {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
nodes = append(nodes, node{
|
nodes = append(nodes, common.Node{
|
||||||
Name: n.Name,
|
Name: n.Name,
|
||||||
Addr: n.Addr,
|
Addr: n.Addr,
|
||||||
Meta: n.Meta,
|
Meta: n.Meta,
|
||||||
|
@ -150,12 +153,12 @@ func (c *cluster) members() <-chan []node {
|
||||||
return changes
|
return changes
|
||||||
}
|
}
|
||||||
|
|
||||||
func computeClusterKey(state *ClusterState, clusterKey []byte) ([]byte, error) {
|
func computeClusterKey(state *State, clusterKey []byte) ([]byte, error) {
|
||||||
if len(clusterKey) == 0 {
|
if len(clusterKey) == 0 {
|
||||||
clusterKey = state.ClusterKey
|
clusterKey = state.ClusterKey
|
||||||
}
|
}
|
||||||
if len(clusterKey) == 0 {
|
if len(clusterKey) == 0 {
|
||||||
clusterKey = make([]byte, clusterKeyLen)
|
clusterKey = make([]byte, KeyLen)
|
||||||
_, err := rand.Read(clusterKey)
|
_, err := rand.Read(clusterKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -169,7 +172,7 @@ func computeClusterKey(state *ClusterState, clusterKey []byte) ([]byte, error) {
|
||||||
return clusterKey, nil
|
return clusterKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cluster) saveState() error {
|
func (c *Cluster) saveState() error {
|
||||||
if err := os.MkdirAll(path.Dir(statePath), 0700); err != nil {
|
if err := os.MkdirAll(path.Dir(statePath), 0700); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -182,7 +185,7 @@ func (c *cluster) saveState() error {
|
||||||
return ioutil.WriteFile(statePath, stateOut, 0600)
|
return ioutil.WriteFile(statePath, stateOut, 0600)
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadState(cs *ClusterState) {
|
func loadState(cs *State) {
|
||||||
content, err := ioutil.ReadFile(statePath)
|
content, err := ioutil.ReadFile(statePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
|
@ -192,7 +195,7 @@ func loadState(cs *ClusterState) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// avoid partially unmarshalled content by using a temp var
|
// avoid partially unmarshalled content by using a temp var
|
||||||
csTmp := &ClusterState{}
|
csTmp := &State{}
|
||||||
if err := json.Unmarshal(content, csTmp); err != nil {
|
if err := json.Unmarshal(content, csTmp); err != nil {
|
||||||
logrus.Warnf("could not decode state: %s", err)
|
logrus.Warnf("could not decode state: %s", err)
|
||||||
} else {
|
} else {
|
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -9,25 +9,25 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nodeMeta holds metadata sent over the cluster
|
// NodeMeta holds metadata sent over the cluster
|
||||||
type nodeMeta struct {
|
type NodeMeta struct {
|
||||||
OverlayAddr net.IPNet
|
OverlayAddr net.IPNet
|
||||||
PubKey string
|
PubKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Node holds the memberlist node structure
|
// Node holds the memberlist node structure
|
||||||
type node struct {
|
type Node struct {
|
||||||
Name string
|
Name string
|
||||||
Addr net.IP
|
Addr net.IP
|
||||||
Meta []byte
|
Meta []byte
|
||||||
nodeMeta
|
NodeMeta
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *node) String() string {
|
func (n *Node) String() string {
|
||||||
return n.Addr.String()
|
return n.Addr.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeNodeMeta(nm nodeMeta, limit int) []byte {
|
func EncodeNodeMeta(nm NodeMeta, limit int) []byte {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
if err := gob.NewEncoder(buf).Encode(nm); err != nil {
|
if err := gob.NewEncoder(buf).Encode(nm); err != nil {
|
||||||
logrus.Errorf("could not encode local state: %s", err)
|
logrus.Errorf("could not encode local state: %s", err)
|
||||||
|
@ -40,10 +40,10 @@ func encodeNodeMeta(nm nodeMeta, limit int) []byte {
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeNodeMeta(b []byte) (nodeMeta, error) {
|
func DecodeNodeMeta(b []byte) (NodeMeta, error) {
|
||||||
// TODO: we blindly trust the info we get from the peers; We should be more defensive to limit the damage a leaked
|
// TODO: we blindly trust the info we get from the peers; We should be more defensive to limit the damage a leaked
|
||||||
// PSK can cause.
|
// PSK can cause.
|
||||||
nm := nodeMeta{}
|
nm := NodeMeta{}
|
||||||
if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&nm); err != nil {
|
if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&nm); err != nil {
|
||||||
return nm, errors.Wrap(err, "could not decode node meta")
|
return nm, errors.Wrap(err, "could not decode node meta")
|
||||||
}
|
}
|
|
@ -4,12 +4,11 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/costela/wesher/cluster"
|
||||||
"github.com/hashicorp/go-sockaddr"
|
"github.com/hashicorp/go-sockaddr"
|
||||||
"github.com/stevenroose/gonfig"
|
"github.com/stevenroose/gonfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
const clusterKeyLen = 32
|
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
ClusterKey []byte `id:"cluster-key" desc:"shared key for cluster membership; must be 32 bytes base64 encoded; will be generated if not provided"`
|
ClusterKey []byte `id:"cluster-key" desc:"shared key for cluster membership; must be 32 bytes base64 encoded; will be generated if not provided"`
|
||||||
Join []string `desc:"comma separated list of hostnames or IP addresses to existing cluster members; if not provided, will attempt resuming any known state or otherwise wait for further members."`
|
Join []string `desc:"comma separated list of hostnames or IP addresses to existing cluster members; if not provided, will attempt resuming any known state or otherwise wait for further members."`
|
||||||
|
@ -36,8 +35,8 @@ func loadConfig() (*config, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// perform some validation
|
// perform some validation
|
||||||
if len(config.ClusterKey) != 0 && len(config.ClusterKey) != clusterKeyLen {
|
if len(config.ClusterKey) != 0 && len(config.ClusterKey) != cluster.KeyLen {
|
||||||
return nil, fmt.Errorf("unsupported cluster key length; expected %d, got %d", clusterKeyLen, len(config.ClusterKey))
|
return nil, fmt.Errorf("unsupported cluster key length; expected %d, got %d", cluster.KeyLen, len(config.ClusterKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
if bits, _ := ((*net.IPNet)(config.OverlayNet)).Mask.Size(); bits%8 != 0 {
|
if bits, _ := ((*net.IPNet)(config.OverlayNet)).Mask.Size(); bits%8 != 0 {
|
||||||
|
|
33
main.go
33
main.go
|
@ -9,7 +9,10 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff"
|
"github.com/cenkalti/backoff"
|
||||||
|
"github.com/costela/wesher/cluster"
|
||||||
|
"github.com/costela/wesher/common"
|
||||||
"github.com/costela/wesher/etchosts"
|
"github.com/costela/wesher/etchosts"
|
||||||
|
"github.com/costela/wesher/wg"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -30,28 +33,28 @@ func main() {
|
||||||
}
|
}
|
||||||
logrus.SetLevel(logLevel)
|
logrus.SetLevel(logLevel)
|
||||||
|
|
||||||
wg, err := newWGConfig(config.Interface, config.WireguardPort)
|
wg, err := wg.NewWGConfig(config.Interface, config.WireguardPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Fatal("could not instantiate wireguard controller")
|
logrus.WithError(err).Fatal("could not instantiate wireguard controller")
|
||||||
}
|
}
|
||||||
|
|
||||||
getMeta := func(limit int) []byte {
|
getMeta := func(limit int) []byte {
|
||||||
return encodeNodeMeta(nodeMeta{
|
return common.EncodeNodeMeta(common.NodeMeta{
|
||||||
OverlayAddr: wg.OverlayAddr,
|
OverlayAddr: wg.OverlayAddr,
|
||||||
PubKey: wg.PubKey.String(),
|
PubKey: wg.PubKey.String(),
|
||||||
}, limit)
|
}, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
cluster, err := newCluster(config.Init, config.ClusterKey, config.BindAddr, config.ClusterPort, config.UseIPAsName, getMeta)
|
cluster, err := cluster.New(config.Init, config.ClusterKey, config.BindAddr, config.ClusterPort, config.UseIPAsName, getMeta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Fatal("could not create cluster")
|
logrus.WithError(err).Fatal("could not create cluster")
|
||||||
}
|
}
|
||||||
wg.assignOverlayAddr((*net.IPNet)(config.OverlayNet), cluster.localName)
|
wg.AssignOverlayAddr((*net.IPNet)(config.OverlayNet), cluster.LocalName)
|
||||||
cluster.update()
|
cluster.Update()
|
||||||
|
|
||||||
nodec := cluster.members() // avoid deadlocks by starting before join
|
nodec := cluster.Members() // avoid deadlocks by starting before join
|
||||||
if err := backoff.RetryNotify(
|
if err := backoff.RetryNotify(
|
||||||
func() error { return cluster.join(config.Join) },
|
func() error { return cluster.Join(config.Join) },
|
||||||
backoff.NewExponentialBackOff(),
|
backoff.NewExponentialBackOff(),
|
||||||
func(err error, dur time.Duration) {
|
func(err error, dur time.Duration) {
|
||||||
logrus.WithError(err).Errorf("could not join cluster, retrying in %s", dur)
|
logrus.WithError(err).Errorf("could not join cluster, retrying in %s", dur)
|
||||||
|
@ -67,20 +70,20 @@ func main() {
|
||||||
select {
|
select {
|
||||||
case rawNodes := <-nodec:
|
case rawNodes := <-nodec:
|
||||||
logrus.Info("cluster members:\n")
|
logrus.Info("cluster members:\n")
|
||||||
nodes := make([]node, 0, len(rawNodes))
|
nodes := make([]common.Node, 0, len(rawNodes))
|
||||||
for _, node := range rawNodes {
|
for _, node := range rawNodes {
|
||||||
meta, err := decodeNodeMeta(node.Meta)
|
meta, err := common.DecodeNodeMeta(node.Meta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Warnf("\t addr: %s, could not decode metadata", node.Addr)
|
logrus.Warnf("\t addr: %s, could not decode metadata", node.Addr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
node.nodeMeta = meta
|
node.NodeMeta = meta
|
||||||
nodes = append(nodes, node)
|
nodes = append(nodes, node)
|
||||||
logrus.Infof("\taddr: %s, overlay: %s, pubkey: %s", node.Addr, node.OverlayAddr, node.PubKey)
|
logrus.Infof("\taddr: %s, overlay: %s, pubkey: %s", node.Addr, node.OverlayAddr, node.PubKey)
|
||||||
}
|
}
|
||||||
if err := wg.setUpInterface(nodes); err != nil {
|
if err := wg.SetUpInterface(nodes); err != nil {
|
||||||
logrus.WithError(err).Error("could not up interface")
|
logrus.WithError(err).Error("could not up interface")
|
||||||
wg.downInterface()
|
wg.DownInterface()
|
||||||
}
|
}
|
||||||
if !config.NoEtcHosts {
|
if !config.NoEtcHosts {
|
||||||
if err := writeToEtcHosts(nodes); err != nil {
|
if err := writeToEtcHosts(nodes); err != nil {
|
||||||
|
@ -89,13 +92,13 @@ func main() {
|
||||||
}
|
}
|
||||||
case <-incomingSigs:
|
case <-incomingSigs:
|
||||||
logrus.Info("terminating...")
|
logrus.Info("terminating...")
|
||||||
cluster.leave()
|
cluster.Leave()
|
||||||
if !config.NoEtcHosts {
|
if !config.NoEtcHosts {
|
||||||
if err := writeToEtcHosts(nil); err != nil {
|
if err := writeToEtcHosts(nil); err != nil {
|
||||||
logrus.WithError(err).Error("could not remove stale hosts entries")
|
logrus.WithError(err).Error("could not remove stale hosts entries")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := wg.downInterface(); err != nil {
|
if err := wg.DownInterface(); err != nil {
|
||||||
logrus.WithError(err).Error("could not down interface")
|
logrus.WithError(err).Error("could not down interface")
|
||||||
}
|
}
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
@ -103,7 +106,7 @@ func main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeToEtcHosts(nodes []node) error {
|
func writeToEtcHosts(nodes []common.Node) error {
|
||||||
hosts := make(map[string][]string, len(nodes))
|
hosts := make(map[string][]string, len(nodes))
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
hosts[n.OverlayAddr.IP.String()] = []string{n.Name}
|
hosts[n.OverlayAddr.IP.String()] = []string{n.Name}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package wg
|
||||||
|
|
||||||
import "github.com/vishvananda/netlink"
|
import "github.com/vishvananda/netlink"
|
||||||
|
|
|
@ -1,17 +1,18 @@
|
||||||
package main
|
package wg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/costela/wesher/common"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgState struct {
|
type WgState struct {
|
||||||
iface string
|
iface string
|
||||||
client *wgctrl.Client
|
client *wgctrl.Client
|
||||||
OverlayAddr net.IPNet
|
OverlayAddr net.IPNet
|
||||||
|
@ -20,7 +21,7 @@ type wgState struct {
|
||||||
PubKey wgtypes.Key
|
PubKey wgtypes.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWGConfig(iface string, port int) (*wgState, error) {
|
func NewWGConfig(iface string, port int) (*WgState, error) {
|
||||||
client, err := wgctrl.New()
|
client, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "could not instantiate wireguard client")
|
return nil, errors.Wrap(err, "could not instantiate wireguard client")
|
||||||
|
@ -32,7 +33,7 @@ func newWGConfig(iface string, port int) (*wgState, error) {
|
||||||
}
|
}
|
||||||
pubKey := privKey.PublicKey()
|
pubKey := privKey.PublicKey()
|
||||||
|
|
||||||
wgState := wgState{
|
wgState := WgState{
|
||||||
iface: iface,
|
iface: iface,
|
||||||
client: client,
|
client: client,
|
||||||
Port: port,
|
Port: port,
|
||||||
|
@ -42,7 +43,7 @@ func newWGConfig(iface string, port int) (*wgState, error) {
|
||||||
return &wgState, nil
|
return &wgState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wg *wgState) assignOverlayAddr(ipnet *net.IPNet, name string) {
|
func (wg *WgState) AssignOverlayAddr(ipnet *net.IPNet, name string) {
|
||||||
// TODO: this is way too brittle and opaque
|
// TODO: this is way too brittle and opaque
|
||||||
bits, size := ipnet.Mask.Size()
|
bits, size := ipnet.Mask.Size()
|
||||||
ip := make([]byte, len(ipnet.IP))
|
ip := make([]byte, len(ipnet.IP))
|
||||||
|
@ -62,7 +63,7 @@ func (wg *wgState) assignOverlayAddr(ipnet *net.IPNet, name string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wg *wgState) downInterface() error {
|
func (wg *WgState) DownInterface() error {
|
||||||
if _, err := wg.client.Device(wg.iface); err != nil {
|
if _, err := wg.client.Device(wg.iface); err != nil {
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
return nil // device already gone; noop
|
return nil // device already gone; noop
|
||||||
|
@ -76,12 +77,12 @@ func (wg *wgState) downInterface() error {
|
||||||
return netlink.LinkDel(link)
|
return netlink.LinkDel(link)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wg *wgState) setUpInterface(nodes []node) error {
|
func (wg *WgState) SetUpInterface(nodes []common.Node) error {
|
||||||
if err := netlink.LinkAdd(&wireguard{LinkAttrs: netlink.LinkAttrs{Name: wg.iface}}); err != nil && !os.IsExist(err) {
|
if err := netlink.LinkAdd(&wireguard{LinkAttrs: netlink.LinkAttrs{Name: wg.iface}}); err != nil && !os.IsExist(err) {
|
||||||
return errors.Wrapf(err, "could not create interface %s", wg.iface)
|
return errors.Wrapf(err, "could not create interface %s", wg.iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCfgs, err := wg.nodesToPeerConfigs(nodes)
|
peerCfgs, err := wg.NodesToPeerConfigs(nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "error converting received node information to wireguard format")
|
return errors.Wrap(err, "error converting received node information to wireguard format")
|
||||||
}
|
}
|
||||||
|
@ -121,7 +122,7 @@ func (wg *wgState) setUpInterface(nodes []node) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wg *wgState) nodesToPeerConfigs(nodes []node) ([]wgtypes.PeerConfig, error) {
|
func (wg *WgState) NodesToPeerConfigs(nodes []common.Node) ([]wgtypes.PeerConfig, error) {
|
||||||
peerCfgs := make([]wgtypes.PeerConfig, len(nodes))
|
peerCfgs := make([]wgtypes.PeerConfig, len(nodes))
|
||||||
for i, node := range nodes {
|
for i, node := range nodes {
|
||||||
pubKey, err := wgtypes.ParseKey(node.PubKey)
|
pubKey, err := wgtypes.ParseKey(node.PubKey)
|
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package wg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
@ -31,8 +31,8 @@ func Test_wgState_assignOverlayAddr(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
wg := &wgState{}
|
wg := &WgState{}
|
||||||
wg.assignOverlayAddr(tt.args.ipnet, tt.args.name)
|
wg.AssignOverlayAddr(tt.args.ipnet, tt.args.name)
|
||||||
|
|
||||||
if !reflect.DeepEqual(wg.OverlayAddr.IP.String(), tt.want) {
|
if !reflect.DeepEqual(wg.OverlayAddr.IP.String(), tt.want) {
|
||||||
t.Errorf("assignOverlayAddr() set = %s, want %s", wg.OverlayAddr, tt.want)
|
t.Errorf("assignOverlayAddr() set = %s, want %s", wg.OverlayAddr, tt.want)
|
||||||
|
@ -47,8 +47,8 @@ func Test_wgState_assignOverlayAddr_no_obvious_collisions(t *testing.T) {
|
||||||
_, ipnet, _ := net.ParseCIDR("10.0.0.0/24")
|
_, ipnet, _ := net.ParseCIDR("10.0.0.0/24")
|
||||||
assignments := make(map[string]string)
|
assignments := make(map[string]string)
|
||||||
for _, n := range []string{"test", "test1", "test2", "1test", "2test"} {
|
for _, n := range []string{"test", "test1", "test2", "1test", "2test"} {
|
||||||
wg := &wgState{}
|
wg := &WgState{}
|
||||||
wg.assignOverlayAddr(ipnet, n)
|
wg.AssignOverlayAddr(ipnet, n)
|
||||||
if assigned, ok := assignments[wg.OverlayAddr.String()]; ok {
|
if assigned, ok := assignments[wg.OverlayAddr.String()]; ok {
|
||||||
t.Errorf("IP assignment collision: hash(%s) = hash(%s)", n, assigned)
|
t.Errorf("IP assignment collision: hash(%s) = hash(%s)", n, assigned)
|
||||||
}
|
}
|
||||||
|
@ -59,10 +59,10 @@ func Test_wgState_assignOverlayAddr_no_obvious_collisions(t *testing.T) {
|
||||||
// This should ensure the obvious fact that the same name should map to the same IP if called twice.
|
// This should ensure the obvious fact that the same name should map to the same IP if called twice.
|
||||||
func Test_wgState_assignOverlayAddr_consistent(t *testing.T) {
|
func Test_wgState_assignOverlayAddr_consistent(t *testing.T) {
|
||||||
_, ipnet, _ := net.ParseCIDR("10.0.0.0/8")
|
_, ipnet, _ := net.ParseCIDR("10.0.0.0/8")
|
||||||
wg1 := &wgState{}
|
wg1 := &WgState{}
|
||||||
wg1.assignOverlayAddr(ipnet, "test")
|
wg1.AssignOverlayAddr(ipnet, "test")
|
||||||
wg2 := &wgState{}
|
wg2 := &WgState{}
|
||||||
wg2.assignOverlayAddr(ipnet, "test")
|
wg2.AssignOverlayAddr(ipnet, "test")
|
||||||
if wg1.OverlayAddr.String() != wg2.OverlayAddr.String() {
|
if wg1.OverlayAddr.String() != wg2.OverlayAddr.String() {
|
||||||
t.Errorf("assignOverlayAddr() %s != %s", wg1.OverlayAddr, wg2.OverlayAddr)
|
t.Errorf("assignOverlayAddr() %s != %s", wg1.OverlayAddr, wg2.OverlayAddr)
|
||||||
}
|
}
|
||||||
|
@ -70,10 +70,10 @@ func Test_wgState_assignOverlayAddr_consistent(t *testing.T) {
|
||||||
|
|
||||||
func Test_wgState_assignOverlayAddr_repeatable(t *testing.T) {
|
func Test_wgState_assignOverlayAddr_repeatable(t *testing.T) {
|
||||||
_, ipnet, _ := net.ParseCIDR("10.0.0.0/8")
|
_, ipnet, _ := net.ParseCIDR("10.0.0.0/8")
|
||||||
wg := &wgState{}
|
wg := &WgState{}
|
||||||
wg.assignOverlayAddr(ipnet, "test")
|
wg.AssignOverlayAddr(ipnet, "test")
|
||||||
gen1 := wg.OverlayAddr.String()
|
gen1 := wg.OverlayAddr.String()
|
||||||
wg.assignOverlayAddr(ipnet, "test")
|
wg.AssignOverlayAddr(ipnet, "test")
|
||||||
gen2 := wg.OverlayAddr.String()
|
gen2 := wg.OverlayAddr.String()
|
||||||
if gen1 != gen2 {
|
if gen1 != gen2 {
|
||||||
t.Errorf("assignOverlayAddr() %s != %s", gen1, gen2)
|
t.Errorf("assignOverlayAddr() %s != %s", gen1, gen2)
|
Loading…
Reference in New Issue