Move util to test, contextual errors to util (#575)
This commit is contained in:
		| @@ -7,12 +7,12 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/slackhq/nebula/cidr" | 	"github.com/slackhq/nebula/cidr" | ||||||
| 	"github.com/slackhq/nebula/config" | 	"github.com/slackhq/nebula/config" | ||||||
| 	"github.com/slackhq/nebula/util" | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestNewAllowListFromConfig(t *testing.T) { | func TestNewAllowListFromConfig(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	c := config.NewC(l) | 	c := config.NewC(l) | ||||||
| 	c.Settings["allowlist"] = map[interface{}]interface{}{ | 	c.Settings["allowlist"] = map[interface{}]interface{}{ | ||||||
| 		"192.168.0.0": true, | 		"192.168.0.0": true, | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								bits_test.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								bits_test.go
									
									
									
									
									
								
							| @@ -3,12 +3,12 @@ package nebula | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/slackhq/nebula/util" | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestBits(t *testing.T) { | func TestBits(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	b := NewBits(10) | 	b := NewBits(10) | ||||||
|  |  | ||||||
| 	// make sure it is the right size | 	// make sure it is the right size | ||||||
| @@ -76,7 +76,7 @@ func TestBits(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestBitsDupeCounter(t *testing.T) { | func TestBitsDupeCounter(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	b := NewBits(10) | 	b := NewBits(10) | ||||||
| 	b.lostCounter.Clear() | 	b.lostCounter.Clear() | ||||||
| 	b.dupeCounter.Clear() | 	b.dupeCounter.Clear() | ||||||
| @@ -101,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestBitsOutOfWindowCounter(t *testing.T) { | func TestBitsOutOfWindowCounter(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	b := NewBits(10) | 	b := NewBits(10) | ||||||
| 	b.lostCounter.Clear() | 	b.lostCounter.Clear() | ||||||
| 	b.dupeCounter.Clear() | 	b.dupeCounter.Clear() | ||||||
| @@ -131,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestBitsLostCounter(t *testing.T) { | func TestBitsLostCounter(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	b := NewBits(10) | 	b := NewBits(10) | ||||||
| 	b.lostCounter.Clear() | 	b.lostCounter.Clear() | ||||||
| 	b.dupeCounter.Clear() | 	b.dupeCounter.Clear() | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/golang/protobuf/proto" | 	"github.com/golang/protobuf/proto" | ||||||
| 	"github.com/slackhq/nebula/util" | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	"golang.org/x/crypto/curve25519" | 	"golang.org/x/crypto/curve25519" | ||||||
| 	"golang.org/x/crypto/ed25519" | 	"golang.org/x/crypto/ed25519" | ||||||
| @@ -752,7 +752,7 @@ func TestNebulaCertificate_Copy(t *testing.T) { | |||||||
| 	assert.Nil(t, err) | 	assert.Nil(t, err) | ||||||
| 	cc := c.Copy() | 	cc := c.Copy() | ||||||
|  |  | ||||||
| 	util.AssertDeepCopyEqual(t, c, cc) | 	test.AssertDeepCopyEqual(t, c, cc) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestUnmarshalNebulaCertificate(t *testing.T) { | func TestUnmarshalNebulaCertificate(t *testing.T) { | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| 	"github.com/sirupsen/logrus" | 	"github.com/sirupsen/logrus" | ||||||
| 	"github.com/slackhq/nebula" | 	"github.com/slackhq/nebula" | ||||||
| 	"github.com/slackhq/nebula/config" | 	"github.com/slackhq/nebula/config" | ||||||
|  | 	"github.com/slackhq/nebula/util" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // A version string that can be set with | // A version string that can be set with | ||||||
| @@ -60,7 +61,7 @@ func main() { | |||||||
| 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil) | 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil) | ||||||
|  |  | ||||||
| 	switch v := err.(type) { | 	switch v := err.(type) { | ||||||
| 	case nebula.ContextualError: | 	case util.ContextualError: | ||||||
| 		v.Log(l) | 		v.Log(l) | ||||||
| 		os.Exit(1) | 		os.Exit(1) | ||||||
| 	case error: | 	case error: | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| 	"github.com/sirupsen/logrus" | 	"github.com/sirupsen/logrus" | ||||||
| 	"github.com/slackhq/nebula" | 	"github.com/slackhq/nebula" | ||||||
| 	"github.com/slackhq/nebula/config" | 	"github.com/slackhq/nebula/config" | ||||||
|  | 	"github.com/slackhq/nebula/util" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // A version string that can be set with | // A version string that can be set with | ||||||
| @@ -54,7 +55,7 @@ func main() { | |||||||
| 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil) | 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil) | ||||||
|  |  | ||||||
| 	switch v := err.(type) { | 	switch v := err.(type) { | ||||||
| 	case nebula.ContextualError: | 	case util.ContextualError: | ||||||
| 		v.Log(l) | 		v.Log(l) | ||||||
| 		os.Exit(1) | 		os.Exit(1) | ||||||
| 	case error: | 	case error: | ||||||
|   | |||||||
| @@ -7,12 +7,12 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/slackhq/nebula/util" | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestConfig_Load(t *testing.T) { | func TestConfig_Load(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	dir, err := ioutil.TempDir("", "config-test") | 	dir, err := ioutil.TempDir("", "config-test") | ||||||
| 	// invalid yaml | 	// invalid yaml | ||||||
| 	c := NewC(l) | 	c := NewC(l) | ||||||
| @@ -42,7 +42,7 @@ func TestConfig_Load(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestConfig_Get(t *testing.T) { | func TestConfig_Get(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	// test simple type | 	// test simple type | ||||||
| 	c := NewC(l) | 	c := NewC(l) | ||||||
| 	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} | 	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} | ||||||
| @@ -58,14 +58,14 @@ func TestConfig_Get(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestConfig_GetStringSlice(t *testing.T) { | func TestConfig_GetStringSlice(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	c := NewC(l) | 	c := NewC(l) | ||||||
| 	c.Settings["slice"] = []interface{}{"one", "two"} | 	c.Settings["slice"] = []interface{}{"one", "two"} | ||||||
| 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) | 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestConfig_GetBool(t *testing.T) { | func TestConfig_GetBool(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	c := NewC(l) | 	c := NewC(l) | ||||||
| 	c.Settings["bool"] = true | 	c.Settings["bool"] = true | ||||||
| 	assert.Equal(t, true, c.GetBool("bool", false)) | 	assert.Equal(t, true, c.GetBool("bool", false)) | ||||||
| @@ -93,7 +93,7 @@ func TestConfig_GetBool(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestConfig_HasChanged(t *testing.T) { | func TestConfig_HasChanged(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	// No reload has occurred, return false | 	// No reload has occurred, return false | ||||||
| 	c := NewC(l) | 	c := NewC(l) | ||||||
| 	c.Settings["test"] = "hi" | 	c.Settings["test"] = "hi" | ||||||
| @@ -115,7 +115,7 @@ func TestConfig_HasChanged(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestConfig_ReloadConfig(t *testing.T) { | func TestConfig_ReloadConfig(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	done := make(chan bool, 1) | 	done := make(chan bool, 1) | ||||||
| 	dir, err := ioutil.TempDir("", "config-test") | 	dir, err := ioutil.TempDir("", "config-test") | ||||||
| 	assert.Nil(t, err) | 	assert.Nil(t, err) | ||||||
|   | |||||||
| @@ -11,15 +11,15 @@ import ( | |||||||
| 	"github.com/flynn/noise" | 	"github.com/flynn/noise" | ||||||
| 	"github.com/slackhq/nebula/cert" | 	"github.com/slackhq/nebula/cert" | ||||||
| 	"github.com/slackhq/nebula/iputil" | 	"github.com/slackhq/nebula/iputil" | ||||||
|  | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/slackhq/nebula/udp" | 	"github.com/slackhq/nebula/udp" | ||||||
| 	"github.com/slackhq/nebula/util" |  | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var vpnIp iputil.VpnIp | var vpnIp iputil.VpnIp | ||||||
|  |  | ||||||
| func Test_NewConnectionManagerTest(t *testing.T) { | func Test_NewConnectionManagerTest(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") | 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") | ||||||
| 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||||
| 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | ||||||
| @@ -89,7 +89,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func Test_NewConnectionManagerTest2(t *testing.T) { | func Test_NewConnectionManagerTest2(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") | 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") | ||||||
| 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||||
| 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | ||||||
| @@ -164,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { | |||||||
| // Disconnect only if disconnectInvalid: true is set. | // Disconnect only if disconnectInvalid: true is set. | ||||||
| func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { | func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	ipNet := net.IPNet{ | 	ipNet := net.IPNet{ | ||||||
| 		IP:   net.IPv4(172, 1, 1, 2), | 		IP:   net.IPv4(172, 1, 1, 2), | ||||||
| 		Mask: net.IPMask{255, 255, 255, 0}, | 		Mask: net.IPMask{255, 255, 255, 0}, | ||||||
|   | |||||||
| @@ -9,13 +9,13 @@ import ( | |||||||
| 	"github.com/sirupsen/logrus" | 	"github.com/sirupsen/logrus" | ||||||
| 	"github.com/slackhq/nebula/cert" | 	"github.com/slackhq/nebula/cert" | ||||||
| 	"github.com/slackhq/nebula/iputil" | 	"github.com/slackhq/nebula/iputil" | ||||||
|  | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/slackhq/nebula/udp" | 	"github.com/slackhq/nebula/udp" | ||||||
| 	"github.com/slackhq/nebula/util" |  | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestControl_GetHostInfoByVpnIp(t *testing.T) { | func TestControl_GetHostInfoByVpnIp(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object | 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object | ||||||
| 	// To properly ensure we are not exposing core memory to the caller | 	// To properly ensure we are not exposing core memory to the caller | ||||||
| 	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) | 	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) | ||||||
| @@ -94,7 +94,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { | |||||||
|  |  | ||||||
| 	// Make sure we don't have any unexpected fields | 	// Make sure we don't have any unexpected fields | ||||||
| 	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi) | 	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi) | ||||||
| 	util.AssertDeepCopyEqual(t, &expectedInfo, thi) | 	test.AssertDeepCopyEqual(t, &expectedInfo, thi) | ||||||
|  |  | ||||||
| 	// Make sure we don't panic if the host info doesn't have a cert yet | 	// Make sure we don't panic if the host info doesn't have a cert yet | ||||||
| 	assert.NotPanics(t, func() { | 	assert.NotPanics(t, func() { | ||||||
|   | |||||||
| @@ -14,12 +14,12 @@ import ( | |||||||
| 	"github.com/slackhq/nebula/config" | 	"github.com/slackhq/nebula/config" | ||||||
| 	"github.com/slackhq/nebula/firewall" | 	"github.com/slackhq/nebula/firewall" | ||||||
| 	"github.com/slackhq/nebula/iputil" | 	"github.com/slackhq/nebula/iputil" | ||||||
| 	"github.com/slackhq/nebula/util" | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestNewFirewall(t *testing.T) { | func TestNewFirewall(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	c := &cert.NebulaCertificate{} | 	c := &cert.NebulaCertificate{} | ||||||
| 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) | 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) | ||||||
| 	conntrack := fw.Conntrack | 	conntrack := fw.Conntrack | ||||||
| @@ -58,7 +58,7 @@ func TestNewFirewall(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestFirewall_AddRule(t *testing.T) { | func TestFirewall_AddRule(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	ob := &bytes.Buffer{} | 	ob := &bytes.Buffer{} | ||||||
| 	l.SetOutput(ob) | 	l.SetOutput(ob) | ||||||
|  |  | ||||||
| @@ -133,7 +133,7 @@ func TestFirewall_AddRule(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestFirewall_Drop(t *testing.T) { | func TestFirewall_Drop(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	ob := &bytes.Buffer{} | 	ob := &bytes.Buffer{} | ||||||
| 	l.SetOutput(ob) | 	l.SetOutput(ob) | ||||||
|  |  | ||||||
| @@ -308,7 +308,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestFirewall_Drop2(t *testing.T) { | func TestFirewall_Drop2(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	ob := &bytes.Buffer{} | 	ob := &bytes.Buffer{} | ||||||
| 	l.SetOutput(ob) | 	l.SetOutput(ob) | ||||||
|  |  | ||||||
| @@ -367,7 +367,7 @@ func TestFirewall_Drop2(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestFirewall_Drop3(t *testing.T) { | func TestFirewall_Drop3(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	ob := &bytes.Buffer{} | 	ob := &bytes.Buffer{} | ||||||
| 	l.SetOutput(ob) | 	l.SetOutput(ob) | ||||||
|  |  | ||||||
| @@ -453,7 +453,7 @@ func TestFirewall_Drop3(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestFirewall_DropConntrackReload(t *testing.T) { | func TestFirewall_DropConntrackReload(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	ob := &bytes.Buffer{} | 	ob := &bytes.Buffer{} | ||||||
| 	l.SetOutput(ob) | 	l.SetOutput(ob) | ||||||
|  |  | ||||||
| @@ -635,7 +635,7 @@ func Test_parsePort(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestNewFirewallFromConfig(t *testing.T) { | func TestNewFirewallFromConfig(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	// Test a bad rule definition | 	// Test a bad rule definition | ||||||
| 	c := &cert.NebulaCertificate{} | 	c := &cert.NebulaCertificate{} | ||||||
| 	conf := config.NewC(l) | 	conf := config.NewC(l) | ||||||
| @@ -685,7 +685,7 @@ func TestNewFirewallFromConfig(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestAddFirewallRulesFromConfig(t *testing.T) { | func TestAddFirewallRulesFromConfig(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	// Test adding tcp rule | 	// Test adding tcp rule | ||||||
| 	conf := config.NewC(l) | 	conf := config.NewC(l) | ||||||
| 	mf := &mockFirewall{} | 	mf := &mockFirewall{} | ||||||
| @@ -849,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestFirewall_convertRule(t *testing.T) { | func TestFirewall_convertRule(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	ob := &bytes.Buffer{} | 	ob := &bytes.Buffer{} | ||||||
| 	l.SetOutput(ob) | 	l.SetOutput(ob) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,13 +7,13 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/slackhq/nebula/header" | 	"github.com/slackhq/nebula/header" | ||||||
| 	"github.com/slackhq/nebula/iputil" | 	"github.com/slackhq/nebula/iputil" | ||||||
|  | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/slackhq/nebula/udp" | 	"github.com/slackhq/nebula/udp" | ||||||
| 	"github.com/slackhq/nebula/util" |  | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func Test_NewHandshakeManagerVpnIp(t *testing.T) { | func Test_NewHandshakeManagerVpnIp(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") | 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||||
| 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||||
| 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | ||||||
| @@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func Test_NewHandshakeManagerTrigger(t *testing.T) { | func Test_NewHandshakeManagerTrigger(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") | 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||||
| 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") | ||||||
| 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24") | ||||||
|   | |||||||
| @@ -8,8 +8,8 @@ import ( | |||||||
| 	"github.com/golang/protobuf/proto" | 	"github.com/golang/protobuf/proto" | ||||||
| 	"github.com/slackhq/nebula/header" | 	"github.com/slackhq/nebula/header" | ||||||
| 	"github.com/slackhq/nebula/iputil" | 	"github.com/slackhq/nebula/iputil" | ||||||
|  | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/slackhq/nebula/udp" | 	"github.com/slackhq/nebula/udp" | ||||||
| 	"github.com/slackhq/nebula/util" |  | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -46,7 +46,7 @@ func TestNewLhQuery(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func Test_lhStaticMapping(t *testing.T) { | func Test_lhStaticMapping(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	lh1 := "10.128.0.2" | 	lh1 := "10.128.0.2" | ||||||
| 	lh1IP := net.ParseIP(lh1) | 	lh1IP := net.ParseIP(lh1) | ||||||
|  |  | ||||||
| @@ -67,7 +67,7 @@ func Test_lhStaticMapping(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func BenchmarkLighthouseHandleRequest(b *testing.B) { | func BenchmarkLighthouseHandleRequest(b *testing.B) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	lh1 := "10.128.0.2" | 	lh1 := "10.128.0.2" | ||||||
| 	lh1IP := net.ParseIP(lh1) | 	lh1IP := net.ParseIP(lh1) | ||||||
|  |  | ||||||
| @@ -137,7 +137,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestLighthouse_Memory(t *testing.T) { | func TestLighthouse_Memory(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
|  |  | ||||||
| 	myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} | 	myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} | ||||||
| 	myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} | 	myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} | ||||||
| @@ -266,7 +266,7 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, | |||||||
|  |  | ||||||
| //TODO: this is a RemoteList test | //TODO: this is a RemoteList test | ||||||
| //func Test_lhRemoteAllowList(t *testing.T) { | //func Test_lhRemoteAllowList(t *testing.T) { | ||||||
| //	l := NewTestLogger() | //	l := NewLogger() | ||||||
| //	c := NewConfig(l) | //	c := NewConfig(l) | ||||||
| //	c.Settings["remoteallowlist"] = map[interface{}]interface{}{ | //	c.Settings["remoteallowlist"] = map[interface{}]interface{}{ | ||||||
| //		"10.20.0.0/12": false, | //		"10.20.0.0/12": false, | ||||||
|   | |||||||
							
								
								
									
										33
									
								
								logger.go
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								logger.go
									
									
									
									
									
								
							| @@ -1,7 +1,6 @@ | |||||||
| package nebula | package nebula | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"errors" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -10,38 +9,6 @@ import ( | |||||||
| 	"github.com/slackhq/nebula/config" | 	"github.com/slackhq/nebula/config" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type ContextualError struct { |  | ||||||
| 	RealError error |  | ||||||
| 	Fields    map[string]interface{} |  | ||||||
| 	Context   string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { |  | ||||||
| 	return ContextualError{Context: msg, Fields: fields, RealError: realError} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ce ContextualError) Error() string { |  | ||||||
| 	if ce.RealError == nil { |  | ||||||
| 		return ce.Context |  | ||||||
| 	} |  | ||||||
| 	return ce.RealError.Error() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ce ContextualError) Unwrap() error { |  | ||||||
| 	if ce.RealError == nil { |  | ||||||
| 		return errors.New(ce.Context) |  | ||||||
| 	} |  | ||||||
| 	return ce.RealError |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ce *ContextualError) Log(lr *logrus.Logger) { |  | ||||||
| 	if ce.RealError != nil { |  | ||||||
| 		lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) |  | ||||||
| 	} else { |  | ||||||
| 		lr.WithFields(ce.Fields).Error(ce.Context) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func configLogger(l *logrus.Logger, c *config.C) error { | func configLogger(l *logrus.Logger, c *config.C) error { | ||||||
| 	// set up our logging level | 	// set up our logging level | ||||||
| 	logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) | 	logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) | ||||||
|   | |||||||
							
								
								
									
										43
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								main.go
									
									
									
									
									
								
							| @@ -12,6 +12,7 @@ import ( | |||||||
| 	"github.com/slackhq/nebula/iputil" | 	"github.com/slackhq/nebula/iputil" | ||||||
| 	"github.com/slackhq/nebula/sshd" | 	"github.com/slackhq/nebula/sshd" | ||||||
| 	"github.com/slackhq/nebula/udp" | 	"github.com/slackhq/nebula/udp" | ||||||
|  | 	"github.com/slackhq/nebula/util" | ||||||
| 	"gopkg.in/yaml.v2" | 	"gopkg.in/yaml.v2" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -44,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
|  |  | ||||||
| 	err := configLogger(l, c) | 	err := configLogger(l, c) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, NewContextualError("Failed to configure the logger", nil, err) | 		return nil, util.NewContextualError("Failed to configure the logger", nil, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c.RegisterReloadCallback(func(c *config.C) { | 	c.RegisterReloadCallback(func(c *config.C) { | ||||||
| @@ -57,20 +58,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 	caPool, err := loadCAFromConfig(l, c) | 	caPool, err := loadCAFromConfig(l, c) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		//The errors coming out of loadCA are already nicely formatted | 		//The errors coming out of loadCA are already nicely formatted | ||||||
| 		return nil, NewContextualError("Failed to load ca from config", nil, err) | 		return nil, util.NewContextualError("Failed to load ca from config", nil, err) | ||||||
| 	} | 	} | ||||||
| 	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") | 	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") | ||||||
|  |  | ||||||
| 	cs, err := NewCertStateFromConfig(c) | 	cs, err := NewCertStateFromConfig(c) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		//The errors coming out of NewCertStateFromConfig are already nicely formatted | 		//The errors coming out of NewCertStateFromConfig are already nicely formatted | ||||||
| 		return nil, NewContextualError("Failed to load certificate from config", nil, err) | 		return nil, util.NewContextualError("Failed to load certificate from config", nil, err) | ||||||
| 	} | 	} | ||||||
| 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate") | 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate") | ||||||
|  |  | ||||||
| 	fw, err := NewFirewallFromConfig(l, cs.certificate, c) | 	fw, err := NewFirewallFromConfig(l, cs.certificate, c) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, NewContextualError("Error while loading firewall rules", nil, err) | 		return nil, util.NewContextualError("Error while loading firewall rules", nil, err) | ||||||
| 	} | 	} | ||||||
| 	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") | 	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") | ||||||
|  |  | ||||||
| @@ -78,11 +79,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 	tunCidr := cs.certificate.Details.Ips[0] | 	tunCidr := cs.certificate.Details.Ips[0] | ||||||
| 	routes, err := parseRoutes(c, tunCidr) | 	routes, err := parseRoutes(c, tunCidr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, NewContextualError("Could not parse tun.routes", nil, err) | 		return nil, util.NewContextualError("Could not parse tun.routes", nil, err) | ||||||
| 	} | 	} | ||||||
| 	unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) | 	unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err) | 		return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) | 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) | ||||||
| @@ -91,7 +92,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 	if c.GetBool("sshd.enabled", false) { | 	if c.GetBool("sshd.enabled", false) { | ||||||
| 		sshStart, err = configSSH(l, ssh, c) | 		sshStart, err = configSSH(l, ssh, c) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, NewContextualError("Error while configuring the sshd", nil, err) | 			return nil, util.NewContextualError("Error while configuring the sshd", nil, err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -167,7 +168,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, NewContextualError("Failed to get a tun/tap device", nil, err) | 			return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -185,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 		for i := 0; i < routines; i++ { | 		for i := 0; i < routines; i++ { | ||||||
| 			udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64)) | 			udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64)) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err) | 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) | ||||||
| 			} | 			} | ||||||
| 			udpServer.ReloadConfig(c) | 			udpServer.ReloadConfig(c) | ||||||
| 			udpConns[i] = udpServer | 			udpConns[i] = udpServer | ||||||
| @@ -194,7 +195,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 			if port == 0 { | 			if port == 0 { | ||||||
| 				uPort, err := udpServer.LocalAddr() | 				uPort, err := udpServer.LocalAddr() | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					return nil, NewContextualError("Failed to get listening port", nil, err) | 					return nil, util.NewContextualError("Failed to get listening port", nil, err) | ||||||
| 				} | 				} | ||||||
| 				port = int(uPort.Port) | 				port = int(uPort.Port) | ||||||
| 			} | 			} | ||||||
| @@ -209,7 +210,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 		for _, rawPreferredRange := range rawPreferredRanges { | 		for _, rawPreferredRange := range rawPreferredRanges { | ||||||
| 			_, preferredRange, err := net.ParseCIDR(rawPreferredRange) | 			_, preferredRange, err := net.ParseCIDR(rawPreferredRange) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, NewContextualError("Failed to parse preferred ranges", nil, err) | 				return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err) | ||||||
| 			} | 			} | ||||||
| 			preferredRanges = append(preferredRanges, preferredRange) | 			preferredRanges = append(preferredRanges, preferredRange) | ||||||
| 		} | 		} | ||||||
| @@ -222,7 +223,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 	if rawLocalRange != "" { | 	if rawLocalRange != "" { | ||||||
| 		_, localRange, err := net.ParseCIDR(rawLocalRange) | 		_, localRange, err := net.ParseCIDR(rawLocalRange) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, NewContextualError("Failed to parse local_range", nil, err) | 			return nil, util.NewContextualError("Failed to parse local_range", nil, err) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Check if the entry for local_range was already specified in | 		// Check if the entry for local_range was already specified in | ||||||
| @@ -261,7 +262,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
|  |  | ||||||
| 	// fatal if am_lighthouse is enabled but we are using an ephemeral port | 	// fatal if am_lighthouse is enabled but we are using an ephemeral port | ||||||
| 	if amLighthouse && (c.GetInt("listen.port", 0) == 0) { | 	if amLighthouse && (c.GetInt("listen.port", 0) == 0) { | ||||||
| 		return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) | 		return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// warn if am_lighthouse is enabled but upstream lighthouses exists | 	// warn if am_lighthouse is enabled but upstream lighthouses exists | ||||||
| @@ -274,10 +275,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 	for i, host := range rawLighthouseHosts { | 	for i, host := range rawLighthouseHosts { | ||||||
| 		ip := net.ParseIP(host) | 		ip := net.ParseIP(host) | ||||||
| 		if ip == nil { | 		if ip == nil { | ||||||
| 			return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) | 			return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) | ||||||
| 		} | 		} | ||||||
| 		if !tunCidr.Contains(ip) { | 		if !tunCidr.Contains(ip) { | ||||||
| 			return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) | 			return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) | ||||||
| 		} | 		} | ||||||
| 		lighthouseHosts[i] = iputil.Ip2VpnIp(ip) | 		lighthouseHosts[i] = iputil.Ip2VpnIp(ip) | ||||||
| 	} | 	} | ||||||
| @@ -298,13 +299,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
|  |  | ||||||
| 	remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") | 	remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) | 		return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) | ||||||
| 	} | 	} | ||||||
| 	lightHouse.SetRemoteAllowList(remoteAllowList) | 	lightHouse.SetRemoteAllowList(remoteAllowList) | ||||||
|  |  | ||||||
| 	localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list") | 	localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err) | 		return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err) | ||||||
| 	} | 	} | ||||||
| 	lightHouse.SetLocalAllowList(localAllowList) | 	lightHouse.SetLocalAllowList(localAllowList) | ||||||
|  |  | ||||||
| @@ -313,21 +314,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 		ip := net.ParseIP(fmt.Sprintf("%v", k)) | 		ip := net.ParseIP(fmt.Sprintf("%v", k)) | ||||||
| 		vpnIp := iputil.Ip2VpnIp(ip) | 		vpnIp := iputil.Ip2VpnIp(ip) | ||||||
| 		if !tunCidr.Contains(ip) { | 		if !tunCidr.Contains(ip) { | ||||||
| 			return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) | 			return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) | ||||||
| 		} | 		} | ||||||
| 		vals, ok := v.([]interface{}) | 		vals, ok := v.([]interface{}) | ||||||
| 		if ok { | 		if ok { | ||||||
| 			for _, v := range vals { | 			for _, v := range vals { | ||||||
| 				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) | 				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | 					return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | ||||||
| 				} | 				} | ||||||
| 				lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) | 				lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) | 			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | 				return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) | ||||||
| 			} | 			} | ||||||
| 			lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) | 			lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) | ||||||
| 		} | 		} | ||||||
| @@ -426,7 +427,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg | |||||||
| 	statsStart, err := startStats(l, c, buildVersion, configTest) | 	statsStart, err := startStats(l, c, buildVersion, configTest) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, NewContextualError("Failed to start stats emitter", nil, err) | 		return nil, util.NewContextualError("Failed to start stats emitter", nil, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if configTest { | 	if configTest { | ||||||
|   | |||||||
| @@ -5,12 +5,12 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/slackhq/nebula/config" | 	"github.com/slackhq/nebula/config" | ||||||
| 	"github.com/slackhq/nebula/util" | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestNewPunchyFromConfig(t *testing.T) { | func TestNewPunchyFromConfig(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	c := config.NewC(l) | 	c := config.NewC(l) | ||||||
|  |  | ||||||
| 	// Test defaults | 	// Test defaults | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| package util | package test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package util | package test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| @@ -7,7 +7,7 @@ import ( | |||||||
| 	"github.com/sirupsen/logrus" | 	"github.com/sirupsen/logrus" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func NewTestLogger() *logrus.Logger { | func NewLogger() *logrus.Logger { | ||||||
| 	l := logrus.New() | 	l := logrus.New() | ||||||
| 
 | 
 | ||||||
| 	v := os.Getenv("TEST_LOGS") | 	v := os.Getenv("TEST_LOGS") | ||||||
| @@ -6,12 +6,12 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/slackhq/nebula/config" | 	"github.com/slackhq/nebula/config" | ||||||
| 	"github.com/slackhq/nebula/util" | 	"github.com/slackhq/nebula/test" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func Test_parseRoutes(t *testing.T) { | func Test_parseRoutes(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	c := config.NewC(l) | 	c := config.NewC(l) | ||||||
| 	_, n, _ := net.ParseCIDR("10.0.0.0/24") | 	_, n, _ := net.ParseCIDR("10.0.0.0/24") | ||||||
|  |  | ||||||
| @@ -107,7 +107,7 @@ func Test_parseRoutes(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func Test_parseUnsafeRoutes(t *testing.T) { | func Test_parseUnsafeRoutes(t *testing.T) { | ||||||
| 	l := util.NewTestLogger() | 	l := test.NewLogger() | ||||||
| 	c := config.NewC(l) | 	c := config.NewC(l) | ||||||
| 	_, n, _ := net.ParseCIDR("10.0.0.0/24") | 	_, n, _ := net.ParseCIDR("10.0.0.0/24") | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										39
									
								
								util/error.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								util/error.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | |||||||
|  | package util | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  |  | ||||||
|  | 	"github.com/sirupsen/logrus" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type ContextualError struct { | ||||||
|  | 	RealError error | ||||||
|  | 	Fields    map[string]interface{} | ||||||
|  | 	Context   string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { | ||||||
|  | 	return ContextualError{Context: msg, Fields: fields, RealError: realError} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ce ContextualError) Error() string { | ||||||
|  | 	if ce.RealError == nil { | ||||||
|  | 		return ce.Context | ||||||
|  | 	} | ||||||
|  | 	return ce.RealError.Error() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ce ContextualError) Unwrap() error { | ||||||
|  | 	if ce.RealError == nil { | ||||||
|  | 		return errors.New(ce.Context) | ||||||
|  | 	} | ||||||
|  | 	return ce.RealError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ce *ContextualError) Log(lr *logrus.Logger) { | ||||||
|  | 	if ce.RealError != nil { | ||||||
|  | 		lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) | ||||||
|  | 	} else { | ||||||
|  | 		lr.WithFields(ce.Fields).Error(ce.Context) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package nebula | package util | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| @@ -8,6 +8,8 @@ import ( | |||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | type m map[string]interface{} | ||||||
|  | 
 | ||||||
| type TestLogWriter struct { | type TestLogWriter struct { | ||||||
| 	Logs []string | 	Logs []string | ||||||
| } | } | ||||||
		Reference in New Issue
	
	Block a user
	 Nate Brown
					Nate Brown