Don't use a global logger (#423)
This commit is contained in:
parent
7a9f9dbded
commit
3ea7e1b75f
4
bits.go
4
bits.go
|
@ -26,7 +26,7 @@ func NewBits(bits uint64) *Bits {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Check(i uint64) bool {
|
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
||||||
// If i is the next number, return true.
|
// If i is the next number, return true.
|
||||||
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
|
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
|
||||||
return true
|
return true
|
||||||
|
@ -47,7 +47,7 @@ func (b *Bits) Check(i uint64) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Update(i uint64) bool {
|
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||||
// If i is the next number, return true and update current.
|
// If i is the next number, return true and update current.
|
||||||
if i == b.current+1 {
|
if i == b.current+1 {
|
||||||
// Report missed packets, we can only understand what was missed after the first window has been gone through
|
// Report missed packets, we can only understand what was missed after the first window has been gone through
|
||||||
|
|
154
bits_test.go
154
bits_test.go
|
@ -7,6 +7,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBits(t *testing.T) {
|
func TestBits(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
|
|
||||||
// make sure it is the right size
|
// make sure it is the right size
|
||||||
|
@ -14,46 +15,46 @@ func TestBits(t *testing.T) {
|
||||||
|
|
||||||
// This is initialized to zero - receive one. This should work.
|
// This is initialized to zero - receive one. This should work.
|
||||||
|
|
||||||
assert.True(t, b.Check(1))
|
assert.True(t, b.Check(l, 1))
|
||||||
u := b.Update(1)
|
u := b.Update(l, 1)
|
||||||
assert.True(t, u)
|
assert.True(t, u)
|
||||||
assert.EqualValues(t, 1, b.current)
|
assert.EqualValues(t, 1, b.current)
|
||||||
g := []bool{false, true, false, false, false, false, false, false, false, false}
|
g := []bool{false, true, false, false, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Receive two
|
// Receive two
|
||||||
assert.True(t, b.Check(2))
|
assert.True(t, b.Check(l, 2))
|
||||||
u = b.Update(2)
|
u = b.Update(l, 2)
|
||||||
assert.True(t, u)
|
assert.True(t, u)
|
||||||
assert.EqualValues(t, 2, b.current)
|
assert.EqualValues(t, 2, b.current)
|
||||||
g = []bool{false, true, true, false, false, false, false, false, false, false}
|
g = []bool{false, true, true, false, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Receive two again - it will fail
|
// Receive two again - it will fail
|
||||||
assert.False(t, b.Check(2))
|
assert.False(t, b.Check(l, 2))
|
||||||
u = b.Update(2)
|
u = b.Update(l, 2)
|
||||||
assert.False(t, u)
|
assert.False(t, u)
|
||||||
assert.EqualValues(t, 2, b.current)
|
assert.EqualValues(t, 2, b.current)
|
||||||
|
|
||||||
// Jump ahead to 15, which should clear everything and set the 6th element
|
// Jump ahead to 15, which should clear everything and set the 6th element
|
||||||
assert.True(t, b.Check(15))
|
assert.True(t, b.Check(l, 15))
|
||||||
u = b.Update(15)
|
u = b.Update(l, 15)
|
||||||
assert.True(t, u)
|
assert.True(t, u)
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 15, b.current)
|
||||||
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Mark 14, which is allowed because it is in the window
|
// Mark 14, which is allowed because it is in the window
|
||||||
assert.True(t, b.Check(14))
|
assert.True(t, b.Check(l, 14))
|
||||||
u = b.Update(14)
|
u = b.Update(l, 14)
|
||||||
assert.True(t, u)
|
assert.True(t, u)
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 15, b.current)
|
||||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Mark 5, which is not allowed because it is not in the window
|
// Mark 5, which is not allowed because it is not in the window
|
||||||
assert.False(t, b.Check(5))
|
assert.False(t, b.Check(l, 5))
|
||||||
u = b.Update(5)
|
u = b.Update(l, 5)
|
||||||
assert.False(t, u)
|
assert.False(t, u)
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 15, b.current)
|
||||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||||
|
@ -61,63 +62,65 @@ func TestBits(t *testing.T) {
|
||||||
|
|
||||||
// make sure we handle wrapping around once to the current position
|
// make sure we handle wrapping around once to the current position
|
||||||
b = NewBits(10)
|
b = NewBits(10)
|
||||||
assert.True(t, b.Update(1))
|
assert.True(t, b.Update(l, 1))
|
||||||
assert.True(t, b.Update(11))
|
assert.True(t, b.Update(l, 11))
|
||||||
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
|
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
|
||||||
|
|
||||||
// Walk through a few windows in order
|
// Walk through a few windows in order
|
||||||
b = NewBits(10)
|
b = NewBits(10)
|
||||||
for i := uint64(0); i <= 100; i++ {
|
for i := uint64(0); i <= 100; i++ {
|
||||||
assert.True(t, b.Check(i), "Error while checking %v", i)
|
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
||||||
assert.True(t, b.Update(i), "Error while updating %v", i)
|
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsDupeCounter(t *testing.T) {
|
func TestBitsDupeCounter(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
assert.True(t, b.Update(1))
|
assert.True(t, b.Update(l, 1))
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
|
|
||||||
assert.False(t, b.Update(1))
|
assert.False(t, b.Update(l, 1))
|
||||||
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
||||||
|
|
||||||
assert.True(t, b.Update(2))
|
assert.True(t, b.Update(l, 2))
|
||||||
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
||||||
|
|
||||||
assert.True(t, b.Update(3))
|
assert.True(t, b.Update(l, 3))
|
||||||
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
||||||
|
|
||||||
assert.False(t, b.Update(1))
|
assert.False(t, b.Update(l, 1))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
assert.Equal(t, int64(2), b.dupeCounter.Count())
|
assert.Equal(t, int64(2), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsOutOfWindowCounter(t *testing.T) {
|
func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
assert.True(t, b.Update(20))
|
assert.True(t, b.Update(l, 20))
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
assert.True(t, b.Update(21))
|
assert.True(t, b.Update(l, 21))
|
||||||
assert.True(t, b.Update(22))
|
assert.True(t, b.Update(l, 22))
|
||||||
assert.True(t, b.Update(23))
|
assert.True(t, b.Update(l, 23))
|
||||||
assert.True(t, b.Update(24))
|
assert.True(t, b.Update(l, 24))
|
||||||
assert.True(t, b.Update(25))
|
assert.True(t, b.Update(l, 25))
|
||||||
assert.True(t, b.Update(26))
|
assert.True(t, b.Update(l, 26))
|
||||||
assert.True(t, b.Update(27))
|
assert.True(t, b.Update(l, 27))
|
||||||
assert.True(t, b.Update(28))
|
assert.True(t, b.Update(l, 28))
|
||||||
assert.True(t, b.Update(29))
|
assert.True(t, b.Update(l, 29))
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
assert.False(t, b.Update(0))
|
assert.False(t, b.Update(l, 0))
|
||||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
//tODO: make sure lostcounter doesn't increase in orderly increment
|
//tODO: make sure lostcounter doesn't increase in orderly increment
|
||||||
|
@ -127,23 +130,24 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsLostCounter(t *testing.T) {
|
func TestBitsLostCounter(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
//assert.True(t, b.Update(0))
|
//assert.True(t, b.Update(0))
|
||||||
assert.True(t, b.Update(0))
|
assert.True(t, b.Update(l, 0))
|
||||||
assert.True(t, b.Update(20))
|
assert.True(t, b.Update(l, 20))
|
||||||
assert.True(t, b.Update(21))
|
assert.True(t, b.Update(l, 21))
|
||||||
assert.True(t, b.Update(22))
|
assert.True(t, b.Update(l, 22))
|
||||||
assert.True(t, b.Update(23))
|
assert.True(t, b.Update(l, 23))
|
||||||
assert.True(t, b.Update(24))
|
assert.True(t, b.Update(l, 24))
|
||||||
assert.True(t, b.Update(25))
|
assert.True(t, b.Update(l, 25))
|
||||||
assert.True(t, b.Update(26))
|
assert.True(t, b.Update(l, 26))
|
||||||
assert.True(t, b.Update(27))
|
assert.True(t, b.Update(l, 27))
|
||||||
assert.True(t, b.Update(28))
|
assert.True(t, b.Update(l, 28))
|
||||||
assert.True(t, b.Update(29))
|
assert.True(t, b.Update(l, 29))
|
||||||
assert.Equal(t, int64(20), b.lostCounter.Count())
|
assert.Equal(t, int64(20), b.lostCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
|
@ -153,56 +157,56 @@ func TestBitsLostCounter(t *testing.T) {
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
assert.True(t, b.Update(0))
|
assert.True(t, b.Update(l, 0))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(9))
|
assert.True(t, b.Update(l, 9))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
// 10 will set 0 index, 0 was already set, no lost packets
|
// 10 will set 0 index, 0 was already set, no lost packets
|
||||||
assert.True(t, b.Update(10))
|
assert.True(t, b.Update(l, 10))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
|
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
|
||||||
assert.True(t, b.Update(11))
|
assert.True(t, b.Update(l, 11))
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
// Now let's fill in the window, should end up with 8 lost packets
|
// Now let's fill in the window, should end up with 8 lost packets
|
||||||
assert.True(t, b.Update(12))
|
assert.True(t, b.Update(l, 12))
|
||||||
assert.True(t, b.Update(13))
|
assert.True(t, b.Update(l, 13))
|
||||||
assert.True(t, b.Update(14))
|
assert.True(t, b.Update(l, 14))
|
||||||
assert.True(t, b.Update(15))
|
assert.True(t, b.Update(l, 15))
|
||||||
assert.True(t, b.Update(16))
|
assert.True(t, b.Update(l, 16))
|
||||||
assert.True(t, b.Update(17))
|
assert.True(t, b.Update(l, 17))
|
||||||
assert.True(t, b.Update(18))
|
assert.True(t, b.Update(l, 18))
|
||||||
assert.True(t, b.Update(19))
|
assert.True(t, b.Update(l, 19))
|
||||||
assert.Equal(t, int64(8), b.lostCounter.Count())
|
assert.Equal(t, int64(8), b.lostCounter.Count())
|
||||||
|
|
||||||
// Jump ahead by a window size
|
// Jump ahead by a window size
|
||||||
assert.True(t, b.Update(29))
|
assert.True(t, b.Update(l, 29))
|
||||||
assert.Equal(t, int64(8), b.lostCounter.Count())
|
assert.Equal(t, int64(8), b.lostCounter.Count())
|
||||||
// Now lets walk ahead normally through the window, the missed packets should fill in
|
// Now lets walk ahead normally through the window, the missed packets should fill in
|
||||||
assert.True(t, b.Update(30))
|
assert.True(t, b.Update(l, 30))
|
||||||
assert.True(t, b.Update(31))
|
assert.True(t, b.Update(l, 31))
|
||||||
assert.True(t, b.Update(32))
|
assert.True(t, b.Update(l, 32))
|
||||||
assert.True(t, b.Update(33))
|
assert.True(t, b.Update(l, 33))
|
||||||
assert.True(t, b.Update(34))
|
assert.True(t, b.Update(l, 34))
|
||||||
assert.True(t, b.Update(35))
|
assert.True(t, b.Update(l, 35))
|
||||||
assert.True(t, b.Update(36))
|
assert.True(t, b.Update(l, 36))
|
||||||
assert.True(t, b.Update(37))
|
assert.True(t, b.Update(l, 37))
|
||||||
assert.True(t, b.Update(38))
|
assert.True(t, b.Update(l, 38))
|
||||||
// 39 packets tracked, 22 seen, 17 lost
|
// 39 packets tracked, 22 seen, 17 lost
|
||||||
assert.Equal(t, int64(17), b.lostCounter.Count())
|
assert.Equal(t, int64(17), b.lostCounter.Count())
|
||||||
|
|
||||||
// Jump ahead by 2 windows, should have recording 1 full window missing
|
// Jump ahead by 2 windows, should have recording 1 full window missing
|
||||||
assert.True(t, b.Update(58))
|
assert.True(t, b.Update(l, 58))
|
||||||
assert.Equal(t, int64(27), b.lostCounter.Count())
|
assert.Equal(t, int64(27), b.lostCounter.Count())
|
||||||
// Now lets walk ahead normally through the window, the missed packets should fill in from this window
|
// Now lets walk ahead normally through the window, the missed packets should fill in from this window
|
||||||
assert.True(t, b.Update(59))
|
assert.True(t, b.Update(l, 59))
|
||||||
assert.True(t, b.Update(60))
|
assert.True(t, b.Update(l, 60))
|
||||||
assert.True(t, b.Update(61))
|
assert.True(t, b.Update(l, 61))
|
||||||
assert.True(t, b.Update(62))
|
assert.True(t, b.Update(l, 62))
|
||||||
assert.True(t, b.Update(63))
|
assert.True(t, b.Update(l, 63))
|
||||||
assert.True(t, b.Update(64))
|
assert.True(t, b.Update(l, 64))
|
||||||
assert.True(t, b.Update(65))
|
assert.True(t, b.Update(l, 65))
|
||||||
assert.True(t, b.Update(66))
|
assert.True(t, b.Update(l, 66))
|
||||||
assert.True(t, b.Update(67))
|
assert.True(t, b.Update(l, 67))
|
||||||
// 68 packets tracked, 32 seen, 36 missed
|
// 68 packets tracked, 32 seen, 36 missed
|
||||||
assert.Equal(t, int64(36), b.lostCounter.Count())
|
assert.Equal(t, int64(36), b.lostCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
|
|
3
cert.go
3
cert.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -119,7 +120,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
|
||||||
return NewCertState(nebulaCert, rawKey)
|
return NewCertState(nebulaCert, rawKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
|
func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
|
||||||
var rawCA []byte
|
var rawCA []byte
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
|
|
@ -46,15 +46,16 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
config := nebula.NewConfig()
|
l := logrus.New()
|
||||||
|
l.Out = os.Stdout
|
||||||
|
|
||||||
|
config := nebula.NewConfig(l)
|
||||||
err := config.Load(*configPath)
|
err := config.Load(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("failed to load config: %s", err)
|
fmt.Printf("failed to load config: %s", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logrus.New()
|
|
||||||
l.Out = os.Stdout
|
|
||||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||||
|
|
||||||
switch v := err.(type) {
|
switch v := err.(type) {
|
||||||
|
|
|
@ -24,14 +24,15 @@ func (p *program) Start(s service.Service) error {
|
||||||
// Start should not block.
|
// Start should not block.
|
||||||
logger.Info("Nebula service starting.")
|
logger.Info("Nebula service starting.")
|
||||||
|
|
||||||
config := nebula.NewConfig()
|
l := logrus.New()
|
||||||
|
l.Out = os.Stdout
|
||||||
|
|
||||||
|
config := nebula.NewConfig(l)
|
||||||
err := config.Load(*p.configPath)
|
err := config.Load(*p.configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load config: %s", err)
|
return fmt.Errorf("failed to load config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logrus.New()
|
|
||||||
l.Out = os.Stdout
|
|
||||||
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
|
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -40,15 +40,16 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
config := nebula.NewConfig()
|
l := logrus.New()
|
||||||
|
l.Out = os.Stdout
|
||||||
|
|
||||||
|
config := nebula.NewConfig(l)
|
||||||
err := config.Load(*configPath)
|
err := config.Load(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("failed to load config: %s", err)
|
fmt.Printf("failed to load config: %s", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logrus.New()
|
|
||||||
l.Out = os.Stdout
|
|
||||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||||
|
|
||||||
switch v := err.(type) {
|
switch v := err.(type) {
|
||||||
|
|
18
config.go
18
config.go
|
@ -26,11 +26,13 @@ type Config struct {
|
||||||
Settings map[interface{}]interface{}
|
Settings map[interface{}]interface{}
|
||||||
oldSettings map[interface{}]interface{}
|
oldSettings map[interface{}]interface{}
|
||||||
callbacks []func(*Config)
|
callbacks []func(*Config)
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfig() *Config {
|
func NewConfig(l *logrus.Logger) *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
Settings: make(map[interface{}]interface{}),
|
Settings: make(map[interface{}]interface{}),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,12 +101,12 @@ func (c *Config) HasChanged(k string) bool {
|
||||||
|
|
||||||
newVals, err := yaml.Marshal(nv)
|
newVals, err := yaml.Marshal(nv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
||||||
}
|
}
|
||||||
|
|
||||||
oldVals, err := yaml.Marshal(ov)
|
oldVals, err := yaml.Marshal(ov)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(newVals) != string(oldVals)
|
return string(newVals) != string(oldVals)
|
||||||
|
@ -118,7 +120,7 @@ func (c *Config) CatchHUP() {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for range ch {
|
for range ch {
|
||||||
l.Info("Caught HUP, reloading config")
|
c.l.Info("Caught HUP, reloading config")
|
||||||
c.ReloadConfig()
|
c.ReloadConfig()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -132,7 +134,7 @@ func (c *Config) ReloadConfig() {
|
||||||
|
|
||||||
err := c.Load(c.path)
|
err := c.Load(c.path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -500,7 +502,7 @@ func configLogger(c *Config) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
|
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
|
||||||
}
|
}
|
||||||
l.SetLevel(logLevel)
|
c.l.SetLevel(logLevel)
|
||||||
|
|
||||||
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
||||||
timestampFormat := c.GetString("logging.timestamp_format", "")
|
timestampFormat := c.GetString("logging.timestamp_format", "")
|
||||||
|
@ -512,13 +514,13 @@ func configLogger(c *Config) error {
|
||||||
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
||||||
switch logFormat {
|
switch logFormat {
|
||||||
case "text":
|
case "text":
|
||||||
l.Formatter = &logrus.TextFormatter{
|
c.l.Formatter = &logrus.TextFormatter{
|
||||||
TimestampFormat: timestampFormat,
|
TimestampFormat: timestampFormat,
|
||||||
FullTimestamp: fullTimestamp,
|
FullTimestamp: fullTimestamp,
|
||||||
DisableTimestamp: disableTimestamp,
|
DisableTimestamp: disableTimestamp,
|
||||||
}
|
}
|
||||||
case "json":
|
case "json":
|
||||||
l.Formatter = &logrus.JSONFormatter{
|
c.l.Formatter = &logrus.JSONFormatter{
|
||||||
TimestampFormat: timestampFormat,
|
TimestampFormat: timestampFormat,
|
||||||
DisableTimestamp: disableTimestamp,
|
DisableTimestamp: disableTimestamp,
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,14 +11,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
dir, err := ioutil.TempDir("", "config-test")
|
dir, err := ioutil.TempDir("", "config-test")
|
||||||
// invalid yaml
|
// invalid yaml
|
||||||
c := NewConfig()
|
c := NewConfig(l)
|
||||||
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
||||||
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
||||||
|
|
||||||
// simple multi config merge
|
// simple multi config merge
|
||||||
c = NewConfig()
|
c = NewConfig(l)
|
||||||
os.RemoveAll(dir)
|
os.RemoveAll(dir)
|
||||||
os.Mkdir(dir, 0755)
|
os.Mkdir(dir, 0755)
|
||||||
|
|
||||||
|
@ -40,8 +41,9 @@ func TestConfig_Load(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_Get(t *testing.T) {
|
func TestConfig_Get(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
// test simple type
|
// test simple type
|
||||||
c := NewConfig()
|
c := NewConfig(l)
|
||||||
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
||||||
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
||||||
|
|
||||||
|
@ -55,13 +57,15 @@ func TestConfig_Get(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GetStringSlice(t *testing.T) {
|
func TestConfig_GetStringSlice(t *testing.T) {
|
||||||
c := NewConfig()
|
l := NewTestLogger()
|
||||||
|
c := NewConfig(l)
|
||||||
c.Settings["slice"] = []interface{}{"one", "two"}
|
c.Settings["slice"] = []interface{}{"one", "two"}
|
||||||
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GetBool(t *testing.T) {
|
func TestConfig_GetBool(t *testing.T) {
|
||||||
c := NewConfig()
|
l := NewTestLogger()
|
||||||
|
c := NewConfig(l)
|
||||||
c.Settings["bool"] = true
|
c.Settings["bool"] = true
|
||||||
assert.Equal(t, true, c.GetBool("bool", false))
|
assert.Equal(t, true, c.GetBool("bool", false))
|
||||||
|
|
||||||
|
@ -88,7 +92,8 @@ func TestConfig_GetBool(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GetAllowList(t *testing.T) {
|
func TestConfig_GetAllowList(t *testing.T) {
|
||||||
c := NewConfig()
|
l := NewTestLogger()
|
||||||
|
c := NewConfig(l)
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"192.168.0.0": true,
|
"192.168.0.0": true,
|
||||||
}
|
}
|
||||||
|
@ -181,20 +186,21 @@ func TestConfig_GetAllowList(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_HasChanged(t *testing.T) {
|
func TestConfig_HasChanged(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
// No reload has occurred, return false
|
// No reload has occurred, return false
|
||||||
c := NewConfig()
|
c := NewConfig(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
assert.False(t, c.HasChanged(""))
|
assert.False(t, c.HasChanged(""))
|
||||||
|
|
||||||
// Test key change
|
// Test key change
|
||||||
c = NewConfig()
|
c = NewConfig(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
c.oldSettings = map[interface{}]interface{}{"test": "no"}
|
c.oldSettings = map[interface{}]interface{}{"test": "no"}
|
||||||
assert.True(t, c.HasChanged("test"))
|
assert.True(t, c.HasChanged("test"))
|
||||||
assert.True(t, c.HasChanged(""))
|
assert.True(t, c.HasChanged(""))
|
||||||
|
|
||||||
// No key change
|
// No key change
|
||||||
c = NewConfig()
|
c = NewConfig(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
|
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
|
||||||
assert.False(t, c.HasChanged("test"))
|
assert.False(t, c.HasChanged("test"))
|
||||||
|
@ -202,12 +208,13 @@ func TestConfig_HasChanged(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_ReloadConfig(t *testing.T) {
|
func TestConfig_ReloadConfig(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
done := make(chan bool, 1)
|
done := make(chan bool, 1)
|
||||||
dir, err := ioutil.TempDir("", "config-test")
|
dir, err := ioutil.TempDir("", "config-test")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||||
|
|
||||||
c := NewConfig()
|
c := NewConfig(l)
|
||||||
assert.Nil(t, c.Load(dir))
|
assert.Nil(t, c.Load(dir))
|
||||||
|
|
||||||
assert.False(t, c.HasChanged("outer.inner"))
|
assert.False(t, c.HasChanged("outer.inner"))
|
||||||
|
|
|
@ -28,10 +28,11 @@ type connectionManager struct {
|
||||||
checkInterval int
|
checkInterval int
|
||||||
pendingDeletionInterval int
|
pendingDeletionInterval int
|
||||||
|
|
||||||
|
l *logrus.Logger
|
||||||
// I wanted to call one matLock
|
// I wanted to call one matLock
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
|
func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
|
||||||
nc := &connectionManager{
|
nc := &connectionManager{
|
||||||
hostMap: intf.hostMap,
|
hostMap: intf.hostMap,
|
||||||
in: make(map[uint32]struct{}),
|
in: make(map[uint32]struct{}),
|
||||||
|
@ -47,6 +48,7 @@ func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterva
|
||||||
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
|
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
|
||||||
checkInterval: checkInterval,
|
checkInterval: checkInterval,
|
||||||
pendingDeletionInterval: pendingDeletionInterval,
|
pendingDeletionInterval: pendingDeletionInterval,
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
nc.Start()
|
nc.Start()
|
||||||
return nc
|
return nc
|
||||||
|
@ -166,8 +168,8 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
||||||
|
|
||||||
// If we saw incoming packets from this ip, just return
|
// If we saw incoming packets from this ip, just return
|
||||||
if traf {
|
if traf {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if n.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).
|
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
}
|
}
|
||||||
|
@ -179,13 +181,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
||||||
// If we didn't we may need to probe or destroy the conn
|
// If we didn't we may need to probe or destroy the conn
|
||||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIP)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIP)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger().
|
hostinfo.logger(n.l).
|
||||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
|
|
||||||
|
@ -194,7 +196,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
||||||
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
|
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
|
hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
|
||||||
}
|
}
|
||||||
n.AddPendingDeletion(vpnIP)
|
n.AddPendingDeletion(vpnIP)
|
||||||
}
|
}
|
||||||
|
@ -214,7 +216,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||||
// If we saw incoming packets from this ip, just return
|
// If we saw incoming packets from this ip, just return
|
||||||
traf := n.CheckIn(vpnIP)
|
traf := n.CheckIn(vpnIP)
|
||||||
if traf {
|
if traf {
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).
|
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIP)
|
||||||
|
@ -226,7 +228,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIP)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIP)
|
||||||
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,7 +238,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||||
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
|
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
|
||||||
cn = hostinfo.ConnectionState.peerCert.Details.Name
|
cn = hostinfo.ConnectionState.peerCert.Details.Name
|
||||||
}
|
}
|
||||||
hostinfo.logger().
|
hostinfo.logger(n.l).
|
||||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
||||||
WithField("certName", cn).
|
WithField("certName", cn).
|
||||||
Info("Tunnel status")
|
Info("Tunnel status")
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
var vpnIP uint32
|
var vpnIP uint32
|
||||||
|
|
||||||
func Test_NewConnectionManagerTest(t *testing.T) {
|
func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
@ -20,7 +21,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
|
|
||||||
// Very incomplete mock objects
|
// Very incomplete mock objects
|
||||||
hostMap := NewHostMap("test", vpncidr, preferredRanges)
|
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
rawCertificate: []byte{},
|
rawCertificate: []byte{},
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
|
@ -28,7 +29,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
rawCertificateNoKey: []byte{},
|
rawCertificateNoKey: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &Tun{},
|
inside: &Tun{},
|
||||||
|
@ -36,12 +37,13 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
certState: cs,
|
certState: cs,
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
nc := newConnectionManager(ifce, 5, 10)
|
nc := newConnectionManager(l, ifce, 5, 10)
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
|
@ -79,13 +81,14 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewConnectionManagerTest2(t *testing.T) {
|
func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
|
|
||||||
// Very incomplete mock objects
|
// Very incomplete mock objects
|
||||||
hostMap := NewHostMap("test", vpncidr, preferredRanges)
|
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
rawCertificate: []byte{},
|
rawCertificate: []byte{},
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
|
@ -93,7 +96,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
rawCertificateNoKey: []byte{},
|
rawCertificateNoKey: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &Tun{},
|
inside: &Tun{},
|
||||||
|
@ -101,12 +104,13 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
certState: cs,
|
certState: cs,
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
nc := newConnectionManager(ifce, 5, 10)
|
nc := newConnectionManager(l, ifce, 5, 10)
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,7 +27,7 @@ type ConnectionState struct {
|
||||||
ready bool
|
ready bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||||
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
|
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
|
||||||
if f.cipher == "chachapoly" {
|
if f.cipher == "chachapoly" {
|
||||||
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||||
|
@ -37,7 +38,7 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
|
||||||
|
|
||||||
b := NewBits(ReplayWindow)
|
b := NewBits(ReplayWindow)
|
||||||
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
|
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
|
||||||
b.Update(0)
|
b.Update(l, 0)
|
||||||
|
|
||||||
hs, err := noise.NewHandshakeState(noise.Config{
|
hs, err := noise.NewHandshakeState(noise.Config{
|
||||||
CipherSuite: cs,
|
CipherSuite: cs,
|
||||||
|
|
|
@ -13,9 +13,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||||
// To properly ensure we are not exposing core memory to the caller
|
// To properly ensure we are not exposing core memory to the caller
|
||||||
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
|
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||||
remote1 := NewUDPAddr(int2ip(100), 4444)
|
remote1 := NewUDPAddr(int2ip(100), 4444)
|
||||||
remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
|
remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
|
||||||
ipNet := net.IPNet{
|
ipNet := net.IPNet{
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This whole thing should be rewritten to use context
|
// This whole thing should be rewritten to use context
|
||||||
|
@ -63,7 +64,7 @@ func (d *dnsRecords) Add(host, data string) {
|
||||||
d.Unlock()
|
d.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
|
||||||
for _, q := range m.Question {
|
for _, q := range m.Question {
|
||||||
switch q.Qtype {
|
switch q.Qtype {
|
||||||
case dns.TypeA:
|
case dns.TypeA:
|
||||||
|
@ -95,34 +96,38 @@ func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
m.Compress = false
|
m.Compress = false
|
||||||
|
|
||||||
switch r.Opcode {
|
switch r.Opcode {
|
||||||
case dns.OpcodeQuery:
|
case dns.OpcodeQuery:
|
||||||
parseQuery(m, w)
|
parseQuery(l, m, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dnsMain(hostMap *HostMap, c *Config) {
|
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
|
||||||
dnsR = newDnsRecords(hostMap)
|
dnsR = newDnsRecords(hostMap)
|
||||||
|
|
||||||
// attach request handler func
|
// attach request handler func
|
||||||
dns.HandleFunc(".", handleDnsRequest)
|
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
handleDnsRequest(l, w, r)
|
||||||
|
})
|
||||||
|
|
||||||
c.RegisterReloadCallback(reloadDns)
|
c.RegisterReloadCallback(func(c *Config) {
|
||||||
startDns(c)
|
reloadDns(l, c)
|
||||||
|
})
|
||||||
|
startDns(l, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDnsServerAddr(c *Config) string {
|
func getDnsServerAddr(c *Config) string {
|
||||||
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
|
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
|
||||||
}
|
}
|
||||||
|
|
||||||
func startDns(c *Config) {
|
func startDns(l *logrus.Logger, c *Config) {
|
||||||
dnsAddr = getDnsServerAddr(c)
|
dnsAddr = getDnsServerAddr(c)
|
||||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
||||||
l.Debugf("Starting DNS responder at %s\n", dnsAddr)
|
l.Debugf("Starting DNS responder at %s\n", dnsAddr)
|
||||||
|
@ -133,7 +138,7 @@ func startDns(c *Config) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func reloadDns(c *Config) {
|
func reloadDns(l *logrus.Logger, c *Config) {
|
||||||
if dnsAddr == getDnsServerAddr(c) {
|
if dnsAddr == getDnsServerAddr(c) {
|
||||||
l.Debug("No DNS server config change detected")
|
l.Debug("No DNS server config change detected")
|
||||||
return
|
return
|
||||||
|
@ -141,5 +146,5 @@ func reloadDns(c *Config) {
|
||||||
|
|
||||||
l.Debug("Restarting DNS server")
|
l.Debug("Restarting DNS server")
|
||||||
dnsServer.Shutdown()
|
dnsServer.Shutdown()
|
||||||
go startDns(c)
|
go startDns(l, c)
|
||||||
}
|
}
|
||||||
|
|
31
firewall.go
31
firewall.go
|
@ -70,6 +70,7 @@ type Firewall struct {
|
||||||
|
|
||||||
trackTCPRTT bool
|
trackTCPRTT bool
|
||||||
metricTCPRTT metrics.Histogram
|
metricTCPRTT metrics.Histogram
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type FirewallConntrack struct {
|
type FirewallConntrack struct {
|
||||||
|
@ -156,7 +157,7 @@ func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||||
func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||||
//TODO: error on 0 duration
|
//TODO: error on 0 duration
|
||||||
var min, max time.Duration
|
var min, max time.Duration
|
||||||
|
|
||||||
|
@ -195,11 +196,13 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
|
||||||
DefaultTimeout: defaultTimeout,
|
DefaultTimeout: defaultTimeout,
|
||||||
localIps: localIps,
|
localIps: localIps,
|
||||||
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
||||||
fw := NewFirewall(
|
fw := NewFirewall(
|
||||||
|
l,
|
||||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||||
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||||
|
@ -207,12 +210,12 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
|
||||||
//TODO: max_connections
|
//TODO: max_connections
|
||||||
)
|
)
|
||||||
|
|
||||||
err := AddFirewallRulesFromConfig(false, c, fw)
|
err := AddFirewallRulesFromConfig(l, false, c, fw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = AddFirewallRulesFromConfig(true, c, fw)
|
err = AddFirewallRulesFromConfig(l, true, c, fw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -240,7 +243,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||||
if !incoming {
|
if !incoming {
|
||||||
direction = "outgoing"
|
direction = "outgoing"
|
||||||
}
|
}
|
||||||
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||||
Info("Firewall rule added")
|
Info("Firewall rule added")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -276,7 +279,7 @@ func (f *Firewall) GetRuleHash() string {
|
||||||
return hex.EncodeToString(sum[:])
|
return hex.EncodeToString(sum[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error {
|
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
|
||||||
var table string
|
var table string
|
||||||
if inbound {
|
if inbound {
|
||||||
table = "firewall.inbound"
|
table = "firewall.inbound"
|
||||||
|
@ -296,7 +299,7 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
|
||||||
|
|
||||||
for i, t := range rs {
|
for i, t := range rs {
|
||||||
var groups []string
|
var groups []string
|
||||||
r, err := convertRule(t, table, i)
|
r, err := convertRule(l, t, table, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
||||||
}
|
}
|
||||||
|
@ -459,8 +462,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
||||||
|
|
||||||
// We now know which firewall table to check against
|
// We now know which firewall table to check against
|
||||||
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
h.logger().
|
h.logger(f.l).
|
||||||
WithField("fwPacket", fp).
|
WithField("fwPacket", fp).
|
||||||
WithField("incoming", c.incoming).
|
WithField("incoming", c.incoming).
|
||||||
WithField("rulesVersion", f.rulesVersion).
|
WithField("rulesVersion", f.rulesVersion).
|
||||||
|
@ -472,8 +475,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
h.logger().
|
h.logger(f.l).
|
||||||
WithField("fwPacket", fp).
|
WithField("fwPacket", fp).
|
||||||
WithField("incoming", c.incoming).
|
WithField("incoming", c.incoming).
|
||||||
WithField("rulesVersion", f.rulesVersion).
|
WithField("rulesVersion", f.rulesVersion).
|
||||||
|
@ -795,7 +798,7 @@ type rule struct {
|
||||||
CASha string
|
CASha string
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertRule(p interface{}, table string, i int) (rule, error) {
|
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
|
||||||
r := rule{}
|
r := rule{}
|
||||||
|
|
||||||
m, ok := p.(map[interface{}]interface{})
|
m, ok := p.(map[interface{}]interface{})
|
||||||
|
@ -968,14 +971,14 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
||||||
|
|
||||||
// Get checks if the cache ticker has moved to the next version before returning
|
// Get checks if the cache ticker has moved to the next version before returning
|
||||||
// the map. If it has moved, we reset the map.
|
// the map. If it has moved, we reset the map.
|
||||||
func (c *ConntrackCacheTicker) Get() ConntrackCache {
|
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
|
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
|
||||||
c.cacheV = tick
|
c.cacheV = tick
|
||||||
if ll := len(c.cache); ll > 0 {
|
if ll := len(c.cache); ll > 0 {
|
||||||
if l.GetLevel() == logrus.DebugLevel {
|
if l.Level == logrus.DebugLevel {
|
||||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||||
}
|
}
|
||||||
c.cache = make(ConntrackCache, ll)
|
c.cache = make(ConntrackCache, ll)
|
||||||
|
|
149
firewall_test.go
149
firewall_test.go
|
@ -15,8 +15,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewFirewall(t *testing.T) {
|
func TestNewFirewall(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
conntrack := fw.Conntrack
|
conntrack := fw.Conntrack
|
||||||
assert.NotNil(t, conntrack)
|
assert.NotNil(t, conntrack)
|
||||||
assert.NotNil(t, conntrack.Conns)
|
assert.NotNil(t, conntrack.Conns)
|
||||||
|
@ -31,35 +32,34 @@ func TestNewFirewall(t *testing.T) {
|
||||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
|
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
|
||||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
|
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
|
||||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
|
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
|
||||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
|
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
|
||||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
|
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
|
||||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_AddRule(t *testing.T) {
|
func TestFirewall_AddRule(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
out := l.Out
|
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
defer l.SetOutput(out)
|
|
||||||
|
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.NotNil(t, fw.InRules)
|
assert.NotNil(t, fw.InRules)
|
||||||
assert.NotNil(t, fw.OutRules)
|
assert.NotNil(t, fw.OutRules)
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
|
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
|
||||||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
|
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
|
||||||
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
||||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
||||||
|
@ -83,7 +83,7 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
|
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
|
||||||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
|
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
|
||||||
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
||||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||||
|
@ -92,23 +92,23 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
|
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
|
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
|
||||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
|
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||||
|
|
||||||
// Set any and clear fields
|
// Set any and clear fields
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
|
||||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
||||||
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
||||||
|
@ -125,26 +125,25 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
|
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
|
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||||
|
|
||||||
// Test error conditions
|
// Test error conditions
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
|
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
|
||||||
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
|
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
out := l.Out
|
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
defer l.SetOutput(out)
|
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := FirewallPacket{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
@ -177,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
|
||||||
}
|
}
|
||||||
h.CreateRemoteCIDR(&c)
|
h.CreateRemoteCIDR(&c)
|
||||||
|
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
|
@ -196,27 +195,27 @@ func TestFirewall_Drop(t *testing.T) {
|
||||||
p.RemoteIP = oldRemote
|
p.RemoteIP = oldRemote
|
||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
||||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
||||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
||||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
||||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||||
|
@ -317,10 +316,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop2(t *testing.T) {
|
func TestFirewall_Drop2(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
out := l.Out
|
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
defer l.SetOutput(out)
|
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := FirewallPacket{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
@ -365,7 +363,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||||
}
|
}
|
||||||
h1.CreateRemoteCIDR(&c1)
|
h1.CreateRemoteCIDR(&c1)
|
||||||
|
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
|
@ -377,10 +375,9 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3(t *testing.T) {
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
out := l.Out
|
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
defer l.SetOutput(out)
|
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := FirewallPacket{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
@ -448,7 +445,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||||
}
|
}
|
||||||
h3.CreateRemoteCIDR(&c3)
|
h3.CreateRemoteCIDR(&c3)
|
||||||
|
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
@ -464,10 +461,9 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
out := l.Out
|
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
defer l.SetOutput(out)
|
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := FirewallPacket{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
@ -500,7 +496,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
}
|
}
|
||||||
h.CreateRemoteCIDR(&c)
|
h.CreateRemoteCIDR(&c)
|
||||||
|
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
|
@ -513,7 +509,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
|
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
|
||||||
|
|
||||||
oldFw := fw
|
oldFw := fw
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
@ -522,7 +518,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
|
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
|
||||||
|
|
||||||
oldFw = fw
|
oldFw = fw
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
@ -647,124 +643,126 @@ func Test_parsePort(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewFirewallFromConfig(t *testing.T) {
|
func TestNewFirewallFromConfig(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
// Test a bad rule definition
|
// Test a bad rule definition
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
conf := NewConfig()
|
conf := NewConfig(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
||||||
_, err := NewFirewallFromConfig(c, conf)
|
_, err := NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||||
|
|
||||||
// Test both port and code
|
// Test both port and code
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
||||||
_, err = NewFirewallFromConfig(c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||||
|
|
||||||
// Test missing host, group, cidr, ca_name and ca_sha
|
// Test missing host, group, cidr, ca_name and ca_sha
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
||||||
_, err = NewFirewallFromConfig(c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
|
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
|
||||||
|
|
||||||
// Test code/port error
|
// Test code/port error
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||||
|
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||||
|
|
||||||
// Test proto error
|
// Test proto error
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||||
|
|
||||||
// Test cidr parse error
|
// Test cidr parse error
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
|
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
|
||||||
|
|
||||||
// Test both group and groups
|
// Test both group and groups
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||||
_, err = NewFirewallFromConfig(c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
// Test adding tcp rule
|
// Test adding tcp rule
|
||||||
conf := NewConfig()
|
conf := NewConfig(l)
|
||||||
mf := &mockFirewall{}
|
mf := &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding udp rule
|
// Test adding udp rule
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding icmp rule
|
// Test adding icmp rule
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding any rule
|
// Test adding any rule
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_name
|
// Test adding rule with ca_name
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
|
||||||
|
|
||||||
// Test single group
|
// Test single group
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test single groups
|
// Test single groups
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test multiple AND groups
|
// Test multiple AND groups
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test Add error
|
// Test Add error
|
||||||
conf = NewConfig()
|
conf = NewConfig(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
mf.nextCallReturn = errors.New("test error")
|
mf.nextCallReturn = errors.New("test error")
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
|
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTCPRTTTracking(t *testing.T) {
|
func TestTCPRTTTracking(t *testing.T) {
|
||||||
|
@ -859,17 +857,16 @@ func TestTCPRTTTracking(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRule(t *testing.T) {
|
func TestFirewall_convertRule(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
out := l.Out
|
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
defer l.SetOutput(out)
|
|
||||||
|
|
||||||
// Ensure group array of 1 is converted and a warning is printed
|
// Ensure group array of 1 is converted and a warning is printed
|
||||||
c := map[interface{}]interface{}{
|
c := map[interface{}]interface{}{
|
||||||
"group": []interface{}{"group1"},
|
"group": []interface{}{"group1"},
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := convertRule(c, "test", 1)
|
r, err := convertRule(l, c, "test", 1)
|
||||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "group1", r.Group)
|
assert.Equal(t, "group1", r.Group)
|
||||||
|
@ -880,7 +877,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||||
"group": []interface{}{"group1", "group2"},
|
"group": []interface{}{"group1", "group2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err = convertRule(c, "test", 1)
|
r, err = convertRule(l, c, "test", 1)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
|
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
|
||||||
|
|
||||||
|
@ -890,7 +887,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||||
"group": "group1",
|
"group": "group1",
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err = convertRule(c, "test", 1)
|
r, err = convertRule(l, c, "test", 1)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "group1", r.Group)
|
assert.Equal(t, "group1", r.Group)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ const (
|
||||||
|
|
||||||
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
|
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
|
||||||
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
|
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
|
||||||
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||||
|
|
||||||
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
|
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -48,7 +48,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||||
hsBytes, err = proto.Marshal(hs)
|
hsBytes, err = proto.Marshal(hs)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -58,14 +58,14 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||||
|
|
||||||
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
|
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// We are sending handshake packet 1, so we don't expect to receive
|
// We are sending handshake packet 1, so we don't expect to receive
|
||||||
// handshake packet 1 from the responder
|
// handshake packet 1 from the responder
|
||||||
ci.window.Update(1)
|
ci.window.Update(f.l, 1)
|
||||||
|
|
||||||
hostinfo.HandshakePacket[0] = msg
|
hostinfo.HandshakePacket[0] = msg
|
||||||
hostinfo.HandshakeReady = true
|
hostinfo.HandshakeReady = true
|
||||||
|
@ -74,13 +74,13 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
|
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
|
||||||
// Mark packet 1 as seen so it doesn't show up as missed
|
// Mark packet 1 as seen so it doesn't show up as missed
|
||||||
ci.window.Update(1)
|
ci.window.Update(f.l, 1)
|
||||||
|
|
||||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -91,14 +91,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
|
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
|
||||||
*/
|
*/
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||||
Info("Invalid certificate from host")
|
Info("Invalid certificate from host")
|
||||||
return
|
return
|
||||||
|
@ -108,16 +108,16 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
fingerprint, _ := remoteCert.Sha256Sum()
|
fingerprint, _ := remoteCert.Sha256Sum()
|
||||||
|
|
||||||
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
|
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
myIndex, err := generateIndex()
|
myIndex, err := generateIndex(f.l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||||
|
@ -133,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
@ -145,7 +145,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
|
|
||||||
hsBytes, err := proto.Marshal(hs)
|
hsBytes, err := proto.Marshal(hs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||||
|
@ -155,13 +155,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||||
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
|
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
return
|
return
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
||||||
|
@ -178,7 +178,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
|
|
||||||
// We are sending handshake packet 2, so we don't expect to receive
|
// We are sending handshake packet 2, so we don't expect to receive
|
||||||
// handshake packet 2 from the initiator.
|
// handshake packet 2 from the initiator.
|
||||||
ci.window.Update(2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
ci.peerCert = remoteCert
|
ci.peerCert = remoteCert
|
||||||
ci.dKey = NewNebulaCipherState(dKey)
|
ci.dKey = NewNebulaCipherState(dKey)
|
||||||
|
@ -203,11 +203,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||||
err := f.outside.WriteTo(msg, addr)
|
err := f.outside.WriteTo(msg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||||
WithError(err).Error("Failed to send handshake message")
|
WithError(err).Error("Failed to send handshake message")
|
||||||
} else {
|
} else {
|
||||||
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
|
@ -215,7 +215,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
case ErrExistingHostInfo:
|
case ErrExistingHostInfo:
|
||||||
// This means there was an existing tunnel and we didn't win
|
// This means there was an existing tunnel and we didn't win
|
||||||
// handshake avoidance
|
// handshake avoidance
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
@ -227,7 +227,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
return
|
return
|
||||||
case ErrLocalIndexCollision:
|
case ErrLocalIndexCollision:
|
||||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
@ -238,7 +238,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
default:
|
default:
|
||||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||||
// And we forget to update it here
|
// And we forget to update it here
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
@ -252,14 +252,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||||
err = f.outside.WriteTo(msg, addr)
|
err = f.outside.WriteTo(msg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithError(err).Error("Failed to send handshake")
|
WithError(err).Error("Failed to send handshake")
|
||||||
} else {
|
} else {
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
@ -267,7 +267,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.handshakeComplete()
|
hostinfo.handshakeComplete(f.l)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -280,7 +280,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
defer hostinfo.Unlock()
|
defer hostinfo.Unlock()
|
||||||
|
|
||||||
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
|
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
|
||||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||||
Info("Already seen this handshake packet")
|
Info("Already seen this handshake packet")
|
||||||
return false
|
return false
|
||||||
|
@ -288,14 +288,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
|
|
||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
// Mark packet 2 as seen so it doesn't show up as missed
|
// Mark packet 2 as seen so it doesn't show up as missed
|
||||||
ci.window.Update(2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:]))
|
hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:]))
|
||||||
copy(hostinfo.HandshakePacket[2], packet[HeaderLen:])
|
copy(hostinfo.HandshakePacket[2], packet[HeaderLen:])
|
||||||
|
|
||||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||||
Error("Failed to call noise.ReadMessage")
|
Error("Failed to call noise.ReadMessage")
|
||||||
|
|
||||||
|
@ -304,7 +304,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
// near future
|
// near future
|
||||||
return false
|
return false
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Error("Noise did not arrive at a key")
|
Error("Noise did not arrive at a key")
|
||||||
return true
|
return true
|
||||||
|
@ -313,14 +313,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = proto.Unmarshal(msg, hs)
|
err = proto.Unmarshal(msg, hs)
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Error("Invalid certificate from host")
|
Error("Invalid certificate from host")
|
||||||
return true
|
return true
|
||||||
|
@ -330,7 +330,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
fingerprint, _ := remoteCert.Sha256Sum()
|
fingerprint, _ := remoteCert.Sha256Sum()
|
||||||
|
|
||||||
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
|
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
@ -362,7 +362,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
hostinfo.CreateRemoteCIDR(remoteCert)
|
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||||
|
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
hostinfo.handshakeComplete()
|
hostinfo.handshakeComplete(f.l)
|
||||||
f.metricHandshakes.Update(duration)
|
f.metricHandshakes.Update(duration)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -53,11 +53,12 @@ type HandshakeManager struct {
|
||||||
InboundHandshakeTimer *SystemTimerWheel
|
InboundHandshakeTimer *SystemTimerWheel
|
||||||
|
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
|
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
|
||||||
return &HandshakeManager{
|
return &HandshakeManager{
|
||||||
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
|
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
|
||||||
mainHostMap: mainHostMap,
|
mainHostMap: mainHostMap,
|
||||||
lightHouse: lightHouse,
|
lightHouse: lightHouse,
|
||||||
outside: outside,
|
outside: outside,
|
||||||
|
@ -70,6 +71,7 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH
|
||||||
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
||||||
|
|
||||||
messageMetrics: config.messageMetrics,
|
messageMetrics: config.messageMetrics,
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +80,7 @@ func (c *HandshakeManager) Run(f EncWriter) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case vpnIP := <-c.trigger:
|
case vpnIP := <-c.trigger:
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
|
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
|
||||||
c.handleOutbound(vpnIP, f, true)
|
c.handleOutbound(vpnIP, f, true)
|
||||||
case now := <-clockSource:
|
case now := <-clockSource:
|
||||||
c.NextOutboundHandshakeTimerTick(now, f)
|
c.NextOutboundHandshakeTimerTick(now, f)
|
||||||
|
@ -149,7 +151,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
|
||||||
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||||
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
|
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
@ -157,7 +159,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
|
||||||
} else {
|
} else {
|
||||||
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
|
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
|
||||||
// keep the real packet struct around for logging purposes
|
// keep the real packet struct around for logging purposes
|
||||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
@ -245,7 +247,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
||||||
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
|
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
|
||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger().
|
hostinfo.logger(c.l).
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
||||||
Info("New host shadows existing host remoteIndex")
|
Info("New host shadows existing host remoteIndex")
|
||||||
}
|
}
|
||||||
|
@ -280,7 +282,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
||||||
if found && existingRemoteIndex != nil {
|
if found && existingRemoteIndex != nil {
|
||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger().
|
hostinfo.logger(c.l).
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
||||||
Info("New host shadows existing host remoteIndex")
|
Info("New host shadows existing host remoteIndex")
|
||||||
}
|
}
|
||||||
|
@ -298,7 +300,7 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
|
||||||
defer c.mainHostMap.RUnlock()
|
defer c.mainHostMap.RUnlock()
|
||||||
|
|
||||||
for i := 0; i < 32; i++ {
|
for i := 0; i < 32; i++ {
|
||||||
index, err := generateIndex()
|
index, err := generateIndex(c.l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -336,7 +338,7 @@ func (c *HandshakeManager) EmitStats() {
|
||||||
|
|
||||||
// Utility functions below
|
// Utility functions below
|
||||||
|
|
||||||
func generateIndex() (uint32, error) {
|
func generateIndex(l *logrus.Logger) (uint32, error) {
|
||||||
b := make([]byte, 4)
|
b := make([]byte, 4)
|
||||||
|
|
||||||
// Let zero mean we don't know the ID, so don't generate zero
|
// Let zero mean we don't know the ID, so don't generate zero
|
||||||
|
|
|
@ -12,15 +12,15 @@ import (
|
||||||
var ips []uint32
|
var ips []uint32
|
||||||
|
|
||||||
func Test_NewHandshakeManagerIndex(t *testing.T) {
|
func Test_NewHandshakeManagerIndex(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
|
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
|
|
||||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextInboundHandshakeTimerTick(now)
|
blah.NextInboundHandshakeTimerTick(now)
|
||||||
|
@ -63,15 +63,16 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
|
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
|
|
||||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||||
|
@ -112,16 +113,17 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
ip := ip2int(net.ParseIP("172.1.1.2"))
|
ip := ip2int(net.ParseIP("172.1.1.2"))
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
lh := &LightHouse{}
|
lh := &LightHouse{}
|
||||||
|
|
||||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||||
|
@ -162,15 +164,16 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
|
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
|
|
||||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||||
|
@ -216,13 +219,14 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
|
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
|
|
||||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextInboundHandshakeTimerTick(now)
|
blah.NextInboundHandshakeTimerTick(now)
|
||||||
|
|
52
hostmap.go
52
hostmap.go
|
@ -33,6 +33,7 @@ type HostMap struct {
|
||||||
defaultRoute uint32
|
defaultRoute uint32
|
||||||
unsafeRoutes *CIDRTree
|
unsafeRoutes *CIDRTree
|
||||||
metricsEnabled bool
|
metricsEnabled bool
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
|
@ -83,7 +84,7 @@ type Probe struct {
|
||||||
Counter int
|
Counter int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
|
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
|
||||||
h := map[uint32]*HostInfo{}
|
h := map[uint32]*HostInfo{}
|
||||||
i := map[uint32]*HostInfo{}
|
i := map[uint32]*HostInfo{}
|
||||||
r := map[uint32]*HostInfo{}
|
r := map[uint32]*HostInfo{}
|
||||||
|
@ -96,6 +97,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *
|
||||||
vpnCIDR: vpnCIDR,
|
vpnCIDR: vpnCIDR,
|
||||||
defaultRoute: 0,
|
defaultRoute: 0,
|
||||||
unsafeRoutes: NewCIDRTree(),
|
unsafeRoutes: NewCIDRTree(),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
return &m
|
return &m
|
||||||
}
|
}
|
||||||
|
@ -160,8 +162,8 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
|
||||||
}
|
}
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
|
||||||
Debug("Hostmap vpnIp deleted")
|
Debug("Hostmap vpnIp deleted")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -173,8 +175,8 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
|
||||||
hm.RemoteIndexes[index] = h
|
hm.RemoteIndexes[index] = h
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
|
||||||
if l.Level > logrus.DebugLevel {
|
if hm.l.Level > logrus.DebugLevel {
|
||||||
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
|
||||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
||||||
Debug("Hostmap remoteIndex added")
|
Debug("Hostmap remoteIndex added")
|
||||||
}
|
}
|
||||||
|
@ -188,8 +190,8 @@ func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
|
||||||
hm.RemoteIndexes[h.remoteIndexId] = h
|
hm.RemoteIndexes[h.remoteIndexId] = h
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
|
||||||
if l.Level > logrus.DebugLevel {
|
if hm.l.Level > logrus.DebugLevel {
|
||||||
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
|
||||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
||||||
Debug("Hostmap vpnIp added")
|
Debug("Hostmap vpnIp added")
|
||||||
}
|
}
|
||||||
|
@ -212,8 +214,8 @@ func (hm *HostMap) DeleteIndex(index uint32) {
|
||||||
}
|
}
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
|
||||||
Debug("Hostmap index deleted")
|
Debug("Hostmap index deleted")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -236,8 +238,8 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
|
||||||
}
|
}
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
|
||||||
Debug("Hostmap remote index deleted")
|
Debug("Hostmap remote index deleted")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -269,8 +271,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
|
||||||
}
|
}
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
|
||||||
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||||
Debug("Hostmap hostInfo deleted")
|
Debug("Hostmap hostInfo deleted")
|
||||||
}
|
}
|
||||||
|
@ -313,9 +315,11 @@ func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
|
||||||
}
|
}
|
||||||
i.remote = i.Remotes[0].addr
|
i.remote = i.Remotes[0].addr
|
||||||
hm.Hosts[vpnIp] = i
|
hm.Hosts[vpnIp] = i
|
||||||
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
|
||||||
Debug("Hostmap remote ip added")
|
Debug("Hostmap remote ip added")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
i.ForcePromoteBest(hm.preferredRanges)
|
i.ForcePromoteBest(hm.preferredRanges)
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
return i
|
return i
|
||||||
|
@ -377,8 +381,8 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
|
||||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
|
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
|
||||||
Debug("Hostmap vpnIp added")
|
Debug("Hostmap vpnIp added")
|
||||||
}
|
}
|
||||||
|
@ -436,7 +440,7 @@ func (hm *HostMap) Punchy(conn *udpConn) {
|
||||||
|
|
||||||
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
|
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
|
||||||
for _, r := range *routes {
|
for _, r := range *routes {
|
||||||
l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
|
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
|
||||||
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
|
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -566,7 +570,7 @@ func (i *HostInfo) rotateRemote() {
|
||||||
i.remote = i.Remotes[0].addr
|
i.remote = i.Remotes[0].addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
|
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
|
||||||
//TODO: return the error so we can log with more context
|
//TODO: return the error so we can log with more context
|
||||||
if len(i.packetStore) < 100 {
|
if len(i.packetStore) < 100 {
|
||||||
tempPacket := make([]byte, len(packet))
|
tempPacket := make([]byte, len(packet))
|
||||||
|
@ -574,14 +578,14 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
|
||||||
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
|
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
|
||||||
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
|
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Level >= logrus.DebugLevel {
|
||||||
i.logger().
|
i.logger(l).
|
||||||
WithField("length", len(i.packetStore)).
|
WithField("length", len(i.packetStore)).
|
||||||
WithField("stored", true).
|
WithField("stored", true).
|
||||||
Debugf("Packet store")
|
Debugf("Packet store")
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if l.Level >= logrus.DebugLevel {
|
} else if l.Level >= logrus.DebugLevel {
|
||||||
i.logger().
|
i.logger(l).
|
||||||
WithField("length", len(i.packetStore)).
|
WithField("length", len(i.packetStore)).
|
||||||
WithField("stored", false).
|
WithField("stored", false).
|
||||||
Debugf("Packet store")
|
Debugf("Packet store")
|
||||||
|
@ -589,7 +593,7 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
|
||||||
}
|
}
|
||||||
|
|
||||||
// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
|
// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
|
||||||
func (i *HostInfo) handshakeComplete() {
|
func (i *HostInfo) handshakeComplete(l *logrus.Logger) {
|
||||||
//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
|
//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
|
||||||
//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
|
//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
|
||||||
//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
|
//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
|
||||||
|
@ -601,7 +605,7 @@ func (i *HostInfo) handshakeComplete() {
|
||||||
atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
|
atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Level >= logrus.DebugLevel {
|
||||||
i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
|
i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(i.packetStore) > 0 {
|
if len(i.packetStore) > 0 {
|
||||||
|
@ -689,7 +693,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
||||||
i.remoteCidr = remoteCidr
|
i.remoteCidr = remoteCidr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) logger() *logrus.Entry {
|
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||||
if i == nil {
|
if i == nil {
|
||||||
return logrus.NewEntry(l)
|
return logrus.NewEntry(l)
|
||||||
}
|
}
|
||||||
|
@ -804,7 +808,7 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) {
|
||||||
|
|
||||||
// Utility functions
|
// Utility functions
|
||||||
|
|
||||||
func localIps(allowList *AllowList) *[]net.IP {
|
func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
|
||||||
//FIXME: This function is pretty garbage
|
//FIXME: This function is pretty garbage
|
||||||
var ips []net.IP
|
var ips []net.IP
|
||||||
ifaces, _ := net.Interfaces()
|
ifaces, _ := net.Interfaces()
|
||||||
|
|
|
@ -64,12 +64,13 @@ func TestHostInfoDestProbe(t *testing.T) {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
func TestHostmap(t *testing.T) {
|
func TestHostmap(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
|
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
|
||||||
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
|
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
|
||||||
myNets := []*net.IPNet{myNet}
|
myNets := []*net.IPNet{myNet}
|
||||||
preferredRanges := []*net.IPNet{localToMe}
|
preferredRanges := []*net.IPNet{localToMe}
|
||||||
|
|
||||||
m := NewHostMap("test", myNet, preferredRanges)
|
m := NewHostMap(l, "test", myNet, preferredRanges)
|
||||||
|
|
||||||
a := NewUDPAddrFromString("10.127.0.3:11111")
|
a := NewUDPAddrFromString("10.127.0.3:11111")
|
||||||
b := NewUDPAddrFromString("1.0.0.1:22222")
|
b := NewUDPAddrFromString("1.0.0.1:22222")
|
||||||
|
@ -103,10 +104,11 @@ func TestHostmap(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHostmapdebug(t *testing.T) {
|
func TestHostmapdebug(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
|
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
|
||||||
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
|
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
|
||||||
preferredRanges := []*net.IPNet{localToMe}
|
preferredRanges := []*net.IPNet{localToMe}
|
||||||
m := NewHostMap("test", myNet, preferredRanges)
|
m := NewHostMap(l, "test", myNet, preferredRanges)
|
||||||
|
|
||||||
a := NewUDPAddrFromString("10.127.0.3:11111")
|
a := NewUDPAddrFromString("10.127.0.3:11111")
|
||||||
b := NewUDPAddrFromString("1.0.0.1:22222")
|
b := NewUDPAddrFromString("1.0.0.1:22222")
|
||||||
|
@ -151,11 +153,12 @@ func TestHostMap_rotateRemote(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkHostmappromote2(b *testing.B) {
|
func BenchmarkHostmappromote2(b *testing.B) {
|
||||||
|
l := NewTestLogger()
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
|
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
|
||||||
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
|
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
|
||||||
preferredRanges := []*net.IPNet{localToMe}
|
preferredRanges := []*net.IPNet{localToMe}
|
||||||
m := NewHostMap("test", myNet, preferredRanges)
|
m := NewHostMap(l, "test", myNet, preferredRanges)
|
||||||
y := NewUDPAddrFromString("10.128.0.3:11111")
|
y := NewUDPAddrFromString("10.128.0.3:11111")
|
||||||
a := NewUDPAddrFromString("10.127.0.3:11111")
|
a := NewUDPAddrFromString("10.127.0.3:11111")
|
||||||
g := NewUDPAddrFromString("1.0.0.1:22222")
|
g := NewUDPAddrFromString("1.0.0.1:22222")
|
||||||
|
|
40
inside.go
40
inside.go
|
@ -10,7 +10,7 @@ import (
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,8 +31,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
|
|
||||||
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
|
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
|
f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
|
||||||
WithField("fwPacket", fwPacket).
|
WithField("fwPacket", fwPacket).
|
||||||
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
|
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
// the packet queue.
|
// the packet queue.
|
||||||
ci.queueLock.Lock()
|
ci.queueLock.Lock()
|
||||||
if !ci.ready {
|
if !ci.ready {
|
||||||
hostinfo.cachePacket(message, 0, packet, f.sendMessageNow)
|
hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow)
|
||||||
ci.queueLock.Unlock()
|
ci.queueLock.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -59,8 +59,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
f.lightHouse.Query(fwPacket.RemoteIP, f)
|
f.lightHouse.Query(fwPacket.RemoteIP, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if l.Level >= logrus.DebugLevel {
|
} else if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger().
|
hostinfo.logger(f.l).
|
||||||
WithField("fwPacket", fwPacket).
|
WithField("fwPacket", fwPacket).
|
||||||
WithField("reason", dropReason).
|
WithField("reason", dropReason).
|
||||||
Debugln("dropping outbound packet")
|
Debugln("dropping outbound packet")
|
||||||
|
@ -104,7 +104,7 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
||||||
|
|
||||||
if ci == nil {
|
if ci == nil {
|
||||||
// if we don't have a connection state, then send a handshake initiation
|
// if we don't have a connection state, then send a handshake initiation
|
||||||
ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0)
|
ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
|
||||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||||
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
|
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
|
||||||
hostinfo.ConnectionState = ci
|
hostinfo.ConnectionState = ci
|
||||||
|
@ -135,15 +135,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
fp := &FirewallPacket{}
|
fp := &FirewallPacket{}
|
||||||
err := newPacket(p, false, fp)
|
err := newPacket(p, false, fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
|
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("fwPacket", fp).
|
f.l.WithField("fwPacket", fp).
|
||||||
WithField("reason", dropReason).
|
WithField("reason", dropReason).
|
||||||
Debugln("dropping cached packet")
|
Debugln("dropping cached packet")
|
||||||
}
|
}
|
||||||
|
@ -160,8 +160,8 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
||||||
hostInfo := f.getOrHandshake(vpnIp)
|
hostInfo := f.getOrHandshake(vpnIp)
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("vpnIp", IntIp(vpnIp)).
|
f.l.WithField("vpnIp", IntIp(vpnIp)).
|
||||||
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
|
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -172,7 +172,7 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
|
||||||
// the packet queue.
|
// the packet queue.
|
||||||
hostInfo.ConnectionState.queueLock.Lock()
|
hostInfo.ConnectionState.queueLock.Lock()
|
||||||
if !hostInfo.ConnectionState.ready {
|
if !hostInfo.ConnectionState.ready {
|
||||||
hostInfo.cachePacket(t, st, p, f.sendMessageToVpnIp)
|
hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp)
|
||||||
hostInfo.ConnectionState.queueLock.Unlock()
|
hostInfo.ConnectionState.queueLock.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -191,8 +191,8 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
|
||||||
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
||||||
hostInfo := f.getOrHandshake(vpnIp)
|
hostInfo := f.getOrHandshake(vpnIp)
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("vpnIp", IntIp(vpnIp)).
|
f.l.WithField("vpnIp", IntIp(vpnIp)).
|
||||||
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
|
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -203,7 +203,7 @@ func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
|
||||||
// the packet queue.
|
// the packet queue.
|
||||||
hostInfo.ConnectionState.queueLock.Lock()
|
hostInfo.ConnectionState.queueLock.Lock()
|
||||||
if !hostInfo.ConnectionState.ready {
|
if !hostInfo.ConnectionState.ready {
|
||||||
hostInfo.cachePacket(t, st, p, f.sendMessageToAll)
|
hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll)
|
||||||
hostInfo.ConnectionState.queueLock.Unlock()
|
hostInfo.ConnectionState.queueLock.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -247,8 +247,8 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||||
f.lightHouse.Query(hostinfo.hostId, f)
|
f.lightHouse.Query(hostinfo.hostId, f)
|
||||||
hostinfo.lastRebindCount = f.rebindCount
|
hostinfo.lastRebindCount = f.rebindCount
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
|
f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
//TODO: see above note on lock
|
//TODO: see above note on lock
|
||||||
//ci.writeLock.Unlock()
|
//ci.writeLock.Unlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger().WithError(err).
|
hostinfo.logger(f.l).WithError(err).
|
||||||
WithField("udpAddr", remote).WithField("counter", c).
|
WithField("udpAddr", remote).WithField("counter", c).
|
||||||
WithField("attemptedCounter", c).
|
WithField("attemptedCounter", c).
|
||||||
Error("Failed to encrypt outgoing packet")
|
Error("Failed to encrypt outgoing packet")
|
||||||
|
@ -265,7 +265,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
|
|
||||||
err = f.writers[q].WriteTo(out, remote)
|
err = f.writers[q].WriteTo(out, remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger().WithError(err).
|
hostinfo.logger(f.l).WithError(err).
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
|
|
40
interface.go
40
interface.go
|
@ -9,6 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const mtu = 9001
|
||||||
|
@ -42,6 +43,7 @@ type InterfaceConfig struct {
|
||||||
version string
|
version string
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
|
@ -73,6 +75,7 @@ type Interface struct {
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||||
|
@ -113,9 +116,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
messageMetrics: c.MessageMetrics,
|
messageMetrics: c.MessageMetrics,
|
||||||
|
l: c.l,
|
||||||
}
|
}
|
||||||
|
|
||||||
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
|
ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
|
||||||
|
|
||||||
return ifce, nil
|
return ifce, nil
|
||||||
}
|
}
|
||||||
|
@ -125,10 +129,10 @@ func (f *Interface) run() {
|
||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to get udp listen address")
|
f.l.WithError(err).Error("Failed to get udp listen address")
|
||||||
}
|
}
|
||||||
|
|
||||||
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
|
f.l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
|
||||||
WithField("build", f.version).WithField("udpAddr", addr).
|
WithField("build", f.version).WithField("udpAddr", addr).
|
||||||
Info("Nebula interface is active")
|
Info("Nebula interface is active")
|
||||||
|
|
||||||
|
@ -140,14 +144,14 @@ func (f *Interface) run() {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Fatal(err)
|
f.l.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.inside.Activate(); err != nil {
|
if err := f.inside.Activate(); err != nil {
|
||||||
l.Fatal(err)
|
f.l.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
|
@ -187,12 +191,12 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
n, err := reader.Read(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Error while reading outbound packet")
|
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,21 +212,21 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
|
||||||
func (f *Interface) reloadCA(c *Config) {
|
func (f *Interface) reloadCA(c *Config) {
|
||||||
// reload and check regardless
|
// reload and check regardless
|
||||||
// todo: need mutex?
|
// todo: need mutex?
|
||||||
newCAs, err := loadCAFromConfig(c)
|
newCAs, err := loadCAFromConfig(f.l, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Could not refresh trusted CA certificates")
|
f.l.WithError(err).Error("Could not refresh trusted CA certificates")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
trustedCAs = newCAs
|
trustedCAs = newCAs
|
||||||
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadCertKey(c *Config) {
|
func (f *Interface) reloadCertKey(c *Config) {
|
||||||
// reload and check in all cases
|
// reload and check in all cases
|
||||||
cs, err := NewCertStateFromConfig(c)
|
cs, err := NewCertStateFromConfig(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Could not refresh client cert")
|
f.l.WithError(err).Error("Could not refresh client cert")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,24 +234,24 @@ func (f *Interface) reloadCertKey(c *Config) {
|
||||||
oldIPs := f.certState.certificate.Details.Ips
|
oldIPs := f.certState.certificate.Details.Ips
|
||||||
newIPs := cs.certificate.Details.Ips
|
newIPs := cs.certificate.Details.Ips
|
||||||
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
||||||
l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
|
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.certState = cs
|
f.certState = cs
|
||||||
l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadFirewall(c *Config) {
|
func (f *Interface) reloadFirewall(c *Config) {
|
||||||
//TODO: need to trigger/detect if the certificate changed too
|
//TODO: need to trigger/detect if the certificate changed too
|
||||||
if c.HasChanged("firewall") == false {
|
if c.HasChanged("firewall") == false {
|
||||||
l.Debug("No firewall config change detected")
|
f.l.Debug("No firewall config change detected")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fw, err := NewFirewallFromConfig(f.certState.certificate, c)
|
fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Error while creating firewall during reload")
|
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,7 +264,7 @@ func (f *Interface) reloadFirewall(c *Config) {
|
||||||
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
||||||
// safe and just reset conntrack in this case.
|
// safe and just reset conntrack in this case.
|
||||||
if fw.rulesVersion == 0 {
|
if fw.rulesVersion == 0 {
|
||||||
l.WithField("firewallHash", fw.GetRuleHash()).
|
f.l.WithField("firewallHash", fw.GetRuleHash()).
|
||||||
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||||
WithField("rulesVersion", fw.rulesVersion).
|
WithField("rulesVersion", fw.rulesVersion).
|
||||||
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
||||||
|
@ -271,7 +275,7 @@ func (f *Interface) reloadFirewall(c *Config) {
|
||||||
f.firewall = fw
|
f.firewall = fw
|
||||||
|
|
||||||
oldFw.Destroy()
|
oldFw.Destroy()
|
||||||
l.WithField("firewallHash", fw.GetRuleHash()).
|
f.l.WithField("firewallHash", fw.GetRuleHash()).
|
||||||
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||||
WithField("rulesVersion", fw.rulesVersion).
|
WithField("rulesVersion", fw.rulesVersion).
|
||||||
Info("New firewall has been installed")
|
Info("New firewall has been installed")
|
||||||
|
|
|
@ -48,6 +48,7 @@ type LightHouse struct {
|
||||||
|
|
||||||
metrics *MessageMetrics
|
metrics *MessageMetrics
|
||||||
metricHolepunchTx metrics.Counter
|
metricHolepunchTx metrics.Counter
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
type EncWriter interface {
|
||||||
|
@ -55,7 +56,7 @@ type EncWriter interface {
|
||||||
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
|
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
|
func NewLightHouse(l *logrus.Logger, amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
|
||||||
h := LightHouse{
|
h := LightHouse{
|
||||||
amLighthouse: amLighthouse,
|
amLighthouse: amLighthouse,
|
||||||
myIp: myIp,
|
myIp: myIp,
|
||||||
|
@ -67,6 +68,7 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
|
||||||
punchConn: pc,
|
punchConn: pc,
|
||||||
punchBack: punchBack,
|
punchBack: punchBack,
|
||||||
punchDelay: punchDelay,
|
punchDelay: punchDelay,
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metricsEnabled {
|
if metricsEnabled {
|
||||||
|
@ -126,7 +128,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
|
||||||
// Send a query to the lighthouses and hope for the best next time
|
// Send a query to the lighthouses and hope for the best next time
|
||||||
query, err := proto.Marshal(NewLhQueryByInt(ip))
|
query, err := proto.Marshal(NewLhQueryByInt(ip))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
|
lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,7 +161,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
|
||||||
lh.Lock()
|
lh.Lock()
|
||||||
//l.Debugln(lh.addrMap)
|
//l.Debugln(lh.addrMap)
|
||||||
delete(lh.addrMap, vpnIP)
|
delete(lh.addrMap, vpnIP)
|
||||||
l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
|
lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
|
||||||
lh.Unlock()
|
lh.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,7 +183,7 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
allow := lh.remoteAllowList.Allow(toIp.IP)
|
allow := lh.remoteAllowList.Allow(toIp.IP)
|
||||||
l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
|
lh.l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
|
||||||
if !allow {
|
if !allow {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -270,7 +272,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
|
||||||
var v4 []*IpAndPort
|
var v4 []*IpAndPort
|
||||||
var v6 []*Ip6AndPort
|
var v6 []*Ip6AndPort
|
||||||
|
|
||||||
for _, e := range *localIps(lh.localAllowList) {
|
for _, e := range *localIps(lh.l, lh.localAllowList) {
|
||||||
// Only add IPs that aren't my VPN/tun IP
|
// Only add IPs that aren't my VPN/tun IP
|
||||||
if ip2int(e) != lh.myIp {
|
if ip2int(e) != lh.myIp {
|
||||||
ipp := NewIpAndPort(e, lh.nebulaPort)
|
ipp := NewIpAndPort(e, lh.nebulaPort)
|
||||||
|
@ -297,7 +299,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
|
||||||
for vpnIp := range lh.lighthouses {
|
for vpnIp := range lh.lighthouses {
|
||||||
mm, err := proto.Marshal(m)
|
mm, err := proto.Marshal(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Debugf("Invalid marshal to update")
|
lh.l.Debugf("Invalid marshal to update")
|
||||||
}
|
}
|
||||||
//l.Error("LIGHTHOUSE PACKET SEND", mm)
|
//l.Error("LIGHTHOUSE PACKET SEND", mm)
|
||||||
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
|
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
|
||||||
|
@ -368,14 +370,14 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
||||||
n := lhh.resetMeta()
|
n := lhh.resetMeta()
|
||||||
err := proto.UnmarshalMerge(p, n)
|
err := proto.UnmarshalMerge(p, n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
||||||
Error("Failed to unmarshal lighthouse packet")
|
Error("Failed to unmarshal lighthouse packet")
|
||||||
//TODO: send recv_error?
|
//TODO: send recv_error?
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.Details == nil {
|
if n.Details == nil {
|
||||||
l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
||||||
Error("Invalid lighthouse update")
|
Error("Invalid lighthouse update")
|
||||||
//TODO: send recv_error?
|
//TODO: send recv_error?
|
||||||
return
|
return
|
||||||
|
@ -387,7 +389,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
||||||
case NebulaMeta_HostQuery:
|
case NebulaMeta_HostQuery:
|
||||||
// Exit if we don't answer queries
|
// Exit if we don't answer queries
|
||||||
if !lh.amLighthouse {
|
if !lh.amLighthouse {
|
||||||
l.Debugln("I don't answer queries, but received from: ", rAddr)
|
lh.l.Debugln("I don't answer queries, but received from: ", rAddr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -422,7 +424,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
||||||
|
|
||||||
reply, err := proto.Marshal(n)
|
reply, err := proto.Marshal(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
|
lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
||||||
|
@ -431,7 +433,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
||||||
// This signals the other side to punch some zero byte udp packets
|
// This signals the other side to punch some zero byte udp packets
|
||||||
ips, err = lh.Query(vpnIp, f)
|
ips, err = lh.Query(vpnIp, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
|
lh.l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
//l.Debugln("Notify host to punch", iap)
|
//l.Debugln("Notify host to punch", iap)
|
||||||
|
@ -492,7 +494,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
||||||
case NebulaMeta_HostUpdateNotification:
|
case NebulaMeta_HostUpdateNotification:
|
||||||
//Simple check that the host sent this not someone else
|
//Simple check that the host sent this not someone else
|
||||||
if n.Details.VpnIp != vpnIp {
|
if n.Details.VpnIp != vpnIp {
|
||||||
l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
|
lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -530,9 +532,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if lh.l.Level >= logrus.DebugLevel {
|
||||||
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
||||||
l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
|
lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -549,9 +551,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if lh.l.Level >= logrus.DebugLevel {
|
||||||
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
||||||
l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
|
lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -561,7 +563,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
||||||
if lh.punchBack {
|
if lh.punchBack {
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Second * 5)
|
time.Sleep(time.Second * 5)
|
||||||
l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
|
lh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
|
||||||
// TODO we have to allocate a new output buffer here since we are spawning a new goroutine
|
// TODO we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||||
// managed by a channel.
|
// managed by a channel.
|
||||||
|
|
|
@ -65,12 +65,13 @@ func TestSetipandportsfromudpaddrs(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_lhStaticMapping(t *testing.T) {
|
func Test_lhStaticMapping(t *testing.T) {
|
||||||
|
l := NewTestLogger()
|
||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
lh1IP := net.ParseIP(lh1)
|
lh1IP := net.ParseIP(lh1)
|
||||||
|
|
||||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
||||||
|
|
||||||
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
meh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||||
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
|
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
|
||||||
err := meh.ValidateLHStaticEntries()
|
err := meh.ValidateLHStaticEntries()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
@ -78,19 +79,20 @@ func Test_lhStaticMapping(t *testing.T) {
|
||||||
lh2 := "10.128.0.3"
|
lh2 := "10.128.0.3"
|
||||||
lh2IP := net.ParseIP(lh2)
|
lh2IP := net.ParseIP(lh2)
|
||||||
|
|
||||||
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
|
meh = NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
|
||||||
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
|
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
|
||||||
err = meh.ValidateLHStaticEntries()
|
err = meh.ValidateLHStaticEntries()
|
||||||
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
|
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
|
l := NewTestLogger()
|
||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
lh1IP := net.ParseIP(lh1)
|
lh1IP := net.ParseIP(lh1)
|
||||||
|
|
||||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
||||||
|
|
||||||
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||||
|
|
||||||
hAddr := NewUDPAddrFromString("4.5.6.7:12345")
|
hAddr := NewUDPAddrFromString("4.5.6.7:12345")
|
||||||
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
|
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
|
||||||
|
@ -136,7 +138,8 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_lhRemoteAllowList(t *testing.T) {
|
func Test_lhRemoteAllowList(t *testing.T) {
|
||||||
c := NewConfig()
|
l := NewTestLogger()
|
||||||
|
c := NewConfig(l)
|
||||||
c.Settings["remoteallowlist"] = map[interface{}]interface{}{
|
c.Settings["remoteallowlist"] = map[interface{}]interface{}{
|
||||||
"10.20.0.0/12": false,
|
"10.20.0.0/12": false,
|
||||||
}
|
}
|
||||||
|
@ -146,9 +149,9 @@ func Test_lhRemoteAllowList(t *testing.T) {
|
||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
lh1IP := net.ParseIP(lh1)
|
lh1IP := net.ParseIP(lh1)
|
||||||
|
|
||||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
||||||
|
|
||||||
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||||
lh.SetRemoteAllowList(allowList)
|
lh.SetRemoteAllowList(allowList)
|
||||||
|
|
||||||
remote1 := "10.20.0.3"
|
remote1 := "10.20.0.3"
|
||||||
|
|
29
main.go
29
main.go
|
@ -11,13 +11,10 @@ import (
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The caller should provide a real logger, we have one just in case
|
|
||||||
var l = logrus.New()
|
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m map[string]interface{}
|
||||||
|
|
||||||
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
|
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
|
||||||
l = logger
|
l := logger
|
||||||
l.Formatter = &logrus.TextFormatter{
|
l.Formatter = &logrus.TextFormatter{
|
||||||
FullTimestamp: true,
|
FullTimestamp: true,
|
||||||
}
|
}
|
||||||
|
@ -46,7 +43,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
})
|
})
|
||||||
|
|
||||||
// trustedCAs is currently a global, so loadCA operates on that global directly
|
// trustedCAs is currently a global, so loadCA operates on that global directly
|
||||||
trustedCAs, err = loadCAFromConfig(config)
|
trustedCAs, err = loadCAFromConfig(l, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of loadCA are already nicely formatted
|
//The errors coming out of loadCA are already nicely formatted
|
||||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||||
|
@ -60,7 +57,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
}
|
}
|
||||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||||
|
|
||||||
fw, err := NewFirewallFromConfig(cs.certificate, config)
|
fw, err := NewFirewallFromConfig(l, cs.certificate, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
||||||
}
|
}
|
||||||
|
@ -78,9 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
}
|
}
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||||
wireSSHReload(ssh, config)
|
wireSSHReload(l, ssh, config)
|
||||||
if config.GetBool("sshd.enabled", false) {
|
if config.GetBool("sshd.enabled", false) {
|
||||||
err = configSSH(ssh, config)
|
err = configSSH(l, ssh, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
||||||
}
|
}
|
||||||
|
@ -136,6 +133,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
|
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
|
||||||
case tunFd != nil:
|
case tunFd != nil:
|
||||||
tun, err = newTunFromFd(
|
tun, err = newTunFromFd(
|
||||||
|
l,
|
||||||
*tunFd,
|
*tunFd,
|
||||||
tunCidr,
|
tunCidr,
|
||||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||||
|
@ -145,6 +143,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
)
|
)
|
||||||
default:
|
default:
|
||||||
tun, err = newTun(
|
tun, err = newTun(
|
||||||
|
l,
|
||||||
config.GetString("tun.dev", ""),
|
config.GetString("tun.dev", ""),
|
||||||
tunCidr,
|
tunCidr,
|
||||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||||
|
@ -166,7 +165,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
|
|
||||||
if !configTest {
|
if !configTest {
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
|
udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
}
|
}
|
||||||
|
@ -222,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostMap := NewHostMap("main", tunCidr, preferredRanges)
|
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
|
||||||
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
|
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
|
||||||
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
||||||
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
|
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
|
||||||
|
@ -266,6 +265,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
}
|
}
|
||||||
|
|
||||||
lightHouse := NewLightHouse(
|
lightHouse := NewLightHouse(
|
||||||
|
l,
|
||||||
amLighthouse,
|
amLighthouse,
|
||||||
ip2int(tunCidr.IP),
|
ip2int(tunCidr.IP),
|
||||||
lighthouseHosts,
|
lighthouseHosts,
|
||||||
|
@ -337,7 +337,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
messageMetrics: messageMetrics,
|
messageMetrics: messageMetrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||||
lightHouse.handshakeTrigger = handshakeManager.trigger
|
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||||
|
|
||||||
//TODO: These will be reused for psk
|
//TODO: These will be reused for psk
|
||||||
|
@ -367,6 +367,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
version: buildVersion,
|
version: buildVersion,
|
||||||
|
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch ifConfig.Cipher {
|
switch ifConfig.Cipher {
|
||||||
|
@ -395,7 +396,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
go lightHouse.LhUpdateWorker(ifce)
|
go lightHouse.LhUpdateWorker(ifce)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = startStats(config, configTest)
|
err = startStats(l, config, configTest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||||
}
|
}
|
||||||
|
@ -407,12 +408,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
//TODO: check if we _should_ be emitting stats
|
//TODO: check if we _should_ be emitting stats
|
||||||
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
||||||
|
|
||||||
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||||
|
|
||||||
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
||||||
if amLighthouse && serveDns {
|
if amLighthouse && serveDns {
|
||||||
l.Debugln("Starting dns server")
|
l.Debugln("Starting dns server")
|
||||||
go dnsMain(hostMap, config)
|
go dnsMain(l, hostMap, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Control{ifce, l}, nil
|
return &Control{ifce, l}, nil
|
||||||
|
|
29
main_test.go
29
main_test.go
|
@ -1 +1,30 @@
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewTestLogger() *logrus.Logger {
|
||||||
|
l := logrus.New()
|
||||||
|
|
||||||
|
v := os.Getenv("TEST_LOGS")
|
||||||
|
if v == "" {
|
||||||
|
l.SetOutput(ioutil.Discard)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v {
|
||||||
|
case "1":
|
||||||
|
// This is the default level but we are being explicit
|
||||||
|
l.SetLevel(logrus.InfoLevel)
|
||||||
|
case "2":
|
||||||
|
l.SetLevel(logrus.DebugLevel)
|
||||||
|
case "3":
|
||||||
|
l.SetLevel(logrus.TraceLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
50
outside.go
50
outside.go
|
@ -24,7 +24,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
|
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
if len(packet) > 1 {
|
if len(packet) > 1 {
|
||||||
l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
|
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -57,7 +57,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("packet", packet).
|
WithField("packet", packet).
|
||||||
Error("Failed to decrypt lighthouse packet")
|
Error("Failed to decrypt lighthouse packet")
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("packet", packet).
|
WithField("packet", packet).
|
||||||
Error("Failed to decrypt test packet")
|
Error("Failed to decrypt test packet")
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger().WithField("udpAddr", addr).
|
hostinfo.logger(f.l).WithField("udpAddr", addr).
|
||||||
Info("Close tunnel received, tearing down.")
|
Info("Close tunnel received, tearing down.")
|
||||||
|
|
||||||
f.closeTunnel(hostinfo)
|
f.closeTunnel(hostinfo)
|
||||||
|
@ -123,7 +123,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
|
|
||||||
default:
|
default:
|
||||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||||
hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
|
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,18 +143,18 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
||||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
|
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
|
||||||
if hostDidRoam(hostinfo.remote, addr) {
|
if hostDidRoam(hostinfo.remote, addr) {
|
||||||
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
|
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
|
||||||
hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
|
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||||
Info("Host roamed to new udp ip/port.")
|
Info("Host roamed to new udp ip/port.")
|
||||||
hostinfo.lastRoam = time.Now()
|
hostinfo.lastRoam = time.Now()
|
||||||
remoteCopy := *hostinfo.remote
|
remoteCopy := *hostinfo.remote
|
||||||
|
@ -170,7 +170,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
|
||||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
|
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
|
||||||
// If connectionstate exists and the replay protector allows, process packet
|
// If connectionstate exists and the replay protector allows, process packet
|
||||||
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
||||||
if ci == nil || !ci.window.Check(header.MessageCounter) {
|
if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
|
||||||
f.sendRecvError(addr, header.RemoteIndex)
|
f.sendRecvError(addr, header.RemoteIndex)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -247,8 +247,8 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(mc) {
|
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
||||||
hostinfo.logger().WithField("header", header).
|
hostinfo.logger(f.l).WithField("header", header).
|
||||||
Debugln("dropping out of window packet")
|
Debugln("dropping out of window packet")
|
||||||
return nil, errors.New("out of window packet")
|
return nil, errors.New("out of window packet")
|
||||||
}
|
}
|
||||||
|
@ -261,7 +261,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger().WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
||||||
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
|
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
|
||||||
return
|
return
|
||||||
|
@ -269,21 +269,21 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||||
|
|
||||||
err = newPacket(out, true, fwPacket)
|
err = newPacket(out, true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger().WithError(err).WithField("packet", out).
|
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||||
Warnf("Error while validating inbound packet")
|
Warnf("Error while validating inbound packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(messageCounter) {
|
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||||
hostinfo.logger().WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
Debugln("dropping out of window packet")
|
Debugln("dropping out of window packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
|
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger().WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
WithField("reason", dropReason).
|
WithField("reason", dropReason).
|
||||||
Debugln("dropping inbound packet")
|
Debugln("dropping inbound packet")
|
||||||
}
|
}
|
||||||
|
@ -293,7 +293,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||||
f.connectionManager.In(hostinfo.hostId)
|
f.connectionManager.In(hostinfo.hostId)
|
||||||
_, err = f.readers[q].Write(out)
|
_, err = f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,16 +303,16 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
|
||||||
//TODO: this should be a signed message so we can trust that we should drop the index
|
//TODO: this should be a signed message so we can trust that we should drop the index
|
||||||
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
|
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
|
||||||
f.outside.WriteTo(b, endpoint)
|
f.outside.WriteTo(b, endpoint)
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("index", index).
|
f.l.WithField("index", index).
|
||||||
WithField("udpAddr", endpoint).
|
WithField("udpAddr", endpoint).
|
||||||
Debug("Recv error sent")
|
Debug("Recv error sent")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("index", h.RemoteIndex).
|
f.l.WithField("index", h.RemoteIndex).
|
||||||
WithField("udpAddr", addr).
|
WithField("udpAddr", addr).
|
||||||
Debug("Recv error received")
|
Debug("Recv error received")
|
||||||
}
|
}
|
||||||
|
@ -322,7 +322,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
||||||
|
|
||||||
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Debugln(err, ": ", h.RemoteIndex)
|
f.l.Debugln(err, ": ", h.RemoteIndex)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,7 +333,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() {
|
if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() {
|
||||||
l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewPunchyFromConfig(t *testing.T) {
|
func TestNewPunchyFromConfig(t *testing.T) {
|
||||||
c := NewConfig()
|
l := NewTestLogger()
|
||||||
|
c := NewConfig(l)
|
||||||
|
|
||||||
// Test defaults
|
// Test defaults
|
||||||
p := NewPunchyFromConfig(c)
|
p := NewPunchyFromConfig(c)
|
||||||
|
|
20
ssh.go
20
ssh.go
|
@ -44,10 +44,10 @@ type sshCreateTunnelFlags struct {
|
||||||
Address string
|
Address string
|
||||||
}
|
}
|
||||||
|
|
||||||
func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
|
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
|
||||||
c.RegisterReloadCallback(func(c *Config) {
|
c.RegisterReloadCallback(func(c *Config) {
|
||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
err := configSSH(ssh, c)
|
err := configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to reconfigure the sshd")
|
l.WithError(err).Error("Failed to reconfigure the sshd")
|
||||||
ssh.Stop()
|
ssh.Stop()
|
||||||
|
@ -58,7 +58,7 @@ func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func configSSH(ssh *sshd.SSHServer, c *Config) error {
|
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
|
||||||
//TODO conntrack list
|
//TODO conntrack list
|
||||||
//TODO print firewall rules or hash?
|
//TODO print firewall rules or hash?
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
|
func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
|
||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "list-hostmap",
|
Name: "list-hostmap",
|
||||||
ShortDescription: "List all known previously connected hosts",
|
ShortDescription: "List all known previously connected hosts",
|
||||||
|
@ -225,13 +225,17 @@ func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostM
|
||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "log-level",
|
Name: "log-level",
|
||||||
ShortDescription: "Gets or sets the current log level",
|
ShortDescription: "Gets or sets the current log level",
|
||||||
Callback: sshLogLevel,
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
|
return sshLogLevel(l, fs, a, w)
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "log-format",
|
Name: "log-format",
|
||||||
ShortDescription: "Gets or sets the current log format",
|
ShortDescription: "Gets or sets the current log format",
|
||||||
Callback: sshLogFormat,
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
|
return sshLogFormat(l, fs, a, w)
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
|
@ -629,7 +633,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
|
func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||||
}
|
}
|
||||||
|
@ -643,7 +647,7 @@ func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshLogFormat(fs interface{}, a []string, w sshd.StringWriter) error {
|
func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
||||||
}
|
}
|
||||||
|
|
11
stats.go
11
stats.go
|
@ -13,9 +13,10 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startStats(c *Config, configTest bool) error {
|
func startStats(l *logrus.Logger, c *Config, configTest bool) error {
|
||||||
mType := c.GetString("stats.type", "")
|
mType := c.GetString("stats.type", "")
|
||||||
if mType == "" || mType == "none" {
|
if mType == "" || mType == "none" {
|
||||||
return nil
|
return nil
|
||||||
|
@ -28,9 +29,9 @@ func startStats(c *Config, configTest bool) error {
|
||||||
|
|
||||||
switch mType {
|
switch mType {
|
||||||
case "graphite":
|
case "graphite":
|
||||||
startGraphiteStats(interval, c, configTest)
|
startGraphiteStats(l, interval, c, configTest)
|
||||||
case "prometheus":
|
case "prometheus":
|
||||||
startPrometheusStats(interval, c, configTest)
|
startPrometheusStats(l, interval, c, configTest)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("stats.type was not understood: %s", mType)
|
return fmt.Errorf("stats.type was not understood: %s", mType)
|
||||||
}
|
}
|
||||||
|
@ -44,7 +45,7 @@ func startStats(c *Config, configTest bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
|
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
|
||||||
proto := c.GetString("stats.protocol", "tcp")
|
proto := c.GetString("stats.protocol", "tcp")
|
||||||
host := c.GetString("stats.host", "")
|
host := c.GetString("stats.host", "")
|
||||||
if host == "" {
|
if host == "" {
|
||||||
|
@ -64,7 +65,7 @@ func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startPrometheusStats(i time.Duration, c *Config, configTest bool) error {
|
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
|
||||||
namespace := c.GetString("stats.namespace", "")
|
namespace := c.GetString("stats.namespace", "")
|
||||||
subsystem := c.GetString("stats.subsystem", "")
|
subsystem := c.GetString("stats.subsystem", "")
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,9 +20,10 @@ type Tun struct {
|
||||||
TXQueueLen int
|
TXQueueLen int
|
||||||
Routes []route
|
Routes []route
|
||||||
UnsafeRoutes []route
|
UnsafeRoutes []route
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
|
|
||||||
ifce = &Tun{
|
ifce = &Tun{
|
||||||
|
@ -33,6 +35,7 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||||
TXQueueLen: txQueueLen,
|
TXQueueLen: txQueueLen,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
UnsafeRoutes: unsafeRoutes,
|
UnsafeRoutes: unsafeRoutes,
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/songgao/water"
|
"github.com/songgao/water"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,11 +18,11 @@ type Tun struct {
|
||||||
Cidr *net.IPNet
|
Cidr *net.IPNet
|
||||||
MTU int
|
MTU int
|
||||||
UnsafeRoutes []route
|
UnsafeRoutes []route
|
||||||
|
l *logrus.Logger
|
||||||
*water.Interface
|
*water.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||||
if len(routes) > 0 {
|
if len(routes) > 0 {
|
||||||
return nil, fmt.Errorf("route MTU not supported in Darwin")
|
return nil, fmt.Errorf("route MTU not supported in Darwin")
|
||||||
}
|
}
|
||||||
|
@ -31,10 +32,11 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||||
Cidr: cidr,
|
Cidr: cidr,
|
||||||
MTU: defaultMTU,
|
MTU: defaultMTU,
|
||||||
UnsafeRoutes: unsafeRoutes,
|
UnsafeRoutes: unsafeRoutes,
|
||||||
|
l: l,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,24 +9,23 @@ import (
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type disabledTun struct {
|
type disabledTun struct {
|
||||||
read chan []byte
|
read chan []byte
|
||||||
cidr *net.IPNet
|
cidr *net.IPNet
|
||||||
logger *log.Logger
|
|
||||||
|
|
||||||
// Track these metrics since we don't have the tun device to do it for us
|
// Track these metrics since we don't have the tun device to do it for us
|
||||||
tx metrics.Counter
|
tx metrics.Counter
|
||||||
rx metrics.Counter
|
rx metrics.Counter
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *log.Logger) *disabledTun {
|
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||||
tun := &disabledTun{
|
tun := &disabledTun{
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
read: make(chan []byte, queueLen),
|
read: make(chan []byte, queueLen),
|
||||||
logger: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metricsEnabled {
|
if metricsEnabled {
|
||||||
|
@ -63,8 +62,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.tx.Inc(1)
|
t.tx.Inc(1)
|
||||||
if l.Level >= logrus.DebugLevel {
|
if t.l.Level >= logrus.DebugLevel {
|
||||||
t.logger.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
||||||
}
|
}
|
||||||
|
|
||||||
return copy(b, r), nil
|
return copy(b, r), nil
|
||||||
|
@ -103,7 +102,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
||||||
select {
|
select {
|
||||||
case t.read <- buf:
|
case t.read <- buf:
|
||||||
default:
|
default:
|
||||||
t.logger.Debugf("tun_disabled: dropped ICMP Echo Reply response")
|
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
@ -114,11 +113,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
||||||
|
|
||||||
// Check for ICMP Echo Request before spending time doing the full parsing
|
// Check for ICMP Echo Request before spending time doing the full parsing
|
||||||
if t.handleICMPEchoRequest(b) {
|
if t.handleICMPEchoRequest(b) {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if t.l.Level >= logrus.DebugLevel {
|
||||||
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
|
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
|
||||||
}
|
}
|
||||||
} else if l.Level >= logrus.DebugLevel {
|
} else if t.l.Level >= logrus.DebugLevel {
|
||||||
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
|
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
|
||||||
}
|
}
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
@ -18,15 +20,16 @@ type Tun struct {
|
||||||
Cidr *net.IPNet
|
Cidr *net.IPNet
|
||||||
MTU int
|
MTU int
|
||||||
UnsafeRoutes []route
|
UnsafeRoutes []route
|
||||||
|
l *logrus.Logger
|
||||||
|
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||||
if len(routes) > 0 {
|
if len(routes) > 0 {
|
||||||
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
|
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
|
||||||
}
|
}
|
||||||
|
@ -41,6 +44,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||||
Cidr: cidr,
|
Cidr: cidr,
|
||||||
MTU: defaultMTU,
|
MTU: defaultMTU,
|
||||||
UnsafeRoutes: unsafeRoutes,
|
UnsafeRoutes: unsafeRoutes,
|
||||||
|
l: l,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,21 +56,21 @@ func (c *Tun) Activate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO use syscalls instead of exec.Command
|
// TODO use syscalls instead of exec.Command
|
||||||
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
|
c.l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
|
||||||
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
|
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
}
|
}
|
||||||
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
|
c.l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
|
||||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
|
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
}
|
}
|
||||||
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
|
c.l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
|
||||||
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
|
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
}
|
}
|
||||||
// Unsafe path routes
|
// Unsafe path routes
|
||||||
for _, r := range c.UnsafeRoutes {
|
for _, r := range c.UnsafeRoutes {
|
||||||
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
|
c.l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
|
||||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
|
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
|
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
|
||||||
}
|
}
|
||||||
|
|
12
tun_linux.go
12
tun_linux.go
|
@ -10,6 +10,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
@ -24,6 +25,7 @@ type Tun struct {
|
||||||
TXQueueLen int
|
TXQueueLen int
|
||||||
Routes []route
|
Routes []route
|
||||||
UnsafeRoutes []route
|
UnsafeRoutes []route
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifReq struct {
|
type ifReq struct {
|
||||||
|
@ -78,7 +80,7 @@ type ifreqQLEN struct {
|
||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
|
|
||||||
|
@ -91,11 +93,12 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||||
TXQueueLen: txQueueLen,
|
TXQueueLen: txQueueLen,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
UnsafeRoutes: unsafeRoutes,
|
UnsafeRoutes: unsafeRoutes,
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -131,6 +134,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||||
TXQueueLen: txQueueLen,
|
TXQueueLen: txQueueLen,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
UnsafeRoutes: unsafeRoutes,
|
UnsafeRoutes: unsafeRoutes,
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -233,14 +237,14 @@ func (c Tun) Activate() error {
|
||||||
ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)}
|
ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)}
|
||||||
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||||
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
||||||
l.WithError(err).Error("Failed to set tun mtu")
|
c.l.WithError(err).Error("Failed to set tun mtu")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the transmit queue length
|
// Set the transmit queue length
|
||||||
ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)}
|
ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)}
|
||||||
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||||
l.WithError(err).Error("Failed to set tun tx queue length")
|
c.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bring up the interface
|
// Bring up the interface
|
||||||
|
|
|
@ -9,7 +9,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_parseRoutes(t *testing.T) {
|
func Test_parseRoutes(t *testing.T) {
|
||||||
c := NewConfig()
|
l := NewTestLogger()
|
||||||
|
c := NewConfig(l)
|
||||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
|
@ -104,7 +105,8 @@ func Test_parseRoutes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_parseUnsafeRoutes(t *testing.T) {
|
func Test_parseUnsafeRoutes(t *testing.T) {
|
||||||
c := NewConfig()
|
l := NewTestLogger()
|
||||||
|
c := NewConfig(l)
|
||||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/songgao/water"
|
"github.com/songgao/water"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,15 +16,16 @@ type Tun struct {
|
||||||
Cidr *net.IPNet
|
Cidr *net.IPNet
|
||||||
MTU int
|
MTU int
|
||||||
UnsafeRoutes []route
|
UnsafeRoutes []route
|
||||||
|
l *logrus.Logger
|
||||||
|
|
||||||
*water.Interface
|
*water.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||||
if len(routes) > 0 {
|
if len(routes) > 0 {
|
||||||
return nil, fmt.Errorf("route MTU not supported in Windows")
|
return nil, fmt.Errorf("route MTU not supported in Windows")
|
||||||
}
|
}
|
||||||
|
@ -33,6 +35,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||||
Cidr: cidr,
|
Cidr: cidr,
|
||||||
MTU: defaultMTU,
|
MTU: defaultMTU,
|
||||||
UnsafeRoutes: unsafeRoutes,
|
UnsafeRoutes: unsafeRoutes,
|
||||||
|
l: l,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
|
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
|
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
// +build !linux android
|
// +build !linux android
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
||||||
// means it can be used on platforms like Darwin and Windows.
|
// means it can be used on platforms like Darwin and Windows.
|
||||||
|
@ -9,20 +10,23 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type udpConn struct {
|
type udpConn struct {
|
||||||
*net.UDPConn
|
*net.UDPConn
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
|
||||||
lc := NewListenConfig(multi)
|
lc := NewListenConfig(multi)
|
||||||
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
|
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if uc, ok := pc.(*net.UDPConn); ok {
|
if uc, ok := pc.(*net.UDPConn); ok {
|
||||||
return &udpConn{UDPConn: uc}, nil
|
return &udpConn{UDPConn: uc, l: l}, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
|
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
|
||||||
}
|
}
|
||||||
|
@ -76,13 +80,13 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
n, rua, err := u.ReadFromUDP(buffer)
|
n, rua, err := u.ReadFromUDP(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to read packets")
|
f.l.WithError(err).Error("Failed to read packets")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
udpAddr.IP = rua.IP
|
udpAddr.IP = rua.IP
|
||||||
udpAddr.Port = uint16(rua.Port)
|
udpAddr.Port = uint16(rua.Port)
|
||||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get())
|
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
23
udp_linux.go
23
udp_linux.go
|
@ -1,4 +1,5 @@
|
||||||
// +build !android
|
// +build !android
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
@ -10,6 +11,7 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,6 +19,7 @@ import (
|
||||||
|
|
||||||
type udpConn struct {
|
type udpConn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var x int
|
var x int
|
||||||
|
@ -38,7 +41,7 @@ const (
|
||||||
|
|
||||||
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
|
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
|
||||||
|
|
||||||
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
|
||||||
syscall.ForkLock.RLock()
|
syscall.ForkLock.RLock()
|
||||||
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -70,7 +73,7 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
||||||
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
|
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
|
||||||
//l.Println(v, err)
|
//l.Println(v, err)
|
||||||
|
|
||||||
return &udpConn{sysFd: fd}, err
|
return &udpConn{sysFd: fd, l: l}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() error {
|
func (u *udpConn) Rebind() error {
|
||||||
|
@ -153,7 +156,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
|
||||||
for {
|
for {
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to read packets")
|
u.l.WithError(err).Error("Failed to read packets")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +164,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
udpAddr.IP = names[i][8:24]
|
udpAddr.IP = names[i][8:24]
|
||||||
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
|
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
|
||||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get())
|
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -244,12 +247,12 @@ func (u *udpConn) reloadConfig(c *Config) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s, err := u.GetRecvBuffer()
|
s, err := u.GetRecvBuffer()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
l.WithField("size", s).Info("listen.read_buffer was set")
|
u.l.WithField("size", s).Info("listen.read_buffer was set")
|
||||||
} else {
|
} else {
|
||||||
l.WithError(err).Warn("Failed to get listen.read_buffer")
|
u.l.WithError(err).Warn("Failed to get listen.read_buffer")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
l.WithError(err).Error("Failed to set listen.read_buffer")
|
u.l.WithError(err).Error("Failed to set listen.read_buffer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -259,12 +262,12 @@ func (u *udpConn) reloadConfig(c *Config) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s, err := u.GetSendBuffer()
|
s, err := u.GetSendBuffer()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
l.WithField("size", s).Info("listen.write_buffer was set")
|
u.l.WithField("size", s).Info("listen.write_buffer was set")
|
||||||
} else {
|
} else {
|
||||||
l.WithError(err).Warn("Failed to get listen.write_buffer")
|
u.l.WithError(err).Warn("Failed to get listen.write_buffer")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
l.WithError(err).Error("Failed to set listen.write_buffer")
|
u.l.WithError(err).Error("Failed to set listen.write_buffer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
// +build linux
|
// +build linux
|
||||||
// +build 386 amd64p32 arm mips mipsle
|
// +build 386 amd64p32 arm mips mipsle
|
||||||
// +build !android
|
// +build !android
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
// +build linux
|
// +build linux
|
||||||
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
|
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
|
||||||
// +build !android
|
// +build !android
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue