More like a library (#279)
This commit is contained in:
parent
6238f1550b
commit
68e3e84fdc
57
cert/cert.go
57
cert/cert.go
|
@ -468,6 +468,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
|
||||||
return json.Marshal(jc)
|
return json.Marshal(jc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
|
||||||
|
// r, err := nc.Marshal()
|
||||||
|
// if err != nil {
|
||||||
|
// //TODO
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// c, err := UnmarshalNebulaCertificate(r)
|
||||||
|
// return c
|
||||||
|
//}
|
||||||
|
|
||||||
|
func (nc *NebulaCertificate) Copy() *NebulaCertificate {
|
||||||
|
c := &NebulaCertificate{
|
||||||
|
Details: NebulaCertificateDetails{
|
||||||
|
Name: nc.Details.Name,
|
||||||
|
Groups: make([]string, len(nc.Details.Groups)),
|
||||||
|
Ips: make([]*net.IPNet, len(nc.Details.Ips)),
|
||||||
|
Subnets: make([]*net.IPNet, len(nc.Details.Subnets)),
|
||||||
|
NotBefore: nc.Details.NotBefore,
|
||||||
|
NotAfter: nc.Details.NotAfter,
|
||||||
|
PublicKey: make([]byte, len(nc.Details.PublicKey)),
|
||||||
|
IsCA: nc.Details.IsCA,
|
||||||
|
Issuer: nc.Details.Issuer,
|
||||||
|
InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
|
||||||
|
},
|
||||||
|
Signature: make([]byte, len(nc.Signature)),
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(c.Signature, nc.Signature)
|
||||||
|
copy(c.Details.Groups, nc.Details.Groups)
|
||||||
|
copy(c.Details.PublicKey, nc.Details.PublicKey)
|
||||||
|
|
||||||
|
for i, p := range nc.Details.Ips {
|
||||||
|
c.Details.Ips[i] = &net.IPNet{
|
||||||
|
IP: make(net.IP, len(p.IP)),
|
||||||
|
Mask: make(net.IPMask, len(p.Mask)),
|
||||||
|
}
|
||||||
|
copy(c.Details.Ips[i].IP, p.IP)
|
||||||
|
copy(c.Details.Ips[i].Mask, p.Mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, p := range nc.Details.Subnets {
|
||||||
|
c.Details.Subnets[i] = &net.IPNet{
|
||||||
|
IP: make(net.IP, len(p.IP)),
|
||||||
|
Mask: make(net.IPMask, len(p.Mask)),
|
||||||
|
}
|
||||||
|
copy(c.Details.Subnets[i].IP, p.IP)
|
||||||
|
copy(c.Details.Subnets[i].Mask, p.Mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
for g := range nc.Details.InvertedGroups {
|
||||||
|
c.Details.InvertedGroups[g] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
|
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
|
||||||
for _, net := range rootIps {
|
for _, net := range rootIps {
|
||||||
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
|
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/crypto/curve25519"
|
"golang.org/x/crypto/curve25519"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
|
@ -487,6 +488,17 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
|
||||||
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNebulaCertificate_Copy(t *testing.T) {
|
||||||
|
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
cc := c.Copy()
|
||||||
|
|
||||||
|
util.AssertDeepCopyEqual(t, c, cc)
|
||||||
|
}
|
||||||
|
|
||||||
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
|
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
if before.IsZero() {
|
if before.IsZero() {
|
||||||
|
@ -498,11 +510,12 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
|
||||||
|
|
||||||
nc := &NebulaCertificate{
|
nc := &NebulaCertificate{
|
||||||
Details: NebulaCertificateDetails{
|
Details: NebulaCertificateDetails{
|
||||||
Name: "test ca",
|
Name: "test ca",
|
||||||
NotBefore: before,
|
NotBefore: time.Unix(before.Unix(), 0),
|
||||||
NotAfter: after,
|
NotAfter: time.Unix(after.Unix(), 0),
|
||||||
PublicKey: pub,
|
PublicKey: pub,
|
||||||
IsCA: true,
|
IsCA: true,
|
||||||
|
InvertedGroups: make(map[string]struct{}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -544,17 +557,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
|
||||||
|
|
||||||
if len(ips) == 0 {
|
if len(ips) == 0 {
|
||||||
ips = []*net.IPNet{
|
ips = []*net.IPNet{
|
||||||
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
|
{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
|
||||||
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
|
{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
|
||||||
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
|
{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(subnets) == 0 {
|
if len(subnets) == 0 {
|
||||||
subnets = []*net.IPNet{
|
subnets = []*net.IPNet{
|
||||||
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
|
{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
|
||||||
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
|
{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
|
||||||
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
|
{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -562,15 +575,16 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
|
||||||
|
|
||||||
nc := &NebulaCertificate{
|
nc := &NebulaCertificate{
|
||||||
Details: NebulaCertificateDetails{
|
Details: NebulaCertificateDetails{
|
||||||
Name: "testing",
|
Name: "testing",
|
||||||
Ips: ips,
|
Ips: ips,
|
||||||
Subnets: subnets,
|
Subnets: subnets,
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
NotBefore: before,
|
NotBefore: time.Unix(before.Unix(), 0),
|
||||||
NotAfter: after,
|
NotAfter: time.Unix(after.Unix(), 0),
|
||||||
PublicKey: pub,
|
PublicKey: pub,
|
||||||
IsCA: false,
|
IsCA: false,
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
|
InvertedGroups: make(map[string]struct{}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ func main() {
|
||||||
|
|
||||||
l := logrus.New()
|
l := logrus.New()
|
||||||
l.Out = os.Stdout
|
l.Out = os.Stdout
|
||||||
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
|
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||||
|
|
||||||
switch v := err.(type) {
|
switch v := err.(type) {
|
||||||
case nebula.ContextualError:
|
case nebula.ContextualError:
|
||||||
|
@ -66,5 +66,10 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !*configTest {
|
||||||
|
c.Start()
|
||||||
|
c.ShutdownBlock()
|
||||||
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,21 +14,16 @@ import (
|
||||||
var logger service.Logger
|
var logger service.Logger
|
||||||
|
|
||||||
type program struct {
|
type program struct {
|
||||||
exit chan struct{}
|
|
||||||
configPath *string
|
configPath *string
|
||||||
configTest *bool
|
configTest *bool
|
||||||
build string
|
build string
|
||||||
|
control *nebula.Control
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *program) Start(s service.Service) error {
|
func (p *program) Start(s service.Service) error {
|
||||||
logger.Info("Nebula service starting.")
|
|
||||||
p.exit = make(chan struct{})
|
|
||||||
// Start should not block.
|
// Start should not block.
|
||||||
go p.run()
|
logger.Info("Nebula service starting.")
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *program) run() error {
|
|
||||||
config := nebula.NewConfig()
|
config := nebula.NewConfig()
|
||||||
err := config.Load(*p.configPath)
|
err := config.Load(*p.configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -37,17 +32,22 @@ func (p *program) run() error {
|
||||||
|
|
||||||
l := logrus.New()
|
l := logrus.New()
|
||||||
l.Out = os.Stdout
|
l.Out = os.Stdout
|
||||||
return nebula.Main(config, *p.configTest, true, Build, l, nil, nil)
|
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.control.Start()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *program) Stop(s service.Service) error {
|
func (p *program) Stop(s service.Service) error {
|
||||||
logger.Info("Nebula service stopping.")
|
logger.Info("Nebula service stopping.")
|
||||||
close(p.exit)
|
p.control.Stop()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
|
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
|
||||||
|
|
||||||
if *configPath == "" {
|
if *configPath == "" {
|
||||||
ex, err := os.Executable()
|
ex, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -49,7 +49,7 @@ func main() {
|
||||||
|
|
||||||
l := logrus.New()
|
l := logrus.New()
|
||||||
l.Out = os.Stdout
|
l.Out = os.Stdout
|
||||||
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
|
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||||
|
|
||||||
switch v := err.(type) {
|
switch v := err.(type) {
|
||||||
case nebula.ContextualError:
|
case nebula.ContextualError:
|
||||||
|
@ -60,5 +60,10 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !*configTest {
|
||||||
|
c.Start()
|
||||||
|
c.ShutdownBlock()
|
||||||
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,169 @@
|
||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
|
type Control struct {
|
||||||
|
f *Interface
|
||||||
|
l *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type ControlHostInfo struct {
|
||||||
|
VpnIP net.IP `json:"vpnIp"`
|
||||||
|
LocalIndex uint32 `json:"localIndex"`
|
||||||
|
RemoteIndex uint32 `json:"remoteIndex"`
|
||||||
|
RemoteAddrs []udpAddr `json:"remoteAddrs"`
|
||||||
|
CachedPackets int `json:"cachedPackets"`
|
||||||
|
Cert *cert.NebulaCertificate `json:"cert"`
|
||||||
|
MessageCounter uint64 `json:"messageCounter"`
|
||||||
|
CurrentRemote udpAddr `json:"currentRemote"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||||
|
func (c *Control) Start() {
|
||||||
|
c.f.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop signals nebula to shutdown, returns after the shutdown is complete
|
||||||
|
func (c *Control) Stop() {
|
||||||
|
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
|
||||||
|
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
||||||
|
c.f.hostMap.Lock()
|
||||||
|
for _, h := range c.f.hostMap.Hosts {
|
||||||
|
if h.ConnectionState.ready {
|
||||||
|
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
|
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
|
||||||
|
Debug("Sending close tunnel message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.f.hostMap.Unlock()
|
||||||
|
c.l.Info("Goodbye")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
|
func (c *Control) ShutdownBlock() {
|
||||||
|
sigChan := make(chan os.Signal)
|
||||||
|
signal.Notify(sigChan, syscall.SIGTERM)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT)
|
||||||
|
|
||||||
|
rawSig := <-sigChan
|
||||||
|
sig := rawSig.String()
|
||||||
|
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
|
||||||
|
c.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
|
||||||
|
func (c *Control) RebindUDPServer() {
|
||||||
|
_ = c.f.outside.Rebind()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListHostmap returns details about the actual or pending (handshaking) hostmap
|
||||||
|
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
|
||||||
|
var hm *HostMap
|
||||||
|
if pendingMap {
|
||||||
|
hm = c.f.handshakeManager.pendingHostMap
|
||||||
|
} else {
|
||||||
|
hm = c.f.hostMap
|
||||||
|
}
|
||||||
|
|
||||||
|
hm.RLock()
|
||||||
|
hosts := make([]ControlHostInfo, len(hm.Hosts))
|
||||||
|
i := 0
|
||||||
|
for _, v := range hm.Hosts {
|
||||||
|
hosts[i] = copyHostInfo(v)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
hm.RUnlock()
|
||||||
|
|
||||||
|
return hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
|
||||||
|
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
|
||||||
|
var hm *HostMap
|
||||||
|
if pending {
|
||||||
|
hm = c.f.handshakeManager.pendingHostMap
|
||||||
|
} else {
|
||||||
|
hm = c.f.hostMap
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := hm.QueryVpnIP(vpnIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := copyHostInfo(h)
|
||||||
|
return &ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
||||||
|
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
|
||||||
|
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hostInfo.SetRemote(addr.Copy())
|
||||||
|
ch := copyHostInfo(hostInfo)
|
||||||
|
return &ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
||||||
|
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
|
||||||
|
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !localOnly {
|
||||||
|
c.f.send(
|
||||||
|
closeTunnel,
|
||||||
|
0,
|
||||||
|
hostInfo.ConnectionState,
|
||||||
|
hostInfo,
|
||||||
|
hostInfo.remote,
|
||||||
|
[]byte{},
|
||||||
|
make([]byte, 12, 12),
|
||||||
|
make([]byte, mtu),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.f.closeTunnel(hostInfo)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyHostInfo(h *HostInfo) ControlHostInfo {
|
||||||
|
addrs := h.RemoteUDPAddrs()
|
||||||
|
chi := ControlHostInfo{
|
||||||
|
VpnIP: int2ip(h.hostId),
|
||||||
|
LocalIndex: h.localIndexId,
|
||||||
|
RemoteIndex: h.remoteIndexId,
|
||||||
|
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
|
||||||
|
CachedPackets: len(h.packetStore),
|
||||||
|
MessageCounter: *h.ConnectionState.messageCounter,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c := h.GetCert(); c != nil {
|
||||||
|
chi.Cert = c.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.remote != nil {
|
||||||
|
chi.CurrentRemote = *h.remote
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, addr := range addrs {
|
||||||
|
chi.RemoteAddrs[i] = addr.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
return chi
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||||
|
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||||
|
// To properly ensure we are not exposing core memory to the caller
|
||||||
|
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||||
|
remote1 := NewUDPAddr(100, 4444)
|
||||||
|
remote2 := NewUDPAddr(101, 4444)
|
||||||
|
ipNet := net.IPNet{
|
||||||
|
IP: net.IPv4(1, 2, 3, 4),
|
||||||
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
ipNet2 := net.IPNet{
|
||||||
|
IP: net.IPv4(1, 2, 3, 5),
|
||||||
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
crt := &cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
Name: "test",
|
||||||
|
Ips: []*net.IPNet{&ipNet},
|
||||||
|
Subnets: []*net.IPNet{},
|
||||||
|
Groups: []string{"default-group"},
|
||||||
|
NotBefore: time.Unix(1, 0),
|
||||||
|
NotAfter: time.Unix(2, 0),
|
||||||
|
PublicKey: []byte{5, 6, 7, 8},
|
||||||
|
IsCA: false,
|
||||||
|
Issuer: "the-issuer",
|
||||||
|
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||||
|
},
|
||||||
|
Signature: []byte{1, 2, 1, 2, 1, 3},
|
||||||
|
}
|
||||||
|
counter := uint64(0)
|
||||||
|
|
||||||
|
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
|
||||||
|
hm.Add(ip2int(ipNet.IP), &HostInfo{
|
||||||
|
remote: remote1,
|
||||||
|
Remotes: remotes,
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: crt,
|
||||||
|
messageCounter: &counter,
|
||||||
|
},
|
||||||
|
remoteIndexId: 200,
|
||||||
|
localIndexId: 201,
|
||||||
|
hostId: ip2int(ipNet.IP),
|
||||||
|
})
|
||||||
|
|
||||||
|
hm.Add(ip2int(ipNet2.IP), &HostInfo{
|
||||||
|
remote: remote1,
|
||||||
|
Remotes: remotes,
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: nil,
|
||||||
|
messageCounter: &counter,
|
||||||
|
},
|
||||||
|
remoteIndexId: 200,
|
||||||
|
localIndexId: 201,
|
||||||
|
hostId: ip2int(ipNet2.IP),
|
||||||
|
})
|
||||||
|
|
||||||
|
c := Control{
|
||||||
|
f: &Interface{
|
||||||
|
hostMap: hm,
|
||||||
|
},
|
||||||
|
l: logrus.New(),
|
||||||
|
}
|
||||||
|
|
||||||
|
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
|
||||||
|
|
||||||
|
expectedInfo := ControlHostInfo{
|
||||||
|
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
|
||||||
|
LocalIndex: 201,
|
||||||
|
RemoteIndex: 200,
|
||||||
|
RemoteAddrs: []udpAddr{*remote1, *remote2},
|
||||||
|
CachedPackets: 0,
|
||||||
|
Cert: crt.Copy(),
|
||||||
|
MessageCounter: 0,
|
||||||
|
CurrentRemote: *NewUDPAddr(100, 4444),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we don't have any unexpected fields
|
||||||
|
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
|
||||||
|
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||||
|
|
||||||
|
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
|
||||||
|
val := reflect.ValueOf(actualStruct).Elem()
|
||||||
|
fields := make([]string, val.NumField())
|
||||||
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
fields[i] = val.Type().Field(i).Name
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, fields)
|
||||||
|
}
|
10
firewall.go
10
firewall.go
|
@ -221,11 +221,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
|
||||||
|
|
||||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
||||||
|
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
|
||||||
|
// https://github.com/golang/go/issues/14131
|
||||||
|
sIp := ""
|
||||||
|
if ip != nil {
|
||||||
|
sIp = ip.String()
|
||||||
|
}
|
||||||
|
|
||||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||||
ruleString := fmt.Sprintf(
|
ruleString := fmt.Sprintf(
|
||||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
|
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
|
||||||
incoming, proto, startPort, endPort, groups, host, ip, caName, caSha,
|
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
|
||||||
)
|
)
|
||||||
f.rules += ruleString + "\n"
|
f.rules += ruleString + "\n"
|
||||||
|
|
||||||
|
@ -233,7 +239,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||||
if !incoming {
|
if !incoming {
|
||||||
direction = "outgoing"
|
direction = "outgoing"
|
||||||
}
|
}
|
||||||
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}).
|
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||||
Info("Firewall rule added")
|
Info("Firewall rule added")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -22,7 +22,7 @@ require (
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
|
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
|
||||||
github.com/sirupsen/logrus v1.4.2
|
github.com/sirupsen/logrus v1.4.2
|
||||||
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
|
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
|
||||||
github.com/stretchr/testify v1.4.0
|
github.com/stretchr/testify v1.6.1
|
||||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
|
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
|
||||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
|
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
|
||||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
|
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
|
||||||
|
|
8
go.sum
8
go.sum
|
@ -103,8 +103,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
|
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
|
||||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
||||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
||||||
|
@ -112,8 +112,6 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
|
||||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
|
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
|
||||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g=
|
|
||||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
|
||||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
|
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
|
||||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
@ -154,3 +152,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
|
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
|
||||||
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|
17
interface.go
17
interface.go
|
@ -35,7 +35,10 @@ type InterfaceConfig struct {
|
||||||
DropLocalBroadcast bool
|
DropLocalBroadcast bool
|
||||||
DropMulticast bool
|
DropMulticast bool
|
||||||
UDPBatchSize int
|
UDPBatchSize int
|
||||||
|
udpQueues int
|
||||||
|
tunQueues int
|
||||||
MessageMetrics *MessageMetrics
|
MessageMetrics *MessageMetrics
|
||||||
|
version string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
|
@ -54,6 +57,8 @@ type Interface struct {
|
||||||
dropLocalBroadcast bool
|
dropLocalBroadcast bool
|
||||||
dropMulticast bool
|
dropMulticast bool
|
||||||
udpBatchSize int
|
udpBatchSize int
|
||||||
|
udpQueues int
|
||||||
|
tunQueues int
|
||||||
version string
|
version string
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
|
@ -89,6 +94,9 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||||
dropLocalBroadcast: c.DropLocalBroadcast,
|
dropLocalBroadcast: c.DropLocalBroadcast,
|
||||||
dropMulticast: c.DropMulticast,
|
dropMulticast: c.DropMulticast,
|
||||||
udpBatchSize: c.UDPBatchSize,
|
udpBatchSize: c.UDPBatchSize,
|
||||||
|
udpQueues: c.udpQueues,
|
||||||
|
tunQueues: c.tunQueues,
|
||||||
|
version: c.version,
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
messageMetrics: c.MessageMetrics,
|
messageMetrics: c.MessageMetrics,
|
||||||
|
@ -99,29 +107,28 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||||
return ifce, nil
|
return ifce, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) {
|
func (f *Interface) run() {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
if err := f.inside.Activate(); err != nil {
|
if err := f.inside.Activate(); err != nil {
|
||||||
l.Fatal(err)
|
l.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.version = buildVersion
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to get udp listen address")
|
l.WithError(err).Error("Failed to get udp listen address")
|
||||||
}
|
}
|
||||||
|
|
||||||
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
|
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
|
||||||
WithField("build", buildVersion).WithField("udpAddr", addr).
|
WithField("build", f.version).WithField("udpAddr", addr).
|
||||||
Info("Nebula interface is active")
|
Info("Nebula interface is active")
|
||||||
|
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
for i := 0; i < udpRoutines; i++ {
|
for i := 0; i < f.udpQueues; i++ {
|
||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < tunRoutines; i++ {
|
for i := 0; i < f.tunQueues; i++ {
|
||||||
go f.listenIn(i)
|
go f.listenIn(i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,10 +17,16 @@ func NewContextualError(msg string, fields map[string]interface{}, realError err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ce ContextualError) Error() string {
|
func (ce ContextualError) Error() string {
|
||||||
|
if ce.RealError == nil {
|
||||||
|
return ce.Context
|
||||||
|
}
|
||||||
return ce.RealError.Error()
|
return ce.RealError.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ce ContextualError) Unwrap() error {
|
func (ce ContextualError) Unwrap() error {
|
||||||
|
if ce.RealError == nil {
|
||||||
|
return errors.New(ce.Context)
|
||||||
|
}
|
||||||
return ce.RealError
|
return ce.RealError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
123
main.go
123
main.go
|
@ -4,11 +4,8 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
@ -21,12 +18,7 @@ var l = logrus.New()
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m map[string]interface{}
|
||||||
|
|
||||||
type CommandRequest struct {
|
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
|
||||||
Command string
|
|
||||||
Callback chan error
|
|
||||||
}
|
|
||||||
|
|
||||||
func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error {
|
|
||||||
l = logger
|
l = logger
|
||||||
l.Formatter = &logrus.TextFormatter{
|
l.Formatter = &logrus.TextFormatter{
|
||||||
FullTimestamp: true,
|
FullTimestamp: true,
|
||||||
|
@ -36,7 +28,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
if configTest {
|
if configTest {
|
||||||
b, err := yaml.Marshal(config.Settings)
|
b, err := yaml.Marshal(config.Settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Print the final config
|
// Print the final config
|
||||||
|
@ -45,7 +37,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
|
|
||||||
err := configLogger(config)
|
err := configLogger(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Failed to configure the logger", nil, err)
|
return nil, NewContextualError("Failed to configure the logger", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config.RegisterReloadCallback(func(c *Config) {
|
config.RegisterReloadCallback(func(c *Config) {
|
||||||
|
@ -59,20 +51,20 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
trustedCAs, err = loadCAFromConfig(config)
|
trustedCAs, err = loadCAFromConfig(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of loadCA are already nicely formatted
|
//The errors coming out of loadCA are already nicely formatted
|
||||||
return NewContextualError("Failed to load ca from config", nil, err)
|
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
|
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||||
|
|
||||||
cs, err := NewCertStateFromConfig(config)
|
cs, err := NewCertStateFromConfig(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||||
return NewContextualError("Failed to load certificate from config", nil, err)
|
return nil, NewContextualError("Failed to load certificate from config", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||||
|
|
||||||
fw, err := NewFirewallFromConfig(cs.certificate, config)
|
fw, err := NewFirewallFromConfig(cs.certificate, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Error while loading firewall rules", nil, err)
|
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
||||||
|
|
||||||
|
@ -80,11 +72,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
tunCidr := cs.certificate.Details.Ips[0]
|
tunCidr := cs.certificate.Details.Ips[0]
|
||||||
routes, err := parseRoutes(config, tunCidr)
|
routes, err := parseRoutes(config, tunCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Could not parse tun.routes", nil, err)
|
return nil, NewContextualError("Could not parse tun.routes", nil, err)
|
||||||
}
|
}
|
||||||
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||||
|
@ -92,7 +84,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
if config.GetBool("sshd.enabled", false) {
|
if config.GetBool("sshd.enabled", false) {
|
||||||
err = configSSH(ssh, config)
|
err = configSSH(ssh, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Error while configuring the sshd", nil, err)
|
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,7 +121,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Failed to get a tun/tap device", nil, err)
|
return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,28 +132,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
if !configTest {
|
if !configTest {
|
||||||
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
|
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Failed to open udp listener", nil, err)
|
return nil, NewContextualError("Failed to open udp listener", nil, err)
|
||||||
}
|
}
|
||||||
udpServer.reloadConfig(config)
|
udpServer.reloadConfig(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
sigChan := make(chan os.Signal)
|
|
||||||
killChan := make(chan CommandRequest)
|
|
||||||
if commandChan != nil {
|
|
||||||
go func() {
|
|
||||||
cmd := CommandRequest{}
|
|
||||||
for {
|
|
||||||
cmd = <-commandChan
|
|
||||||
switch cmd.Command {
|
|
||||||
case "rebind":
|
|
||||||
udpServer.Rebind()
|
|
||||||
case "exit":
|
|
||||||
killChan <- cmd
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up my internal host map
|
// Set up my internal host map
|
||||||
var preferredRanges []*net.IPNet
|
var preferredRanges []*net.IPNet
|
||||||
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
|
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
|
||||||
|
@ -170,7 +145,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
for _, rawPreferredRange := range rawPreferredRanges {
|
for _, rawPreferredRange := range rawPreferredRanges {
|
||||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Failed to parse preferred ranges", nil, err)
|
return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||||
}
|
}
|
||||||
preferredRanges = append(preferredRanges, preferredRange)
|
preferredRanges = append(preferredRanges, preferredRange)
|
||||||
}
|
}
|
||||||
|
@ -183,7 +158,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
if rawLocalRange != "" {
|
if rawLocalRange != "" {
|
||||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Failed to parse local_range", nil, err)
|
return nil, NewContextualError("Failed to parse local_range", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the entry for local_range was already specified in
|
// Check if the entry for local_range was already specified in
|
||||||
|
@ -223,7 +198,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
if port == 0 && !configTest {
|
if port == 0 && !configTest {
|
||||||
uPort, err := udpServer.LocalAddr()
|
uPort, err := udpServer.LocalAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Failed to get listening port", nil, err)
|
return nil, NewContextualError("Failed to get listening port", nil, err)
|
||||||
}
|
}
|
||||||
port = int(uPort.Port)
|
port = int(uPort.Port)
|
||||||
}
|
}
|
||||||
|
@ -240,10 +215,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
for i, host := range rawLighthouseHosts {
|
for i, host := range rawLighthouseHosts {
|
||||||
ip := net.ParseIP(host)
|
ip := net.ParseIP(host)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||||
}
|
}
|
||||||
if !tunCidr.Contains(ip) {
|
if !tunCidr.Contains(ip) {
|
||||||
return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||||
}
|
}
|
||||||
lighthouseHosts[i] = ip2int(ip)
|
lighthouseHosts[i] = ip2int(ip)
|
||||||
}
|
}
|
||||||
|
@ -263,13 +238,13 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
|
|
||||||
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
|
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||||
}
|
}
|
||||||
lightHouse.SetRemoteAllowList(remoteAllowList)
|
lightHouse.SetRemoteAllowList(remoteAllowList)
|
||||||
|
|
||||||
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
|
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||||
}
|
}
|
||||||
lightHouse.SetLocalAllowList(localAllowList)
|
lightHouse.SetLocalAllowList(localAllowList)
|
||||||
|
|
||||||
|
@ -277,7 +252,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
||||||
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
|
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
|
||||||
if !tunCidr.Contains(vpnIp) {
|
if !tunCidr.Contains(vpnIp) {
|
||||||
return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||||
}
|
}
|
||||||
vals, ok := v.([]interface{})
|
vals, ok := v.([]interface{})
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -288,7 +263,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
ip := addr.IP
|
ip := addr.IP
|
||||||
port, err := strconv.Atoi(parts[1])
|
port, err := strconv.Atoi(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||||
}
|
}
|
||||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||||
}
|
}
|
||||||
|
@ -301,7 +276,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
ip := addr.IP
|
ip := addr.IP
|
||||||
port, err := strconv.Atoi(parts[1])
|
port, err := strconv.Atoi(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||||
}
|
}
|
||||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||||
}
|
}
|
||||||
|
@ -354,7 +329,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
|
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
|
||||||
DropMulticast: config.GetBool("tun.drop_multicast", false),
|
DropMulticast: config.GetBool("tun.drop_multicast", false),
|
||||||
UDPBatchSize: config.GetInt("listen.batch", 64),
|
UDPBatchSize: config.GetInt("listen.batch", 64),
|
||||||
|
udpQueues: udpQueues,
|
||||||
|
tunQueues: config.GetInt("tun.routines", 1),
|
||||||
MessageMetrics: messageMetrics,
|
MessageMetrics: messageMetrics,
|
||||||
|
version: buildVersion,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch ifConfig.Cipher {
|
switch ifConfig.Cipher {
|
||||||
|
@ -363,14 +341,14 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
case "chachapoly":
|
case "chachapoly":
|
||||||
noiseEndianness = binary.LittleEndian
|
noiseEndianness = binary.LittleEndian
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
|
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ifce *Interface
|
var ifce *Interface
|
||||||
if !configTest {
|
if !configTest {
|
||||||
ifce, err = NewInterface(ifConfig)
|
ifce, err = NewInterface(ifConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize interface: %s", err)
|
return nil, fmt.Errorf("failed to initialize interface: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ifce.RegisterConfigChangeCallbacks(config)
|
ifce.RegisterConfigChangeCallbacks(config)
|
||||||
|
@ -381,18 +359,17 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
|
|
||||||
err = startStats(config, configTest)
|
err = startStats(config, configTest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewContextualError("Failed to start stats emitter", nil, err)
|
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if configTest {
|
if configTest {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: check if we _should_ be emitting stats
|
//TODO: check if we _should_ be emitting stats
|
||||||
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
||||||
|
|
||||||
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||||
ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
|
|
||||||
|
|
||||||
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
||||||
if amLighthouse && serveDns {
|
if amLighthouse && serveDns {
|
||||||
|
@ -400,47 +377,5 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
||||||
go dnsMain(hostMap, config)
|
go dnsMain(hostMap, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
if block {
|
return &Control{ifce, l}, nil
|
||||||
// Just sit here and be friendly, main thread.
|
|
||||||
shutdownBlock(ifce, sigChan, killChan)
|
|
||||||
} else {
|
|
||||||
// Even though we aren't blocking we still want to shutdown gracefully
|
|
||||||
go shutdownBlock(ifce, sigChan, killChan)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) {
|
|
||||||
var cmd CommandRequest
|
|
||||||
var sig string
|
|
||||||
|
|
||||||
signal.Notify(sigChan, syscall.SIGTERM)
|
|
||||||
signal.Notify(sigChan, syscall.SIGINT)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case rawSig := <-sigChan:
|
|
||||||
sig = rawSig.String()
|
|
||||||
case cmd = <-killChan:
|
|
||||||
sig = "controlling app"
|
|
||||||
}
|
|
||||||
|
|
||||||
l.WithField("signal", sig).Info("Caught signal, shutting down")
|
|
||||||
|
|
||||||
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
|
|
||||||
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
|
||||||
ifce.hostMap.Lock()
|
|
||||||
for _, h := range ifce.hostMap.Hosts {
|
|
||||||
if h.ConnectionState.ready {
|
|
||||||
ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
|
||||||
l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
|
|
||||||
Debug("Sending close tunnel message")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ifce.hostMap.Unlock()
|
|
||||||
|
|
||||||
l.WithField("signal", sig).Info("Goodbye")
|
|
||||||
select {
|
|
||||||
case cmd.Callback <- nil:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() {
|
func (u *udpConn) Rebind() error {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() {
|
func (u *udpConn) Rebind() error {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
|
||||||
return ua.IP.Equal(t.IP) && ua.Port == t.Port
|
return ua.IP.Equal(t.IP) && ua.Port == t.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ua *udpAddr) Copy() udpAddr {
|
||||||
|
nu := udpAddr{net.UDPAddr{
|
||||||
|
Port: ua.Port,
|
||||||
|
Zone: ua.Zone,
|
||||||
|
IP: make(net.IP, len(ua.IP)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
copy(nu.IP, ua.IP)
|
||||||
|
return nu
|
||||||
|
}
|
||||||
|
|
||||||
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
||||||
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
|
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
|
||||||
return err
|
return err
|
||||||
|
|
15
udp_linux.go
15
udp_linux.go
|
@ -89,8 +89,12 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
||||||
return &udpConn{sysFd: fd}, err
|
return &udpConn{sysFd: fd}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() {
|
func (u *udpConn) Rebind() error {
|
||||||
return
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ua *udpAddr) Copy() udpAddr {
|
||||||
|
return *ua
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) SetRecvBuffer(n int) error {
|
func (u *udpConn) SetRecvBuffer(n int) error {
|
||||||
|
@ -282,13 +286,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
|
||||||
return ua.IP == t.IP && ua.Port == t.Port
|
return ua.IP == t.IP && ua.Port == t.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ua *udpAddr) Copy() *udpAddr {
|
|
||||||
return &udpAddr{
|
|
||||||
Port: ua.Port,
|
|
||||||
IP: ua.IP,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ua *udpAddr) String() string {
|
func (ua *udpAddr) String() string {
|
||||||
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
|
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() {
|
func (u *udpConn) Rebind() error {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
|
||||||
|
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
|
||||||
|
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
|
||||||
|
v1 := reflect.ValueOf(a)
|
||||||
|
v2 := reflect.ValueOf(b)
|
||||||
|
|
||||||
|
if !assert.Equal(t, v1.Type(), v2.Type()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
traverseDeepCopy(t, v1, v2, v1.Type().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
|
||||||
|
switch v1.Kind() {
|
||||||
|
case reflect.Array:
|
||||||
|
for i := 0; i < v1.Len(); i++ {
|
||||||
|
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
if v1.IsNil() || v2.IsNil() {
|
||||||
|
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// A slice with cap 0
|
||||||
|
if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
v1c := v1.Cap()
|
||||||
|
v2c := v2.Cap()
|
||||||
|
if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
|
||||||
|
return assert.Fail(t, "", "%s share some underlying memory", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < v1.Len(); i++ {
|
||||||
|
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
if v1.IsNil() || v2.IsNil() {
|
||||||
|
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
|
||||||
|
}
|
||||||
|
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
|
||||||
|
|
||||||
|
case reflect.Ptr:
|
||||||
|
local := reflect.ValueOf(time.Local).Pointer()
|
||||||
|
if local == v1.Pointer() && local == v2.Pointer() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
for i, n := 0, v1.NumField(); i < n; i++ {
|
||||||
|
if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
if v1.IsNil() || v2.IsNil() {
|
||||||
|
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, k := range v1.MapKeys() {
|
||||||
|
val1 := v1.MapIndex(k)
|
||||||
|
val2 := v2.MapIndex(k)
|
||||||
|
if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
|
||||||
|
default:
|
||||||
|
if v1.CanInterface() && v2.CanInterface() {
|
||||||
|
return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
|
||||||
|
e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
|
||||||
|
|
||||||
|
return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue