More like a library (#279)

This commit is contained in:
Nathan Brown 2020-09-18 09:20:09 -05:00 committed by GitHub
parent 6238f1550b
commit 68e3e84fdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 608 additions and 153 deletions

View File

@ -468,6 +468,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
return json.Marshal(jc)
}
//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
// r, err := nc.Marshal()
// if err != nil {
// //TODO
// return nil
// }
//
// c, err := UnmarshalNebulaCertificate(r)
// return c
//}
func (nc *NebulaCertificate) Copy() *NebulaCertificate {
c := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: nc.Details.Name,
Groups: make([]string, len(nc.Details.Groups)),
Ips: make([]*net.IPNet, len(nc.Details.Ips)),
Subnets: make([]*net.IPNet, len(nc.Details.Subnets)),
NotBefore: nc.Details.NotBefore,
NotAfter: nc.Details.NotAfter,
PublicKey: make([]byte, len(nc.Details.PublicKey)),
IsCA: nc.Details.IsCA,
Issuer: nc.Details.Issuer,
InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
},
Signature: make([]byte, len(nc.Signature)),
}
copy(c.Signature, nc.Signature)
copy(c.Details.Groups, nc.Details.Groups)
copy(c.Details.PublicKey, nc.Details.PublicKey)
for i, p := range nc.Details.Ips {
c.Details.Ips[i] = &net.IPNet{
IP: make(net.IP, len(p.IP)),
Mask: make(net.IPMask, len(p.Mask)),
}
copy(c.Details.Ips[i].IP, p.IP)
copy(c.Details.Ips[i].Mask, p.Mask)
}
for i, p := range nc.Details.Subnets {
c.Details.Subnets[i] = &net.IPNet{
IP: make(net.IP, len(p.IP)),
Mask: make(net.IPMask, len(p.Mask)),
}
copy(c.Details.Subnets[i].IP, p.IP)
copy(c.Details.Subnets[i].Mask, p.Mask)
}
for g := range nc.Details.InvertedGroups {
c.Details.InvertedGroups[g] = struct{}{}
}
return c
}
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
for _, net := range rootIps {
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
@ -487,6 +488,17 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
}
func TestNebulaCertificate_Copy(t *testing.T) {
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
assert.Nil(t, err)
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
assert.Nil(t, err)
cc := c.Copy()
util.AssertDeepCopyEqual(t, c, cc)
}
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() {
@ -499,10 +511,11 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
nc := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "test ca",
NotBefore: before,
NotAfter: after,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: true,
InvertedGroups: make(map[string]struct{}),
},
}
@ -544,17 +557,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
if len(ips) == 0 {
ips = []*net.IPNet{
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
}
}
if len(subnets) == 0 {
subnets = []*net.IPNet{
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
}
}
@ -566,11 +579,12 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
Ips: ips,
Subnets: subnets,
Groups: groups,
NotBefore: before,
NotAfter: after,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
Issuer: issuer,
InvertedGroups: make(map[string]struct{}),
},
}

View File

@ -55,7 +55,7 @@ func main() {
l := logrus.New()
l.Out = os.Stdout
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
@ -66,5 +66,10 @@ func main() {
os.Exit(1)
}
if !*configTest {
c.Start()
c.ShutdownBlock()
}
os.Exit(0)
}

View File

