Move util to test, contextual errors to util (#575)
This commit is contained in:
		| @@ -7,12 +7,12 @@ import ( | ||||
|  | ||||
| 	"github.com/slackhq/nebula/cidr" | ||||
| 	"github.com/slackhq/nebula/config" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestNewAllowListFromConfig(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	c := config.NewC(l) | ||||
| 	c.Settings["allowlist"] = map[interface{}]interface{}{ | ||||
| 		"192.168.0.0": true, | ||||
|   | ||||
							
								
								
									
										10
									
								
								bits_test.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								bits_test.go
									
									
									
									
									
								
							| @@ -3,12 +3,12 @@ package nebula | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestBits(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	b := NewBits(10) | ||||
|  | ||||
| 	// make sure it is the right size | ||||
| @@ -76,7 +76,7 @@ func TestBits(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestBitsDupeCounter(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	b := NewBits(10) | ||||
| 	b.lostCounter.Clear() | ||||
| 	b.dupeCounter.Clear() | ||||
| @@ -101,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestBitsOutOfWindowCounter(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	b := NewBits(10) | ||||
| 	b.lostCounter.Clear() | ||||
| 	b.dupeCounter.Clear() | ||||
| @@ -131,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestBitsLostCounter(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	b := NewBits(10) | ||||
| 	b.lostCounter.Clear() | ||||
| 	b.dupeCounter.Clear() | ||||
|   | ||||
| @@ -9,7 +9,7 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/golang/protobuf/proto" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"golang.org/x/crypto/curve25519" | ||||
| 	"golang.org/x/crypto/ed25519" | ||||
| @@ -752,7 +752,7 @@ func TestNebulaCertificate_Copy(t *testing.T) { | ||||
| 	assert.Nil(t, err) | ||||
| 	cc := c.Copy() | ||||
|  | ||||
| 	util.AssertDeepCopyEqual(t, c, cc) | ||||
| 	test.AssertDeepCopyEqual(t, c, cc) | ||||
| } | ||||
|  | ||||
| func TestUnmarshalNebulaCertificate(t *testing.T) { | ||||
|   | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/slackhq/nebula" | ||||
| 	"github.com/slackhq/nebula/config" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| ) | ||||
|  | ||||
| // A version string that can be set with | ||||
| @@ -60,7 +61,7 @@ func main() { | ||||
| 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil) | ||||
|  | ||||
| 	switch v := err.(type) { | ||||
| 	case nebula.ContextualError: | ||||
| 	case util.ContextualError: | ||||
| 		v.Log(l) | ||||
| 		os.Exit(1) | ||||
| 	case error: | ||||
|   | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/slackhq/nebula" | ||||
| 	"github.com/slackhq/nebula/config" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| ) | ||||
|  | ||||
| // A version string that can be set with | ||||
| @@ -54,7 +55,7 @@ func main() { | ||||
| 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil) | ||||
|  | ||||
| 	switch v := err.(type) { | ||||
| 	case nebula.ContextualError: | ||||
| 	case util.ContextualError: | ||||
| 		v.Log(l) | ||||
| 		os.Exit(1) | ||||
| 	case error: | ||||
|   | ||||
| @@ -7,12 +7,12 @@ import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestConfig_Load(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	dir, err := ioutil.TempDir("", "config-test") | ||||
| 	// invalid yaml | ||||
| 	c := NewC(l) | ||||
| @@ -42,7 +42,7 @@ func TestConfig_Load(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestConfig_Get(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	// test simple type | ||||
| 	c := NewC(l) | ||||
| 	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} | ||||
| @@ -58,14 +58,14 @@ func TestConfig_Get(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestConfig_GetStringSlice(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	c := NewC(l) | ||||
| 	c.Settings["slice"] = []interface{}{"one", "two"} | ||||
| 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) | ||||
| } | ||||
|  | ||||
| func TestConfig_GetBool(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	c := NewC(l) | ||||
| 	c.Settings["bool"] = true | ||||
| 	assert.Equal(t, true, c.GetBool("bool", false)) | ||||
| @@ -93,7 +93,7 @@ func TestConfig_GetBool(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestConfig_HasChanged(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	// No reload has occurred, return false | ||||
| 	c := NewC(l) | ||||
| 	c.Settings["test"] = "hi" | ||||
| @@ -115,7 +115,7 @@ func TestConfig_HasChanged(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestConfig_ReloadConfig(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	done := make(chan bool, 1) | ||||
| 	dir, err := ioutil.TempDir("", "config-test") | ||||
| 	assert.Nil(t, err) | ||||
|   | ||||
| @@ -11,15 +11,15 @@ import ( | ||||
| 	"github.com/flynn/noise" | ||||
| 	"github.com/slackhq/nebula/cert" | ||||
| 	"github.com/slackhq/nebula/iputil" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/slackhq/nebula/udp" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| var vpnIp iputil.VpnIp | ||||
|  | ||||
| func Test_NewConnectionManagerTest(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") | ||||
| 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||
| 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | ||||
| @@ -89,7 +89,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func Test_NewConnectionManagerTest2(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") | ||||
| 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||
| 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | ||||
| @@ -164,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { | ||||
| // Disconnect only if disconnectInvalid: true is set. | ||||
| func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { | ||||
| 	now := time.Now() | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	ipNet := net.IPNet{ | ||||
| 		IP:   net.IPv4(172, 1, 1, 2), | ||||
| 		Mask: net.IPMask{255, 255, 255, 0}, | ||||
|   | ||||
| @@ -9,13 +9,13 @@ import ( | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/slackhq/nebula/cert" | ||||
| 	"github.com/slackhq/nebula/iputil" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/slackhq/nebula/udp" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestControl_GetHostInfoByVpnIp(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object | ||||
| 	// To properly ensure we are not exposing core memory to the caller | ||||
| 	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) | ||||
| @@ -94,7 +94,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { | ||||
|  | ||||
| 	// Make sure we don't have any unexpected fields | ||||
| 	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi) | ||||
| 	util.AssertDeepCopyEqual(t, &expectedInfo, thi) | ||||
| 	test.AssertDeepCopyEqual(t, &expectedInfo, thi) | ||||
|  | ||||
| 	// Make sure we don't panic if the host info doesn't have a cert yet | ||||
| 	assert.NotPanics(t, func() { | ||||
|   | ||||
| @@ -14,12 +14,12 @@ import ( | ||||
| 	"github.com/slackhq/nebula/config" | ||||
| 	"github.com/slackhq/nebula/firewall" | ||||
| 	"github.com/slackhq/nebula/iputil" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestNewFirewall(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	c := &cert.NebulaCertificate{} | ||||
| 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) | ||||
| 	conntrack := fw.Conntrack | ||||
| @@ -58,7 +58,7 @@ func TestNewFirewall(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestFirewall_AddRule(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	l.SetOutput(ob) | ||||
|  | ||||
| @@ -133,7 +133,7 @@ func TestFirewall_AddRule(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestFirewall_Drop(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	l.SetOutput(ob) | ||||
|  | ||||
| @@ -308,7 +308,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { | ||||
| } | ||||
|  | ||||
| func TestFirewall_Drop2(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	l.SetOutput(ob) | ||||
|  | ||||
| @@ -367,7 +367,7 @@ func TestFirewall_Drop2(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestFirewall_Drop3(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	l.SetOutput(ob) | ||||
|  | ||||
| @@ -453,7 +453,7 @@ func TestFirewall_Drop3(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestFirewall_DropConntrackReload(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	l.SetOutput(ob) | ||||
|  | ||||
| @@ -635,7 +635,7 @@ func Test_parsePort(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestNewFirewallFromConfig(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	// Test a bad rule definition | ||||
| 	c := &cert.NebulaCertificate{} | ||||
| 	conf := config.NewC(l) | ||||
| @@ -685,7 +685,7 @@ func TestNewFirewallFromConfig(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestAddFirewallRulesFromConfig(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	// Test adding tcp rule | ||||
| 	conf := config.NewC(l) | ||||
| 	mf := &mockFirewall{} | ||||
| @@ -849,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestFirewall_convertRule(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	ob := &bytes.Buffer{} | ||||
| 	l.SetOutput(ob) | ||||
|  | ||||
|   | ||||
| @@ -7,13 +7,13 @@ import ( | ||||
|  | ||||
| 	"github.com/slackhq/nebula/header" | ||||
| 	"github.com/slackhq/nebula/iputil" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/slackhq/nebula/udp" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func Test_NewHandshakeManagerVpnIp(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||
| 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||
| 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | ||||
| @@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func Test_NewHandshakeManagerTrigger(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||
| 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||
| 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | ||||
|   | ||||
| @@ -8,8 +8,8 @@ import ( | ||||
| 	"github.com/golang/protobuf/proto" | ||||
| 	"github.com/slackhq/nebula/header" | ||||
| 	"github.com/slackhq/nebula/iputil" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/slackhq/nebula/udp" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| @@ -46,7 +46,7 @@ func TestNewLhQuery(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func Test_lhStaticMapping(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	lh1 := "10.128.0.2" | ||||
| 	lh1IP := net.ParseIP(lh1) | ||||
|  | ||||
| @@ -67,7 +67,7 @@ func Test_lhStaticMapping(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func BenchmarkLighthouseHandleRequest(b *testing.B) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	lh1 := "10.128.0.2" | ||||
| 	lh1IP := net.ParseIP(lh1) | ||||
|  | ||||
| @@ -137,7 +137,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { | ||||
| } | ||||
|  | ||||
| func TestLighthouse_Memory(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
|  | ||||
| 	myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} | ||||
| 	myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} | ||||
| @@ -266,7 +266,7 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, | ||||
|  | ||||
| //TODO: this is a RemoteList test | ||||
| //func Test_lhRemoteAllowList(t *testing.T) { | ||||
| //	l := NewTestLogger() | ||||
| //	l := NewLogger() | ||||
| //	c := NewConfig(l) | ||||
| //	c.Settings["remoteallowlist"] = map[interface{}]interface{}{ | ||||
| //		"10.20.0.0/12": false, | ||||
|   | ||||
							
								
								
									
										33
									
								
								logger.go
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								logger.go
									
									
									
									
									
								
							| @@ -1,7 +1,6 @@ | ||||
| package nebula | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -10,38 +9,6 @@ import ( | ||||
| 	"github.com/slackhq/nebula/config" | ||||
| ) | ||||
|  | ||||
| type ContextualError struct { | ||||
| 	RealError error | ||||
| 	Fields    map[string]interface{} | ||||
| 	Context   string | ||||
| } | ||||
|  | ||||
| func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { | ||||
| 	return ContextualError{Context: msg, Fields: fields, RealError: realError} | ||||
| } | ||||
|  | ||||
| func (ce ContextualError) Error() string { | ||||
| 	if ce.RealError == nil { | ||||
| 		return ce.Context | ||||
| 	} | ||||
| 	return ce.RealError.Error() | ||||
| } | ||||
|  | ||||
| func (ce ContextualError) Unwrap() error { | ||||
| 	if ce.RealError == nil { | ||||
| 		return errors.New(ce.Context) | ||||
| 	} | ||||
| 	return ce.RealError | ||||
| } | ||||
|  | ||||
| func (ce *ContextualError) Log(lr *logrus.Logger) { | ||||
| 	if ce.RealError != nil { | ||||
| 		lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) | ||||
| 	} else { | ||||
| 		lr.WithFields(ce.Fields).Error(ce.Context) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func configLogger(l *logrus.Logger, c *config.C) error { | ||||
| 	// set up our logging level | ||||
| 	logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) | ||||
|   | ||||
							
								
								
									
										43
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								main.go
									
									
									
									
									
								
							| @@ -12,6 +12,7 @@ import ( | ||||
| 	"github.com/slackhq/nebula/iputil" | ||||
| 	"github.com/slackhq/nebula/sshd" | ||||
| 	"github.com/slackhq/nebula/udp" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"gopkg.in/yaml.v2" | ||||
| ) | ||||
|  | ||||
| @@ -44,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
|  | ||||
| 	err := configLogger(l, c) | ||||
| 	if err != nil { | ||||
| 		return nil, NewContextualError("Failed to configure the logger", nil, err) | ||||
| 		return nil, util.NewContextualError("Failed to configure the logger", nil, err) | ||||
| 	} | ||||
|  | ||||
| 	c.RegisterReloadCallback(func(c *config.C) { | ||||
| @@ -57,20 +58,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 	caPool, err := loadCAFromConfig(l, c) | ||||
| 	if err != nil { | ||||
| 		//The errors coming out of loadCA are already nicely formatted | ||||
| 		return nil, NewContextualError("Failed to load ca from config", nil, err) | ||||
| 		return nil, util.NewContextualError("Failed to load ca from config", nil, err) | ||||
| 	} | ||||
| 	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") | ||||
|  | ||||
| 	cs, err := NewCertStateFromConfig(c) | ||||
| 	if err != nil { | ||||
| 		//The errors coming out of NewCertStateFromConfig are already nicely formatted | ||||
| 		return nil, NewContextualError("Failed to load certificate from config", nil, err) | ||||
| 		return nil, util.NewContextualError("Failed to load certificate from config", nil, err) | ||||
| 	} | ||||
| 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate") | ||||
|  | ||||
| 	fw, err := NewFirewallFromConfig(l, cs.certificate, c) | ||||
| 	if err != nil { | ||||
| 		return nil, NewContextualError("Error while loading firewall rules", nil, err) | ||||
| 		return nil, util.NewContextualError("Error while loading firewall rules", nil, err) | ||||
| 	} | ||||
| 	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") | ||||
|  | ||||
| @@ -78,11 +79,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 	tunCidr := cs.certificate.Details.Ips[0] | ||||
| 	routes, err := parseRoutes(c, tunCidr) | ||||
| 	if err != nil { | ||||
| 		return nil, NewContextualError("Could not parse tun.routes", nil, err) | ||||
| 		return nil, util.NewContextualError("Could not parse tun.routes", nil, err) | ||||
| 	} | ||||
| 	unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) | ||||
| 	if err != nil { | ||||
| 		return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err) | ||||
| 		return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) | ||||
| 	} | ||||
|  | ||||
| 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) | ||||
| @@ -91,7 +92,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 	if c.GetBool("sshd.enabled", false) { | ||||
| 		sshStart, err = configSSH(l, ssh, c) | ||||
| 		if err != nil { | ||||
| 			return nil, NewContextualError("Error while configuring the sshd", nil, err) | ||||
| 			return nil, util.NewContextualError("Error while configuring the sshd", nil, err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -167,7 +168,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 		} | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return nil, NewContextualError("Failed to get a tun/tap device", nil, err) | ||||
| 			return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -185,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 		for i := 0; i < routines; i++ { | ||||
| 			udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64)) | ||||
| 			if err != nil { | ||||
| 				return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err) | ||||
| 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) | ||||
| 			} | ||||
| 			udpServer.ReloadConfig(c) | ||||
| 			udpConns[i] = udpServer | ||||
| @@ -194,7 +195,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 			if port == 0 { | ||||
| 				uPort, err := udpServer.LocalAddr() | ||||
| 				if err != nil { | ||||
| 					return nil, NewContextualError("Failed to get listening port", nil, err) | ||||
| 					return nil, util.NewContextualError("Failed to get listening port", nil, err) | ||||
| 				} | ||||
| 				port = int(uPort.Port) | ||||
| 			} | ||||
| @@ -209,7 +210,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 		for _, rawPreferredRange := range rawPreferredRanges { | ||||
| 			_, preferredRange, err := net.ParseCIDR(rawPreferredRange) | ||||
| 			if err != nil { | ||||
| 				return nil, NewContextualError("Failed to parse preferred ranges", nil, err) | ||||
| 				return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err) | ||||
| 			} | ||||
| 			preferredRanges = append(preferredRanges, preferredRange) | ||||
| 		} | ||||
| @@ -222,7 +223,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 	if rawLocalRange != "" { | ||||
| 		_, localRange, err := net.ParseCIDR(rawLocalRange) | ||||
| 		if err != nil { | ||||
| 			return nil, NewContextualError("Failed to parse local_range", nil, err) | ||||
| 			return nil, util.NewContextualError("Failed to parse local_range", nil, err) | ||||
| 		} | ||||
|  | ||||
| 		// Check if the entry for local_range was already specified in | ||||
| @@ -261,7 +262,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
|  | ||||
| 	// fatal if am_lighthouse is enabled but we are using an ephemeral port | ||||
| 	if amLighthouse && (c.GetInt("listen.port", 0) == 0) { | ||||
| 		return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) | ||||
| 		return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) | ||||
| 	} | ||||
|  | ||||
| 	// warn if am_lighthouse is enabled but upstream lighthouses exists | ||||
| @@ -274,10 +275,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 	for i, host := range rawLighthouseHosts { | ||||
| 		ip := net.ParseIP(host) | ||||
| 		if ip == nil { | ||||
| 			return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) | ||||
| 			return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) | ||||
| 		} | ||||
| 		if !tunCidr.Contains(ip) { | ||||
| 			return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) | ||||
| 			return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) | ||||
| 		} | ||||
| 		lighthouseHosts[i] = iputil.Ip2VpnIp(ip) | ||||
| 	} | ||||
| @@ -298,13 +299,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
|  | ||||
| 	remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") | ||||
| 	if err != nil { | ||||
| 		return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) | ||||
| 		return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) | ||||
| 	} | ||||
| 	lightHouse.SetRemoteAllowList(remoteAllowList) | ||||
|  | ||||
| 	localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list") | ||||
| 	if err != nil { | ||||
| 		return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err) | ||||
| 		return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err) | ||||
| 	} | ||||
| 	lightHouse.SetLocalAllowList(localAllowList) | ||||
|  | ||||
| @@ -313,21 +314,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 		ip := net.ParseIP(fmt.Sprintf("%v", k)) | ||||
| 		vpnIp := iputil.Ip2VpnIp(ip) | ||||
| 		if !tunCidr.Contains(ip) { | ||||
| 			return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) | ||||
| 			return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) | ||||
| 		} | ||||
| 		vals, ok := v.([]interface{}) | ||||
| 		if ok { | ||||
| 			for _, v := range vals { | ||||
| 				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) | ||||
| 				if err != nil { | ||||
| 					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | ||||
| 					return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | ||||
| 				} | ||||
| 				lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) | ||||
| 			} | ||||
| 		} else { | ||||
| 			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) | ||||
| 			if err != nil { | ||||
| 				return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | ||||
| 				return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | ||||
| 			} | ||||
| 			lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) | ||||
| 		} | ||||
| @@ -426,7 +427,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | ||||
| 	statsStart, err := startStats(l, c, buildVersion, configTest) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, NewContextualError("Failed to start stats emitter", nil, err) | ||||
| 		return nil, util.NewContextualError("Failed to start stats emitter", nil, err) | ||||
| 	} | ||||
|  | ||||
| 	if configTest { | ||||
|   | ||||
| @@ -5,12 +5,12 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/slackhq/nebula/config" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestNewPunchyFromConfig(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	c := config.NewC(l) | ||||
|  | ||||
| 	// Test defaults | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| package util | ||||
| package test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| @@ -1,4 +1,4 @@ | ||||
| package util | ||||
| package test | ||||
| 
 | ||||
| import ( | ||||
| 	"io/ioutil" | ||||
| @@ -7,7 +7,7 @@ import ( | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
| 
 | ||||
| func NewTestLogger() *logrus.Logger { | ||||
| func NewLogger() *logrus.Logger { | ||||
| 	l := logrus.New() | ||||
| 
 | ||||
| 	v := os.Getenv("TEST_LOGS") | ||||
| @@ -6,12 +6,12 @@ import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/slackhq/nebula/config" | ||||
| 	"github.com/slackhq/nebula/util" | ||||
| 	"github.com/slackhq/nebula/test" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func Test_parseRoutes(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	c := config.NewC(l) | ||||
| 	_, n, _ := net.ParseCIDR("10.0.0.0/24") | ||||
|  | ||||
| @@ -107,7 +107,7 @@ func Test_parseRoutes(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func Test_parseUnsafeRoutes(t *testing.T) { | ||||
| 	l := util.NewTestLogger() | ||||
| 	l := test.NewLogger() | ||||
| 	c := config.NewC(l) | ||||
| 	_, n, _ := net.ParseCIDR("10.0.0.0/24") | ||||
|  | ||||
|   | ||||
							
								
								
									
										39
									
								
								util/error.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								util/error.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| package util | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
|  | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
|  | ||||
| type ContextualError struct { | ||||
| 	RealError error | ||||
| 	Fields    map[string]interface{} | ||||
| 	Context   string | ||||
| } | ||||
|  | ||||
| func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { | ||||
| 	return ContextualError{Context: msg, Fields: fields, RealError: realError} | ||||
| } | ||||
|  | ||||
| func (ce ContextualError) Error() string { | ||||
| 	if ce.RealError == nil { | ||||
| 		return ce.Context | ||||
| 	} | ||||
| 	return ce.RealError.Error() | ||||
| } | ||||
|  | ||||
| func (ce ContextualError) Unwrap() error { | ||||
| 	if ce.RealError == nil { | ||||
| 		return errors.New(ce.Context) | ||||
| 	} | ||||
| 	return ce.RealError | ||||
| } | ||||
|  | ||||
| func (ce *ContextualError) Log(lr *logrus.Logger) { | ||||
| 	if ce.RealError != nil { | ||||
| 		lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) | ||||
| 	} else { | ||||
| 		lr.WithFields(ce.Fields).Error(ce.Context) | ||||
| 	} | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package nebula | ||||
| package util | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| @@ -8,6 +8,8 @@ import ( | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| type m map[string]interface{} | ||||
| 
 | ||||
| type TestLogWriter struct { | ||||
| 	Logs []string | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Nate Brown
					Nate Brown