@ -14,21 +14,16 @@ import (
var logger service.Logger
type program struct {
exit chan struct{}
configPath *string
configTest *bool
build string
control *nebula.Control
}
func (p *program) Start(s service.Service) error {
logger.Info("Nebula service starting.")
p.exit = make(chan struct{})
// Start should not block.
go p.run()
return nil
}
logger.Info("Nebula service starting.")
func (p *program) run() error {
config := nebula.NewConfig()
err := config.Load(*p.configPath)
if err != nil {
@ -37,17 +32,22 @@ func (p *program) run() error {
l := logrus.New()
l.Out = os.Stdout
return nebula.Main(config, *p.configTest, true, Build, l, nil, nil)
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
if err != nil {
return err
}
p.control.Start()
return nil
}
func (p *program) Stop(s service.Service) error {
logger.Info("Nebula service stopping.")
close(p.exit)
p.control.Stop()
return nil
}
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
if *configPath == "" {
ex, err := os.Executable()
if err != nil {

View File

@ -49,7 +49,7 @@ func main() {
l := logrus.New()
l.Out = os.Stdout
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
@ -60,5 +60,10 @@ func main() {
os.Exit(1)
}
if !*configTest {
c.Start()
c.ShutdownBlock()
}
os.Exit(0)
}

169
control.go Normal file
View File

@ -0,0 +1,169 @@
package nebula
import (
"net"
"os"
"os/signal"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
type Control struct {
f *Interface
l *logrus.Logger
}
type ControlHostInfo struct {
VpnIP net.IP `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []udpAddr `json:"remoteAddrs"`
CachedPackets int `json:"cachedPackets"`
Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
CurrentRemote udpAddr `json:"currentRemote"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
c.f.run()
}
// Stop signals nebula to shutdown, returns after the shutdown is complete
func (c *Control) Stop() {
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
c.f.hostMap.Lock()
for _, h := range c.f.hostMap.Hosts {
if h.ConnectionState.ready {
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
c.f.hostMap.Unlock()
c.l.Info("Goodbye")
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
func (c *Control) ShutdownBlock() {
sigChan := make(chan os.Signal)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
rawSig := <-sigChan
sig := rawSig.String()
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
c.Stop()
}
// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
func (c *Control) RebindUDPServer() {
_ = c.f.outside.Rebind()
}
// ListHostmap returns details about the actual or pending (handshaking) hostmap
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
var hm *HostMap
if pendingMap {
hm = c.f.handshakeManager.pendingHostMap
} else {
hm = c.f.hostMap
}
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v)
i++
}
hm.RUnlock()
return hosts
}
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
var hm *HostMap
if pending {
hm = c.f.handshakeManager.pendingHostMap
} else {
hm = c.f.hostMap
}
h, err := hm.QueryVpnIP(vpnIP)
if err != nil {
return nil
}
ch := copyHostInfo(h)
return &ch
}
// SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
if err != nil {
return nil
}
hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo)
return &ch
}
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
if err != nil {
return false
}
if !localOnly {
c.f.send(
closeTunnel,
0,
hostInfo.ConnectionState,
hostInfo,
hostInfo.remote,
[]byte{},
make([]byte, 12, 12),
make([]byte, mtu),
)
}
c.f.closeTunnel(hostInfo)
return true
}
func copyHostInfo(h *HostInfo) ControlHostInfo {
addrs := h.RemoteUDPAddrs()
chi := ControlHostInfo{
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
CachedPackets: len(h.packetStore),
MessageCounter: *h.ConnectionState.messageCounter,
}
if c := h.GetCert(); c != nil {
chi.Cert = c.Copy()
}
if h.remote != nil {
chi.CurrentRemote = *h.remote
}
for i, addr := range addrs {
chi.RemoteAddrs[i] = addr.Copy()
}
return chi
}

111
control_test.go Normal file
View File

@ -0,0 +1,111 @@
package nebula
import (
"net"
"reflect"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(100, 4444)
remote2 := NewUDPAddr(101, 4444)
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
ipNet2 := net.IPNet{
IP: net.IPv4(1, 2, 3, 5),
Mask: net.IPMask{255, 255, 255, 0},
}
crt := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test",
Ips: []*net.IPNet{&ipNet},
Subnets: []*net.IPNet{},
Groups: []string{"default-group"},
NotBefore: time.Unix(1, 0),
NotAfter: time.Unix(2, 0),
PublicKey: []byte{5, 6, 7, 8},
IsCA: false,
Issuer: "the-issuer",
InvertedGroups: map[string]struct{}{"default-group": {}},
},
Signature: []byte{1, 2, 1, 2, 1, 3},
}
counter := uint64(0)
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
hm.Add(ip2int(ipNet.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: crt,
messageCounter: &counter,
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet.IP),
})
hm.Add(ip2int(ipNet2.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: nil,
messageCounter: &counter,
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet2.IP),
})
c := Control{
f: &Interface{
hostMap: hm,
},
l: logrus.New(),
}
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
expectedInfo := ControlHostInfo{
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []udpAddr{*remote1, *remote2},
CachedPackets: 0,
Cert: crt.Copy(),
MessageCounter: 0,
CurrentRemote: *NewUDPAddr(100, 4444),
}
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
})
}
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
val := reflect.ValueOf(actualStruct).Elem()
fields := make([]string, val.NumField())
for i := 0; i < val.NumField(); i++ {
fields[i] = val.Type().Field(i).Name
}
assert.Equal(t, expected, fields)
}

View File

@ -221,11 +221,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip != nil {
sIp = ip.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, ip, caName, caSha,
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
)
f.rules += ruleString + "\n"
@ -233,7 +239,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming {
direction = "outgoing"
}
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}).
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (

2
go.mod
View File

@ -22,7 +22,7 @@ require (
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
github.com/sirupsen/logrus v1.4.2
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
github.com/stretchr/testify v1.4.0
github.com/stretchr/testify v1.6.1
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975

8
go.sum
View File

@ -103,8 +103,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
@ -112,8 +112,6 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g=
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -154,3 +152,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -35,7 +35,10 @@ type InterfaceConfig struct {
DropLocalBroadcast bool
DropMulticast bool
UDPBatchSize int
udpQueues int
tunQueues int
MessageMetrics *MessageMetrics
version string
}
type Interface struct {
@ -54,6 +57,8 @@ type Interface struct {
dropLocalBroadcast bool
dropMulticast bool
udpBatchSize int
udpQueues int
tunQueues int
version string
metricHandshakes metrics.Histogram
@ -89,6 +94,9 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
udpBatchSize: c.UDPBatchSize,
udpQueues: c.udpQueues,
tunQueues: c.tunQueues,
version: c.version,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
@ -99,29 +107,28 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
return ifce, nil
}
func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) {
func (f *Interface) run() {
// actually turn on tun dev
if err := f.inside.Activate(); err != nil {
l.Fatal(err)
}
f.version = buildVersion
addr, err := f.outside.LocalAddr()
if err != nil {
l.WithError(err).Error("Failed to get udp listen address")
}
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
WithField("build", buildVersion).WithField("udpAddr", addr).
WithField("build", f.version).WithField("udpAddr", addr).
Info("Nebula interface is active")
// Launch n queues to read packets from udp
for i := 0; i < udpRoutines; i++ {
for i := 0; i < f.udpQueues; i++ {
go f.listenOut(i)
}
// Launch n queues to read packets from tun dev
for i := 0; i < tunRoutines; i++ {
for i := 0; i < f.tunQueues; i++ {
go f.listenIn(i)
}
}

View File

@ -1,6 +1,8 @@
package nebula
import (
"errors"
"github.com/sirupsen/logrus"
)
@ -15,10 +17,16 @@ func NewContextualError(msg string, fields map[string]interface{}, realError err
}
func (ce ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
if ce.RealError == nil {
return errors.New(ce.Context)
}
return ce.RealError
}

123
main.go
View File

@ -4,11 +4,8 @@ import (
"encoding/binary"
"fmt"
"net"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/sirupsen/logrus"
@ -21,12 +18,7 @@ var l = logrus.New()
type m map[string]interface{}
type CommandRequest struct {
Command string
Callback chan error
}
func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error {
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
l = logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
@ -36,7 +28,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
if configTest {
b, err := yaml.Marshal(config.Settings)
if err != nil {
return err
return nil, err
}
// Print the final config
@ -45,7 +37,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
err := configLogger(config)
if err != nil {
return NewContextualError("Failed to configure the logger", nil, err)
return nil, NewContextualError("Failed to configure the logger", nil, err)
}
config.RegisterReloadCallback(func(c *Config) {
@ -59,20 +51,20 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
trustedCAs, err = loadCAFromConfig(config)
if err != nil {
//The errors coming out of loadCA are already nicely formatted
return NewContextualError("Failed to load ca from config", nil, err)
return nil, NewContextualError("Failed to load ca from config", nil, err)
}
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config)
if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted
return NewContextualError("Failed to load certificate from config", nil, err)
return nil, NewContextualError("Failed to load certificate from config", nil, err)
}
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(cs.certificate, config)
if err != nil {
return NewContextualError("Error while loading firewall rules", nil, err)
return nil, NewContextualError("Error while loading firewall rules", nil, err)
}
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
@ -80,11 +72,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(config, tunCidr)
if err != nil {
return NewContextualError("Could not parse tun.routes", nil, err)
return nil, NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
if err != nil {
return NewContextualError("Could not parse tun.unsafe_routes", nil, err)
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
@ -92,7 +84,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
if config.GetBool("sshd.enabled", false) {
err = configSSH(ssh, config)
if err != nil {
return NewContextualError("Error while configuring the sshd", nil, err)
return nil, NewContextualError("Error while configuring the sshd", nil, err)
}
}
@ -129,7 +121,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
}
if err != nil {
return NewContextualError("Failed to get a tun/tap device", nil, err)
return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
}
}
@ -140,28 +132,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
if !configTest {
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
if err != nil {
return NewContextualError("Failed to open udp listener", nil, err)
return nil, NewContextualError("Failed to open udp listener", nil, err)
}
udpServer.reloadConfig(config)
}
sigChan := make(chan os.Signal)
killChan := make(chan CommandRequest)
if commandChan != nil {
go func() {
cmd := CommandRequest{}
for {
cmd = <-commandChan
switch cmd.Command {
case "rebind":
udpServer.Rebind()
case "exit":
killChan <- cmd
}
}
}()
}
// Set up my internal host map
var preferredRanges []*net.IPNet
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
@ -170,7 +145,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
return NewContextualError("Failed to parse preferred ranges", nil, err)
return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
@ -183,7 +158,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
return NewContextualError("Failed to parse local_range", nil, err)
return nil, NewContextualError("Failed to parse local_range", nil, err)
}
// Check if the entry for local_range was already specified in
@ -223,7 +198,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
if port == 0 && !configTest {
uPort, err := udpServer.LocalAddr()
if err != nil {
return NewContextualError("Failed to get listening port", nil, err)
return nil, NewContextualError("Failed to get listening port", nil, err)
}
port = int(uPort.Port)
}
@ -240,10 +215,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
for i, host := range rawLighthouseHosts {
ip := net.ParseIP(host)
if ip == nil {
return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
}
if !tunCidr.Contains(ip) {
return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
}
lighthouseHosts[i] = ip2int(ip)
}
@ -263,13 +238,13 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
if err != nil {
return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
}
lightHouse.SetRemoteAllowList(remoteAllowList)
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
if err != nil {
return NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
}
lightHouse.SetLocalAllowList(localAllowList)
@ -277,7 +252,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
if !tunCidr.Contains(vpnIp) {
return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
}
vals, ok := v.([]interface{})
if ok {
@ -288,7 +263,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
ip := addr.IP
port, err := strconv.Atoi(parts[1])
if err != nil {
return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
}
@ -301,7 +276,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
ip := addr.IP
port, err := strconv.Atoi(parts[1])
if err != nil {
return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
}
@ -354,7 +329,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
DropMulticast: config.GetBool("tun.drop_multicast", false),
UDPBatchSize: config.GetInt("listen.batch", 64),
udpQueues: udpQueues,
tunQueues: config.GetInt("tun.routines", 1),
MessageMetrics: messageMetrics,
version: buildVersion,
}
switch ifConfig.Cipher {
@ -363,14 +341,14 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
}
var ifce *Interface
if !configTest {
ifce, err = NewInterface(ifConfig)
if err != nil {
return fmt.Errorf("failed to initialize interface: %s", err)
return nil, fmt.Errorf("failed to initialize interface: %s", err)
}
ifce.RegisterConfigChangeCallbacks(config)
@ -381,18 +359,17 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
err = startStats(config, configTest)
if err != nil {
return NewContextualError("Failed to start stats emitter", nil, err)
return nil, NewContextualError("Failed to start stats emitter", nil, err)
}
if configTest {
return nil
return nil, nil
}
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
if amLighthouse && serveDns {
@ -400,47 +377,5 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
go dnsMain(hostMap, config)
}
if block {
// Just sit here and be friendly, main thread.
shutdownBlock(ifce, sigChan, killChan)
} else {
// Even though we aren't blocking we still want to shutdown gracefully
go shutdownBlock(ifce, sigChan, killChan)
}
return nil
}
func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) {
var cmd CommandRequest
var sig string
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
select {
case rawSig := <-sigChan:
sig = rawSig.String()
case cmd = <-killChan:
sig = "controlling app"
}
l.WithField("signal", sig).Info("Caught signal, shutting down")
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
ifce.hostMap.Lock()
for _, h := range ifce.hostMap.Hosts {
if h.ConnectionState.ready {
ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
ifce.hostMap.Unlock()
l.WithField("signal", sig).Info("Goodbye")
select {
case cmd.Callback <- nil:
default:
}
return &Control{ifce, l}, nil
}

View File

@ -31,6 +31,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *udpConn) Rebind() {
return
func (u *udpConn) Rebind() error {
return nil
}

View File

@ -33,6 +33,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *udpConn) Rebind() {
return
func (u *udpConn) Rebind() error {
return nil
}

View File

@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
return ua.IP.Equal(t.IP) && ua.Port == t.Port
}
func (ua *udpAddr) Copy() udpAddr {
nu := udpAddr{net.UDPAddr{
Port: ua.Port,
Zone: ua.Zone,
IP: make(net.IP, len(ua.IP)),
}}
copy(nu.IP, ua.IP)
return nu
}
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
return err

View File

@ -89,8 +89,12 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
return &udpConn{sysFd: fd}, err
}
func (u *udpConn) Rebind() {
return
func (u *udpConn) Rebind() error {
return nil
}
func (ua *udpAddr) Copy() udpAddr {
return *ua
}
func (u *udpConn) SetRecvBuffer(n int) error {
@ -282,13 +286,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
return ua.IP == t.IP && ua.Port == t.Port
}
func (ua *udpAddr) Copy() *udpAddr {
return &udpAddr{
Port: ua.Port,
IP: ua.IP,
}
}
func (ua *udpAddr) String() string {
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
}

View File

@ -21,6 +21,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *udpConn) Rebind() {
return
func (u *udpConn) Rebind() error {
return nil
}

130
util/assert.go Normal file
View File

@ -0,0 +1,130 @@
package util
import (
"fmt"
"reflect"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/assert"
)
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
v1 := reflect.ValueOf(a)
v2 := reflect.ValueOf(b)
if !assert.Equal(t, v1.Type(), v2.Type()) {
return
}
traverseDeepCopy(t, v1, v2, v1.Type().String())
}
func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
switch v1.Kind() {
case reflect.Array:
for i := 0; i < v1.Len(); i++ {
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
return false
}
}
return true
case reflect.Slice:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
}
if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
return false
}
// A slice with cap 0
if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
return false
}
v1c := v1.Cap()
v2c := v2.Cap()
if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
return assert.Fail(t, "", "%s share some underlying memory", name)
}
for i := 0; i < v1.Len(); i++ {
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
return false
}
}
return true
case reflect.Interface:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
}
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
case reflect.Ptr:
local := reflect.ValueOf(time.Local).Pointer()
if local == v1.Pointer() && local == v2.Pointer() {
return true
}
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
return false
}
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
case reflect.Struct:
for i, n := 0, v1.NumField(); i < n; i++ {
if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
return false
}
}
return true
case reflect.Map:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
}
if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
return false
}
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
return false
}
for _, k := range v1.MapKeys() {
val1 := v1.MapIndex(k)
val2 := v2.MapIndex(k)
if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
return false
}
if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
return false
}
if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
return false
}
}
return true
default:
if v1.CanInterface() && v2.CanInterface() {
return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
}
e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
}
}