Rework some things into packages (#489)
This commit is contained in:
parent
1f75fb3c73
commit
bcabcfdaca
235
allow_list.go
235
allow_list.go
|
@ -4,11 +4,15 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cidr"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AllowList struct {
|
type AllowList struct {
|
||||||
// The values of this cidrTree are `bool`, signifying allow/deny
|
// The values of this cidrTree are `bool`, signifying allow/deny
|
||||||
cidrTree *CIDR6Tree
|
cidrTree *cidr.Tree6
|
||||||
}
|
}
|
||||||
|
|
||||||
type RemoteAllowList struct {
|
type RemoteAllowList struct {
|
||||||
|
@ -16,7 +20,7 @@ type RemoteAllowList struct {
|
||||||
|
|
||||||
// Inside Range Specific, keys of this tree are inside CIDRs and values
|
// Inside Range Specific, keys of this tree are inside CIDRs and values
|
||||||
// are *AllowList
|
// are *AllowList
|
||||||
insideAllowLists *CIDR6Tree
|
insideAllowLists *cidr.Tree6
|
||||||
}
|
}
|
||||||
|
|
||||||
type LocalAllowList struct {
|
type LocalAllowList struct {
|
||||||
|
@ -31,6 +35,223 @@ type AllowListNameRule struct {
|
||||||
Allow bool
|
Allow bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
|
||||||
|
var nameRules []AllowListNameRule
|
||||||
|
handleKey := func(key string, value interface{}) (bool, error) {
|
||||||
|
if key == "interfaces" {
|
||||||
|
var err error
|
||||||
|
nameRules, err = getAllowListInterfaces(k, value)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
al, err := newAllowListFromConfig(c, k, handleKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllowList, error) {
|
||||||
|
al, err := newAllowListFromConfig(c, k, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
remoteAllowRanges, err := getRemoteAllowRanges(c, rangesKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the handleKey func returns true, the rest of the parsing is skipped
|
||||||
|
// for this key. This allows parsing of special values like `interfaces`.
|
||||||
|
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
||||||
|
r := c.Get(k)
|
||||||
|
if r == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return newAllowList(k, r, handleKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the handleKey func returns true, the rest of the parsing is skipped
|
||||||
|
// for this key. This allows parsing of special values like `interfaces`.
|
||||||
|
func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
||||||
|
rawMap, ok := raw.(map[interface{}]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
tree := cidr.NewTree6()
|
||||||
|
|
||||||
|
// Keep track of the rules we have added for both ipv4 and ipv6
|
||||||
|
type allowListRules struct {
|
||||||
|
firstValue bool
|
||||||
|
allValuesMatch bool
|
||||||
|
defaultSet bool
|
||||||
|
allValues bool
|
||||||
|
}
|
||||||
|
|
||||||
|
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
||||||
|
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
||||||
|
|
||||||
|
for rawKey, rawValue := range rawMap {
|
||||||
|
rawCIDR, ok := rawKey.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
if handleKey != nil {
|
||||||
|
handled, err := handleKey(rawCIDR, rawValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if handled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
value, ok := rawValue.(bool)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ipNet, err := net.ParseCIDR(rawCIDR)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: should we error on duplicate CIDRs in the config?
|
||||||
|
tree.AddCIDR(ipNet, value)
|
||||||
|
|
||||||
|
maskBits, maskSize := ipNet.Mask.Size()
|
||||||
|
|
||||||
|
var rules *allowListRules
|
||||||
|
if maskSize == 32 {
|
||||||
|
rules = &rules4
|
||||||
|
} else {
|
||||||
|
rules = &rules6
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules.firstValue {
|
||||||
|
rules.allValues = value
|
||||||
|
rules.firstValue = false
|
||||||
|
} else {
|
||||||
|
if value != rules.allValues {
|
||||||
|
rules.allValuesMatch = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is 0.0.0.0/0 or ::/0
|
||||||
|
if maskBits == 0 {
|
||||||
|
rules.defaultSet = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rules4.defaultSet {
|
||||||
|
if rules4.allValuesMatch {
|
||||||
|
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
|
||||||
|
tree.AddCIDR(zeroCIDR, !rules4.allValues)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rules6.defaultSet {
|
||||||
|
if rules6.allValuesMatch {
|
||||||
|
_, zeroCIDR, _ := net.ParseCIDR("::/0")
|
||||||
|
tree.AddCIDR(zeroCIDR, !rules6.allValues)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AllowList{cidrTree: tree}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
|
||||||
|
var nameRules []AllowListNameRule
|
||||||
|
|
||||||
|
rawRules, ok := v.(map[interface{}]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
firstEntry := true
|
||||||
|
var allValues bool
|
||||||
|
for rawName, rawAllow := range rawRules {
|
||||||
|
name, ok := rawName.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
|
||||||
|
}
|
||||||
|
allow, ok := rawAllow.(bool)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
|
||||||
|
}
|
||||||
|
|
||||||
|
nameRE, err := regexp.Compile("^" + name + "$")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nameRules = append(nameRules, AllowListNameRule{
|
||||||
|
Name: nameRE,
|
||||||
|
Allow: allow,
|
||||||
|
})
|
||||||
|
|
||||||
|
if firstEntry {
|
||||||
|
allValues = allow
|
||||||
|
firstEntry = false
|
||||||
|
} else {
|
||||||
|
if allow != allValues {
|
||||||
|
return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nameRules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
|
||||||
|
value := c.Get(k)
|
||||||
|
if value == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAllowRanges := cidr.NewTree6()
|
||||||
|
|
||||||
|
rawMap, ok := value.(map[interface{}]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||||
|
}
|
||||||
|
for rawKey, rawValue := range rawMap {
|
||||||
|
rawCIDR, ok := rawKey.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ipNet, err := net.ParseCIDR(rawCIDR)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAllowRanges.AddCIDR(ipNet, allowList)
|
||||||
|
}
|
||||||
|
|
||||||
|
return remoteAllowRanges, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (al *AllowList) Allow(ip net.IP) bool {
|
func (al *AllowList) Allow(ip net.IP) bool {
|
||||||
if al == nil {
|
if al == nil {
|
||||||
return true
|
return true
|
||||||
|
@ -45,7 +266,7 @@ func (al *AllowList) Allow(ip net.IP) bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AllowList) AllowIpV4(ip uint32) bool {
|
func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
|
||||||
if al == nil {
|
if al == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -102,14 +323,14 @@ func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
|
||||||
return al.AllowList.Allow(ip)
|
return al.AllowList.Allow(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *RemoteAllowList) Allow(vpnIp uint32, ip net.IP) bool {
|
func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
|
||||||
if !al.getInsideAllowList(vpnIp).Allow(ip) {
|
if !al.getInsideAllowList(vpnIp).Allow(ip) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return al.AllowList.Allow(ip)
|
return al.AllowList.Allow(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool {
|
func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
|
||||||
if al == nil {
|
if al == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -119,7 +340,7 @@ func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool {
|
||||||
return al.AllowList.AllowIpV4(ip)
|
return al.AllowList.AllowIpV4(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool {
|
func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
|
||||||
if al == nil {
|
if al == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -129,7 +350,7 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool {
|
||||||
return al.AllowList.AllowIpV6(hi, lo)
|
return al.AllowList.AllowIpV6(hi, lo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *RemoteAllowList) getInsideAllowList(vpnIp uint32) *AllowList {
|
func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
|
||||||
if al.insideAllowLists != nil {
|
if al.insideAllowLists != nil {
|
||||||
inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
|
inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
|
||||||
if inside != nil {
|
if inside != nil {
|
||||||
|
|
|
@ -5,21 +5,110 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cidr"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestNewAllowListFromConfig(t *testing.T) {
|
||||||
|
l := util.NewTestLogger()
|
||||||
|
c := config.NewC(l)
|
||||||
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
|
"192.168.0.0": true,
|
||||||
|
}
|
||||||
|
r, err := newAllowListFromConfig(c, "allowlist", nil)
|
||||||
|
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
|
||||||
|
assert.Nil(t, r)
|
||||||
|
|
||||||
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
|
"192.168.0.0/16": "abc",
|
||||||
|
}
|
||||||
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
|
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
||||||
|
|
||||||
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
|
"192.168.0.0/16": true,
|
||||||
|
"10.0.0.0/8": false,
|
||||||
|
}
|
||||||
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
|
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
||||||
|
|
||||||
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
|
"0.0.0.0/0": true,
|
||||||
|
"10.0.0.0/8": false,
|
||||||
|
"10.42.42.0/24": true,
|
||||||
|
"fd00::/8": true,
|
||||||
|
"fd00:fd00::/16": false,
|
||||||
|
}
|
||||||
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
|
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
|
||||||
|
|
||||||
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
|
"0.0.0.0/0": true,
|
||||||
|
"10.0.0.0/8": false,
|
||||||
|
"10.42.42.0/24": true,
|
||||||
|
}
|
||||||
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
assert.NotNil(t, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
|
"0.0.0.0/0": true,
|
||||||
|
"10.0.0.0/8": false,
|
||||||
|
"10.42.42.0/24": true,
|
||||||
|
"::/0": false,
|
||||||
|
"fd00::/8": true,
|
||||||
|
"fd00:fd00::/16": false,
|
||||||
|
}
|
||||||
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
assert.NotNil(t, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test interface names
|
||||||
|
|
||||||
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
|
"interfaces": map[interface{}]interface{}{
|
||||||
|
`docker.*`: "foo",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
|
||||||
|
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
||||||
|
|
||||||
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
|
"interfaces": map[interface{}]interface{}{
|
||||||
|
`docker.*`: false,
|
||||||
|
`eth.*`: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
|
||||||
|
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
||||||
|
|
||||||
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
|
"interfaces": map[interface{}]interface{}{
|
||||||
|
`docker.*`: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
assert.NotNil(t, lr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAllowList_Allow(t *testing.T) {
|
func TestAllowList_Allow(t *testing.T) {
|
||||||
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
|
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
|
||||||
|
|
||||||
tree := NewCIDR6Tree()
|
tree := cidr.NewTree6()
|
||||||
tree.AddCIDR(getCIDR("0.0.0.0/0"), true)
|
tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
|
||||||
tree.AddCIDR(getCIDR("10.0.0.0/8"), false)
|
tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
|
||||||
tree.AddCIDR(getCIDR("10.42.42.42/32"), true)
|
tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
|
||||||
tree.AddCIDR(getCIDR("10.42.0.0/16"), true)
|
tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
|
||||||
tree.AddCIDR(getCIDR("10.42.42.0/24"), true)
|
tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
|
||||||
tree.AddCIDR(getCIDR("10.42.42.0/24"), false)
|
tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
|
||||||
tree.AddCIDR(getCIDR("::1/128"), true)
|
tree.AddCIDR(cidr.Parse("::1/128"), true)
|
||||||
tree.AddCIDR(getCIDR("::2/128"), false)
|
tree.AddCIDR(cidr.Parse("::2/128"), false)
|
||||||
al := &AllowList{cidrTree: tree}
|
al := &AllowList{cidrTree: tree}
|
||||||
|
|
||||||
assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
|
assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
|
||||||
|
|
|
@ -3,11 +3,12 @@ package nebula
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBits(t *testing.T) {
|
func TestBits(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
|
|
||||||
// make sure it is the right size
|
// make sure it is the right size
|
||||||
|
@ -75,7 +76,7 @@ func TestBits(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsDupeCounter(t *testing.T) {
|
func TestBitsDupeCounter(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
|
@ -100,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsOutOfWindowCounter(t *testing.T) {
|
func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
|
@ -130,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsLostCounter(t *testing.T) {
|
func TestBitsLostCounter(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
|
|
5
cert.go
5
cert.go
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CertState struct {
|
type CertState struct {
|
||||||
|
@ -45,7 +46,7 @@ func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*Cert
|
||||||
return cs, nil
|
return cs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCertStateFromConfig(c *Config) (*CertState, error) {
|
func NewCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||||
var pemPrivateKey []byte
|
var pemPrivateKey []byte
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
@ -118,7 +119,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
|
||||||
return NewCertState(nebulaCert, rawKey)
|
return NewCertState(nebulaCert, rawKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
|
func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
|
||||||
var rawCA []byte
|
var rawCA []byte
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
package cidr
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
// Parse is a convenience function that returns only the IPNet
|
||||||
|
// This function ignores errors since it is primarily a test helper, the result could be nil
|
||||||
|
func Parse(s string) *net.IPNet {
|
||||||
|
_, c, _ := net.ParseCIDR(s)
|
||||||
|
return c
|
||||||
|
}
|
|
@ -1,39 +1,39 @@
|
||||||
package nebula
|
package cidr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CIDRNode struct {
|
type Node struct {
|
||||||
left *CIDRNode
|
left *Node
|
||||||
right *CIDRNode
|
right *Node
|
||||||
parent *CIDRNode
|
parent *Node
|
||||||
value interface{}
|
value interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type CIDRTree struct {
|
type Tree4 struct {
|
||||||
root *CIDRNode
|
root *Node
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
startbit = uint32(0x80000000)
|
startbit = iputil.VpnIp(0x80000000)
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewCIDRTree() *CIDRTree {
|
func NewTree4() *Tree4 {
|
||||||
tree := new(CIDRTree)
|
tree := new(Tree4)
|
||||||
tree.root = &CIDRNode{}
|
tree.root = &Node{}
|
||||||
return tree
|
return tree
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
|
func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root
|
node := tree.root
|
||||||
next := tree.root
|
next := tree.root
|
||||||
|
|
||||||
ip := ip2int(cidr.IP)
|
ip := iputil.Ip2VpnIp(cidr.IP)
|
||||||
mask := ip2int(cidr.Mask)
|
mask := iputil.Ip2VpnIp(cidr.Mask)
|
||||||
|
|
||||||
// Find our last ancestor in the tree
|
// Find our last ancestor in the tree
|
||||||
for bit&mask != 0 {
|
for bit&mask != 0 {
|
||||||
|
@ -59,7 +59,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
|
|
||||||
// Build up the rest of the tree we don't already have
|
// Build up the rest of the tree we don't already have
|
||||||
for bit&mask != 0 {
|
for bit&mask != 0 {
|
||||||
next = &CIDRNode{}
|
next = &Node{}
|
||||||
next.parent = node
|
next.parent = node
|
||||||
|
|
||||||
if ip&bit != 0 {
|
if ip&bit != 0 {
|
||||||
|
@ -77,7 +77,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finds the first match, which may be the least specific
|
// Finds the first match, which may be the least specific
|
||||||
func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
|
func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root
|
node := tree.root
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finds the most specific match
|
// Finds the most specific match
|
||||||
func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
|
func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root
|
node := tree.root
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finds the most specific match
|
// Finds the most specific match
|
||||||
func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
|
func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root
|
node := tree.root
|
||||||
lastNode := node
|
lastNode := node
|
||||||
|
@ -143,27 +143,3 @@ func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
|
||||||
}
|
}
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
// A helper type to avoid converting to IP when logging
|
|
||||||
type IntIp uint32
|
|
||||||
|
|
||||||
func (ip IntIp) String() string {
|
|
||||||
return fmt.Sprintf("%v", int2ip(uint32(ip)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ip IntIp) MarshalJSON() ([]byte, error) {
|
|
||||||
return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ip2int(ip []byte) uint32 {
|
|
||||||
if len(ip) == 16 {
|
|
||||||
return binary.BigEndian.Uint32(ip[12:16])
|
|
||||||
}
|
|
||||||
return binary.BigEndian.Uint32(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
func int2ip(nn uint32) net.IP {
|
|
||||||
ip := make(net.IP, 4)
|
|
||||||
binary.BigEndian.PutUint32(ip, nn)
|
|
||||||
return ip
|
|
||||||
}
|
|
|
@ -0,0 +1,153 @@
|
||||||
|
package cidr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCIDRTree_Contains(t *testing.T) {
|
||||||
|
tree := NewTree4()
|
||||||
|
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||||
|
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||||
|
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||||
|
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
|
||||||
|
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
|
||||||
|
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
|
||||||
|
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
Result interface{}
|
||||||
|
IP string
|
||||||
|
}{
|
||||||
|
{"1", "1.0.0.0"},
|
||||||
|
{"1", "1.255.255.255"},
|
||||||
|
{"2", "2.1.0.0"},
|
||||||
|
{"2", "2.1.255.255"},
|
||||||
|
{"3", "3.1.1.0"},
|
||||||
|
{"3", "3.1.1.255"},
|
||||||
|
{"4a", "4.1.1.255"},
|
||||||
|
{"4a", "4.1.1.1"},
|
||||||
|
{"5", "240.0.0.0"},
|
||||||
|
{"5", "255.255.255.255"},
|
||||||
|
{nil, "239.0.0.0"},
|
||||||
|
{nil, "4.1.2.2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
|
||||||
|
}
|
||||||
|
|
||||||
|
tree = NewTree4()
|
||||||
|
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||||
|
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
|
||||||
|
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCIDRTree_MostSpecificContains(t *testing.T) {
|
||||||
|
tree := NewTree4()
|
||||||
|
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||||
|
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||||
|
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||||
|
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
|
||||||
|
tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
|
||||||
|
tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
|
||||||
|
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
Result interface{}
|
||||||
|
IP string
|
||||||
|
}{
|
||||||
|
{"1", "1.0.0.0"},
|
||||||
|
{"1", "1.255.255.255"},
|
||||||
|
{"2", "2.1.0.0"},
|
||||||
|
{"2", "2.1.255.255"},
|
||||||
|
{"3", "3.1.1.0"},
|
||||||
|
{"3", "3.1.1.255"},
|
||||||
|
{"4a", "4.1.1.255"},
|
||||||
|
{"4b", "4.1.1.2"},
|
||||||
|
{"4c", "4.1.1.1"},
|
||||||
|
{"5", "240.0.0.0"},
|
||||||
|
{"5", "255.255.255.255"},
|
||||||
|
{nil, "239.0.0.0"},
|
||||||
|
{nil, "4.1.2.2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
|
||||||
|
}
|
||||||
|
|
||||||
|
tree = NewTree4()
|
||||||
|
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||||
|
assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
|
||||||
|
assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCIDRTree_Match(t *testing.T) {
|
||||||
|
tree := NewTree4()
|
||||||
|
tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
|
||||||
|
tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
Result interface{}
|
||||||
|
IP string
|
||||||
|
}{
|
||||||
|
{"1a", "4.1.1.0"},
|
||||||
|
{"1b", "4.1.1.1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
|
||||||
|
}
|
||||||
|
|
||||||
|
tree = NewTree4()
|
||||||
|
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||||
|
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
|
||||||
|
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCIDRTree_Contains(b *testing.B) {
|
||||||
|
tree := NewTree4()
|
||||||
|
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
|
||||||
|
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
|
||||||
|
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
|
||||||
|
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
|
||||||
|
|
||||||
|
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
|
||||||
|
b.Run("found", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tree.Contains(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
|
||||||
|
b.Run("not found", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tree.Contains(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCIDRTree_Match(b *testing.B) {
|
||||||
|
tree := NewTree4()
|
||||||
|
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
|
||||||
|
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
|
||||||
|
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
|
||||||
|
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
|
||||||
|
|
||||||
|
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
|
||||||
|
b.Run("found", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tree.Match(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
|
||||||
|
b.Run("not found", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tree.Match(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,26 +1,27 @@
|
||||||
package nebula
|
package cidr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const startbit6 = uint64(1 << 63)
|
const startbit6 = uint64(1 << 63)
|
||||||
|
|
||||||
type CIDR6Tree struct {
|
type Tree6 struct {
|
||||||
root4 *CIDRNode
|
root4 *Node
|
||||||
root6 *CIDRNode
|
root6 *Node
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCIDR6Tree() *CIDR6Tree {
|
func NewTree6() *Tree6 {
|
||||||
tree := new(CIDR6Tree)
|
tree := new(Tree6)
|
||||||
tree.root4 = &CIDRNode{}
|
tree.root4 = &Node{}
|
||||||
tree.root6 = &CIDRNode{}
|
tree.root6 = &Node{}
|
||||||
return tree
|
return tree
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
|
func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
var node, next *CIDRNode
|
var node, next *Node
|
||||||
|
|
||||||
cidrIP, ipv4 := isIPV4(cidr.IP)
|
cidrIP, ipv4 := isIPV4(cidr.IP)
|
||||||
if ipv4 {
|
if ipv4 {
|
||||||
|
@ -33,8 +34,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < len(cidrIP); i += 4 {
|
for i := 0; i < len(cidrIP); i += 4 {
|
||||||
ip := binary.BigEndian.Uint32(cidrIP[i : i+4])
|
ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
|
||||||
mask := binary.BigEndian.Uint32(cidr.Mask[i : i+4])
|
mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
|
||||||
bit := startbit
|
bit := startbit
|
||||||
|
|
||||||
// Find our last ancestor in the tree
|
// Find our last ancestor in the tree
|
||||||
|
@ -55,7 +56,7 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
|
|
||||||
// Build up the rest of the tree we don't already have
|
// Build up the rest of the tree we don't already have
|
||||||
for bit&mask != 0 {
|
for bit&mask != 0 {
|
||||||
next = &CIDRNode{}
|
next = &Node{}
|
||||||
next.parent = node
|
next.parent = node
|
||||||
|
|
||||||
if ip&bit != 0 {
|
if ip&bit != 0 {
|
||||||
|
@ -74,8 +75,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finds the most specific match
|
// Finds the most specific match
|
||||||
func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
|
func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
|
||||||
var node *CIDRNode
|
var node *Node
|
||||||
|
|
||||||
wholeIP, ipv4 := isIPV4(ip)
|
wholeIP, ipv4 := isIPV4(ip)
|
||||||
if ipv4 {
|
if ipv4 {
|
||||||
|
@ -85,7 +86,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < len(wholeIP); i += 4 {
|
for i := 0; i < len(wholeIP); i += 4 {
|
||||||
ip := ip2int(wholeIP[i : i+4])
|
ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
|
||||||
bit := startbit
|
bit := startbit
|
||||||
|
|
||||||
for node != nil {
|
for node != nil {
|
||||||
|
@ -110,7 +111,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
|
func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root4
|
node := tree.root4
|
||||||
|
|
||||||
|
@ -131,7 +132,7 @@ func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tree *CIDR6Tree) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
|
func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
|
||||||
ip := hi
|
ip := hi
|
||||||
node := tree.root6
|
node := tree.root6
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package nebula
|
package cidr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -8,17 +9,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
|
func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
|
||||||
tree := NewCIDR6Tree()
|
tree := NewTree6()
|
||||||
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
|
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||||
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
|
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||||
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
|
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||||
tree.AddCIDR(getCIDR("4.1.1.1/24"), "4a")
|
tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
|
||||||
tree.AddCIDR(getCIDR("4.1.1.1/30"), "4b")
|
tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
|
||||||
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
|
tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
|
||||||
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
|
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||||
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
|
||||||
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
|
||||||
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
Result interface{}
|
Result interface{}
|
||||||
|
@ -46,9 +47,9 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
|
||||||
assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
|
assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
|
||||||
}
|
}
|
||||||
|
|
||||||
tree = NewCIDR6Tree()
|
tree = NewTree6()
|
||||||
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
|
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||||
tree.AddCIDR(getCIDR("::/0"), "cool6")
|
tree.AddCIDR(Parse("::/0"), "cool6")
|
||||||
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
|
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
|
||||||
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
|
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
|
||||||
assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
|
assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
|
||||||
|
@ -56,10 +57,10 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
|
func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
|
||||||
tree := NewCIDR6Tree()
|
tree := NewTree6()
|
||||||
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
|
||||||
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
|
||||||
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
Result interface{}
|
Result interface{}
|
||||||
|
@ -71,7 +72,10 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
ip := NewIp6AndPort(net.ParseIP(tt.IP), 0)
|
ip := net.ParseIP(tt.IP)
|
||||||
assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(ip.Hi, ip.Lo))
|
hi := binary.BigEndian.Uint64(ip[:8])
|
||||||
|
lo := binary.BigEndian.Uint64(ip[8:])
|
||||||
|
|
||||||
|
assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo))
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,157 +0,0 @@
|
||||||
package nebula
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCIDRTree_Contains(t *testing.T) {
|
|
||||||
tree := NewCIDRTree()
|
|
||||||
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
|
|
||||||
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
|
|
||||||
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
|
|
||||||
tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
|
|
||||||
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4b")
|
|
||||||
tree.AddCIDR(getCIDR("4.1.2.1/32"), "4c")
|
|
||||||
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
Result interface{}
|
|
||||||
IP string
|
|
||||||
}{
|
|
||||||
{"1", "1.0.0.0"},
|
|
||||||
{"1", "1.255.255.255"},
|
|
||||||
{"2", "2.1.0.0"},
|
|
||||||
{"2", "2.1.255.255"},
|
|
||||||
{"3", "3.1.1.0"},
|
|
||||||
{"3", "3.1.1.255"},
|
|
||||||
{"4a", "4.1.1.255"},
|
|
||||||
{"4a", "4.1.1.1"},
|
|
||||||
{"5", "240.0.0.0"},
|
|
||||||
{"5", "255.255.255.255"},
|
|
||||||
{nil, "239.0.0.0"},
|
|
||||||
{nil, "4.1.2.2"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
assert.Equal(t, tt.Result, tree.Contains(ip2int(net.ParseIP(tt.IP))))
|
|
||||||
}
|
|
||||||
|
|
||||||
tree = NewCIDRTree()
|
|
||||||
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
|
|
||||||
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
|
|
||||||
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCIDRTree_MostSpecificContains(t *testing.T) {
|
|
||||||
tree := NewCIDRTree()
|
|
||||||
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
|
|
||||||
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
|
|
||||||
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
|
|
||||||
tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
|
|
||||||
tree.AddCIDR(getCIDR("4.1.1.0/30"), "4b")
|
|
||||||
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
|
|
||||||
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
Result interface{}
|
|
||||||
IP string
|
|
||||||
}{
|
|
||||||
{"1", "1.0.0.0"},
|
|
||||||
{"1", "1.255.255.255"},
|
|
||||||
{"2", "2.1.0.0"},
|
|
||||||
{"2", "2.1.255.255"},
|
|
||||||
{"3", "3.1.1.0"},
|
|
||||||
{"3", "3.1.1.255"},
|
|
||||||
{"4a", "4.1.1.255"},
|
|
||||||
{"4b", "4.1.1.2"},
|
|
||||||
{"4c", "4.1.1.1"},
|
|
||||||
{"5", "240.0.0.0"},
|
|
||||||
{"5", "255.255.255.255"},
|
|
||||||
{nil, "239.0.0.0"},
|
|
||||||
{nil, "4.1.2.2"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
assert.Equal(t, tt.Result, tree.MostSpecificContains(ip2int(net.ParseIP(tt.IP))))
|
|
||||||
}
|
|
||||||
|
|
||||||
tree = NewCIDRTree()
|
|
||||||
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
|
|
||||||
assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("0.0.0.0"))))
|
|
||||||
assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("255.255.255.255"))))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCIDRTree_Match(t *testing.T) {
|
|
||||||
tree := NewCIDRTree()
|
|
||||||
tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a")
|
|
||||||
tree.AddCIDR(getCIDR("4.1.1.1/32"), "1b")
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
Result interface{}
|
|
||||||
IP string
|
|
||||||
}{
|
|
||||||
{"1a", "4.1.1.0"},
|
|
||||||
{"1b", "4.1.1.1"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
assert.Equal(t, tt.Result, tree.Match(ip2int(net.ParseIP(tt.IP))))
|
|
||||||
}
|
|
||||||
|
|
||||||
tree = NewCIDRTree()
|
|
||||||
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
|
|
||||||
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
|
|
||||||
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkCIDRTree_Contains(b *testing.B) {
|
|
||||||
tree := NewCIDRTree()
|
|
||||||
tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
|
|
||||||
tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
|
|
||||||
tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
|
|
||||||
tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
|
|
||||||
|
|
||||||
ip := ip2int(net.ParseIP("1.2.1.1"))
|
|
||||||
b.Run("found", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tree.Contains(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ip = ip2int(net.ParseIP("1.2.1.255"))
|
|
||||||
b.Run("not found", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tree.Contains(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkCIDRTree_Match(b *testing.B) {
|
|
||||||
tree := NewCIDRTree()
|
|
||||||
tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
|
|
||||||
tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
|
|
||||||
tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
|
|
||||||
tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
|
|
||||||
|
|
||||||
ip := ip2int(net.ParseIP("1.2.1.1"))
|
|
||||||
b.Run("found", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tree.Match(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ip = ip2int(net.ParseIP("1.2.1.255"))
|
|
||||||
b.Run("not found", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tree.Match(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCIDR(s string) *net.IPNet {
|
|
||||||
_, c, _ := net.ParseCIDR(s)
|
|
||||||
return c
|
|
||||||
}
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A version string that can be set with
|
// A version string that can be set with
|
||||||
|
@ -49,14 +50,14 @@ func main() {
|
||||||
l := logrus.New()
|
l := logrus.New()
|
||||||
l.Out = os.Stdout
|
l.Out = os.Stdout
|
||||||
|
|
||||||
config := nebula.NewConfig(l)
|
c := config.NewC(l)
|
||||||
err := config.Load(*configPath)
|
err := c.Load(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("failed to load config: %s", err)
|
fmt.Printf("failed to load config: %s", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
|
|
||||||
switch v := err.(type) {
|
switch v := err.(type) {
|
||||||
case nebula.ContextualError:
|
case nebula.ContextualError:
|
||||||
|
@ -68,8 +69,8 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
c.Start()
|
ctrl.Start()
|
||||||
c.ShutdownBlock()
|
ctrl.ShutdownBlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger service.Logger
|
var logger service.Logger
|
||||||
|
@ -27,13 +28,13 @@ func (p *program) Start(s service.Service) error {
|
||||||
l := logrus.New()
|
l := logrus.New()
|
||||||
HookLogger(l)
|
HookLogger(l)
|
||||||
|
|
||||||
config := nebula.NewConfig(l)
|
c := config.NewC(l)
|
||||||
err := config.Load(*p.configPath)
|
err := c.Load(*p.configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load config: %s", err)
|
return fmt.Errorf("failed to load config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
|
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A version string that can be set with
|
// A version string that can be set with
|
||||||
|
@ -43,14 +44,14 @@ func main() {
|
||||||
l := logrus.New()
|
l := logrus.New()
|
||||||
l.Out = os.Stdout
|
l.Out = os.Stdout
|
||||||
|
|
||||||
config := nebula.NewConfig(l)
|
c := config.NewC(l)
|
||||||
err := config.Load(*configPath)
|
err := c.Load(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("failed to load config: %s", err)
|
fmt.Printf("failed to load config: %s", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
|
|
||||||
switch v := err.(type) {
|
switch v := err.(type) {
|
||||||
case nebula.ContextualError:
|
case nebula.ContextualError:
|
||||||
|
@ -62,8 +63,8 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
c.Start()
|
ctrl.Start()
|
||||||
c.ShutdownBlock()
|
ctrl.ShutdownBlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|
611
config.go
611
config.go
|
@ -1,611 +0,0 @@
|
||||||
package nebula
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/imdario/mergo"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"gopkg.in/yaml.v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
path string
|
|
||||||
files []string
|
|
||||||
Settings map[interface{}]interface{}
|
|
||||||
oldSettings map[interface{}]interface{}
|
|
||||||
callbacks []func(*Config)
|
|
||||||
l *logrus.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConfig(l *logrus.Logger) *Config {
|
|
||||||
return &Config{
|
|
||||||
Settings: make(map[interface{}]interface{}),
|
|
||||||
l: l,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load will find all yaml files within path and load them in lexical order
|
|
||||||
func (c *Config) Load(path string) error {
|
|
||||||
c.path = path
|
|
||||||
c.files = make([]string, 0)
|
|
||||||
|
|
||||||
err := c.resolve(path, true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(c.files) == 0 {
|
|
||||||
return fmt.Errorf("no config files found at %s", path)
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(c.files)
|
|
||||||
|
|
||||||
err = c.parse()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) LoadString(raw string) error {
|
|
||||||
if raw == "" {
|
|
||||||
return errors.New("Empty configuration")
|
|
||||||
}
|
|
||||||
return c.parseRaw([]byte(raw))
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
|
|
||||||
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
|
|
||||||
// used to help decide if a change is necessary.
|
|
||||||
// These functions should return quickly or spawn their own go routine if they will take a while
|
|
||||||
func (c *Config) RegisterReloadCallback(f func(*Config)) {
|
|
||||||
c.callbacks = append(c.callbacks, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
|
|
||||||
// k in both the old and new settings will be serialized, the result of the string comparison is returned.
|
|
||||||
// If k is an empty string the entire config is tested.
|
|
||||||
// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
|
|
||||||
// there is change when there actually wasn't any.
|
|
||||||
func (c *Config) HasChanged(k string) bool {
|
|
||||||
if c.oldSettings == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
nv interface{}
|
|
||||||
ov interface{}
|
|
||||||
)
|
|
||||||
|
|
||||||
if k == "" {
|
|
||||||
nv = c.Settings
|
|
||||||
ov = c.oldSettings
|
|
||||||
k = "all settings"
|
|
||||||
} else {
|
|
||||||
nv = c.get(k, c.Settings)
|
|
||||||
ov = c.get(k, c.oldSettings)
|
|
||||||
}
|
|
||||||
|
|
||||||
newVals, err := yaml.Marshal(nv)
|
|
||||||
if err != nil {
|
|
||||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
|
||||||
}
|
|
||||||
|
|
||||||
oldVals, err := yaml.Marshal(ov)
|
|
||||||
if err != nil {
|
|
||||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(newVals) != string(oldVals)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
|
|
||||||
// original path provided to Load. The old settings are shallow copied for change detection after the reload.
|
|
||||||
func (c *Config) CatchHUP(ctx context.Context) {
|
|
||||||
ch := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(ch, syscall.SIGHUP)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
signal.Stop(ch)
|
|
||||||
close(ch)
|
|
||||||
return
|
|
||||||
case <-ch:
|
|
||||||
c.l.Info("Caught HUP, reloading config")
|
|
||||||
c.ReloadConfig()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) ReloadConfig() {
|
|
||||||
c.oldSettings = make(map[interface{}]interface{})
|
|
||||||
for k, v := range c.Settings {
|
|
||||||
c.oldSettings[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.Load(c.path)
|
|
||||||
if err != nil {
|
|
||||||
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range c.callbacks {
|
|
||||||
v(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetString will get the string for k or return the default d if not found or invalid
|
|
||||||
func (c *Config) GetString(k, d string) string {
|
|
||||||
r := c.Get(k)
|
|
||||||
if r == nil {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%v", r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
|
|
||||||
func (c *Config) GetStringSlice(k string, d []string) []string {
|
|
||||||
r := c.Get(k)
|
|
||||||
if r == nil {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
rv, ok := r.([]interface{})
|
|
||||||
if !ok {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
v := make([]string, len(rv))
|
|
||||||
for i := 0; i < len(v); i++ {
|
|
||||||
v[i] = fmt.Sprintf("%v", rv[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMap will get the map for k or return the default d if not found or invalid
|
|
||||||
func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
|
|
||||||
r := c.Get(k)
|
|
||||||
if r == nil {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
v, ok := r.(map[interface{}]interface{})
|
|
||||||
if !ok {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetInt will get the int for k or return the default d if not found or invalid
|
|
||||||
func (c *Config) GetInt(k string, d int) int {
|
|
||||||
r := c.GetString(k, strconv.Itoa(d))
|
|
||||||
v, err := strconv.Atoi(r)
|
|
||||||
if err != nil {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetBool will get the bool for k or return the default d if not found or invalid
|
|
||||||
func (c *Config) GetBool(k string, d bool) bool {
|
|
||||||
r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
|
|
||||||
v, err := strconv.ParseBool(r)
|
|
||||||
if err != nil {
|
|
||||||
switch r {
|
|
||||||
case "y", "yes":
|
|
||||||
return true
|
|
||||||
case "n", "no":
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDuration will get the duration for k or return the default d if not found or invalid
|
|
||||||
func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
|
|
||||||
r := c.GetString(k, "")
|
|
||||||
v, err := time.ParseDuration(r)
|
|
||||||
if err != nil {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) GetLocalAllowList(k string) (*LocalAllowList, error) {
|
|
||||||
var nameRules []AllowListNameRule
|
|
||||||
handleKey := func(key string, value interface{}) (bool, error) {
|
|
||||||
if key == "interfaces" {
|
|
||||||
var err error
|
|
||||||
nameRules, err = c.getAllowListInterfaces(k, value)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
al, err := c.GetAllowList(k, handleKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) GetRemoteAllowList(k, rangesKey string) (*RemoteAllowList, error) {
|
|
||||||
al, err := c.GetAllowList(k, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
remoteAllowRanges, err := c.getRemoteAllowRanges(rangesKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) getRemoteAllowRanges(k string) (*CIDR6Tree, error) {
|
|
||||||
value := c.Get(k)
|
|
||||||
if value == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAllowRanges := NewCIDR6Tree()
|
|
||||||
|
|
||||||
rawMap, ok := value.(map[interface{}]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
|
||||||
}
|
|
||||||
for rawKey, rawValue := range rawMap {
|
|
||||||
rawCIDR, ok := rawKey.(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
allowList, err := c.getAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, cidr, err := net.ParseCIDR(rawCIDR)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAllowRanges.AddCIDR(cidr, allowList)
|
|
||||||
}
|
|
||||||
|
|
||||||
return remoteAllowRanges, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the handleKey func returns true, the rest of the parsing is skipped
|
|
||||||
// for this key. This allows parsing of special values like `interfaces`.
|
|
||||||
func (c *Config) GetAllowList(k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
|
||||||
r := c.Get(k)
|
|
||||||
if r == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.getAllowList(k, r, handleKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the handleKey func returns true, the rest of the parsing is skipped
|
|
||||||
// for this key. This allows parsing of special values like `interfaces`.
|
|
||||||
func (c *Config) getAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
|
||||||
rawMap, ok := raw.(map[interface{}]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
|
||||||
}
|
|
||||||
|
|
||||||
tree := NewCIDR6Tree()
|
|
||||||
|
|
||||||
// Keep track of the rules we have added for both ipv4 and ipv6
|
|
||||||
type allowListRules struct {
|
|
||||||
firstValue bool
|
|
||||||
allValuesMatch bool
|
|
||||||
defaultSet bool
|
|
||||||
allValues bool
|
|
||||||
}
|
|
||||||
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
|
||||||
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
|
||||||
|
|
||||||
for rawKey, rawValue := range rawMap {
|
|
||||||
rawCIDR, ok := rawKey.(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
if handleKey != nil {
|
|
||||||
handled, err := handleKey(rawCIDR, rawValue)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if handled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
value, ok := rawValue.(bool)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, cidr, err := net.ParseCIDR(rawCIDR)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: should we error on duplicate CIDRs in the config?
|
|
||||||
tree.AddCIDR(cidr, value)
|
|
||||||
|
|
||||||
maskBits, maskSize := cidr.Mask.Size()
|
|
||||||
|
|
||||||
var rules *allowListRules
|
|
||||||
if maskSize == 32 {
|
|
||||||
rules = &rules4
|
|
||||||
} else {
|
|
||||||
rules = &rules6
|
|
||||||
}
|
|
||||||
|
|
||||||
if rules.firstValue {
|
|
||||||
rules.allValues = value
|
|
||||||
rules.firstValue = false
|
|
||||||
} else {
|
|
||||||
if value != rules.allValues {
|
|
||||||
rules.allValuesMatch = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this is 0.0.0.0/0 or ::/0
|
|
||||||
if maskBits == 0 {
|
|
||||||
rules.defaultSet = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !rules4.defaultSet {
|
|
||||||
if rules4.allValuesMatch {
|
|
||||||
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
|
|
||||||
tree.AddCIDR(zeroCIDR, !rules4.allValues)
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !rules6.defaultSet {
|
|
||||||
if rules6.allValuesMatch {
|
|
||||||
_, zeroCIDR, _ := net.ParseCIDR("::/0")
|
|
||||||
tree.AddCIDR(zeroCIDR, !rules6.allValues)
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &AllowList{cidrTree: tree}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
|
|
||||||
var nameRules []AllowListNameRule
|
|
||||||
|
|
||||||
rawRules, ok := v.(map[interface{}]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
firstEntry := true
|
|
||||||
var allValues bool
|
|
||||||
for rawName, rawAllow := range rawRules {
|
|
||||||
name, ok := rawName.(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
|
|
||||||
}
|
|
||||||
allow, ok := rawAllow.(bool)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
|
|
||||||
}
|
|
||||||
|
|
||||||
nameRE, err := regexp.Compile("^" + name + "$")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nameRules = append(nameRules, AllowListNameRule{
|
|
||||||
Name: nameRE,
|
|
||||||
Allow: allow,
|
|
||||||
})
|
|
||||||
|
|
||||||
if firstEntry {
|
|
||||||
allValues = allow
|
|
||||||
firstEntry = false
|
|
||||||
} else {
|
|
||||||
if allow != allValues {
|
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nameRules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) Get(k string) interface{} {
|
|
||||||
return c.get(k, c.Settings)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) IsSet(k string) bool {
|
|
||||||
return c.get(k, c.Settings) != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) get(k string, v interface{}) interface{} {
|
|
||||||
parts := strings.Split(k, ".")
|
|
||||||
for _, p := range parts {
|
|
||||||
m, ok := v.(map[interface{}]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
v, ok = m[p]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// direct signifies if this is the config path directly specified by the user,
|
|
||||||
// versus a file/dir found by recursing into that path
|
|
||||||
func (c *Config) resolve(path string, direct bool) error {
|
|
||||||
i, err := os.Stat(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !i.IsDir() {
|
|
||||||
c.addFile(path, direct)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
paths, err := readDirNames(path)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("problem while reading directory %s: %s", path, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range paths {
|
|
||||||
err := c.resolve(filepath.Join(path, p), false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) addFile(path string, direct bool) error {
|
|
||||||
ext := filepath.Ext(path)
|
|
||||||
|
|
||||||
if !direct && ext != ".yaml" && ext != ".yml" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ap, err := filepath.Abs(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.files = append(c.files, ap)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) parseRaw(b []byte) error {
|
|
||||||
var m map[interface{}]interface{}
|
|
||||||
|
|
||||||
err := yaml.Unmarshal(b, &m)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Settings = m
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) parse() error {
|
|
||||||
var m map[interface{}]interface{}
|
|
||||||
|
|
||||||
for _, path := range c.files {
|
|
||||||
b, err := ioutil.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var nm map[interface{}]interface{}
|
|
||||||
err = yaml.Unmarshal(b, &nm)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need to use WithAppendSlice so that firewall rules in separate
|
|
||||||
// files are appended together
|
|
||||||
err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
|
|
||||||
m = nm
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Settings = m
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func readDirNames(path string) ([]string, error) {
|
|
||||||
f, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
paths, err := f.Readdirnames(-1)
|
|
||||||
f.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(paths)
|
|
||||||
return paths, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func configLogger(c *Config) error {
|
|
||||||
// set up our logging level
|
|
||||||
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
|
|
||||||
}
|
|
||||||
c.l.SetLevel(logLevel)
|
|
||||||
|
|
||||||
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
|
||||||
timestampFormat := c.GetString("logging.timestamp_format", "")
|
|
||||||
fullTimestamp := (timestampFormat != "")
|
|
||||||
if timestampFormat == "" {
|
|
||||||
timestampFormat = time.RFC3339
|
|
||||||
}
|
|
||||||
|
|
||||||
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
|
||||||
switch logFormat {
|
|
||||||
case "text":
|
|
||||||
c.l.Formatter = &logrus.TextFormatter{
|
|
||||||
TimestampFormat: timestampFormat,
|
|
||||||
FullTimestamp: fullTimestamp,
|
|
||||||
DisableTimestamp: disableTimestamp,
|
|
||||||
}
|
|
||||||
case "json":
|
|
||||||
c.l.Formatter = &logrus.JSONFormatter{
|
|
||||||
TimestampFormat: timestampFormat,
|
|
||||||
DisableTimestamp: disableTimestamp,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -0,0 +1,358 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/imdario/mergo"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type C struct {
|
||||||
|
path string
|
||||||
|
files []string
|
||||||
|
Settings map[interface{}]interface{}
|
||||||
|
oldSettings map[interface{}]interface{}
|
||||||
|
callbacks []func(*C)
|
||||||
|
l *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewC(l *logrus.Logger) *C {
|
||||||
|
return &C{
|
||||||
|
Settings: make(map[interface{}]interface{}),
|
||||||
|
l: l,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load will find all yaml files within path and load them in lexical order
|
||||||
|
func (c *C) Load(path string) error {
|
||||||
|
c.path = path
|
||||||
|
c.files = make([]string, 0)
|
||||||
|
|
||||||
|
err := c.resolve(path, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.files) == 0 {
|
||||||
|
return fmt.Errorf("no config files found at %s", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(c.files)
|
||||||
|
|
||||||
|
err = c.parse()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *C) LoadString(raw string) error {
|
||||||
|
if raw == "" {
|
||||||
|
return errors.New("Empty configuration")
|
||||||
|
}
|
||||||
|
return c.parseRaw([]byte(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
|
||||||
|
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
|
||||||
|
// used to help decide if a change is necessary.
|
||||||
|
// These functions should return quickly or spawn their own go routine if they will take a while
|
||||||
|
func (c *C) RegisterReloadCallback(f func(*C)) {
|
||||||
|
c.callbacks = append(c.callbacks, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
|
||||||
|
// k in both the old and new settings will be serialized, the result of the string comparison is returned.
|
||||||
|
// If k is an empty string the entire config is tested.
|
||||||
|
// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
|
||||||
|
// there is change when there actually wasn't any.
|
||||||
|
func (c *C) HasChanged(k string) bool {
|
||||||
|
if c.oldSettings == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
nv interface{}
|
||||||
|
ov interface{}
|
||||||
|
)
|
||||||
|
|
||||||
|
if k == "" {
|
||||||
|
nv = c.Settings
|
||||||
|
ov = c.oldSettings
|
||||||
|
k = "all settings"
|
||||||
|
} else {
|
||||||
|
nv = c.get(k, c.Settings)
|
||||||
|
ov = c.get(k, c.oldSettings)
|
||||||
|
}
|
||||||
|
|
||||||
|
newVals, err := yaml.Marshal(nv)
|
||||||
|
if err != nil {
|
||||||
|
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldVals, err := yaml.Marshal(ov)
|
||||||
|
if err != nil {
|
||||||
|
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(newVals) != string(oldVals)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
|
||||||
|
// original path provided to Load. The old settings are shallow copied for change detection after the reload.
|
||||||
|
func (c *C) CatchHUP(ctx context.Context) {
|
||||||
|
ch := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(ch, syscall.SIGHUP)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
signal.Stop(ch)
|
||||||
|
close(ch)
|
||||||
|
return
|
||||||
|
case <-ch:
|
||||||
|
c.l.Info("Caught HUP, reloading config")
|
||||||
|
c.ReloadConfig()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *C) ReloadConfig() {
|
||||||
|
c.oldSettings = make(map[interface{}]interface{})
|
||||||
|
for k, v := range c.Settings {
|
||||||
|
c.oldSettings[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.Load(c.path)
|
||||||
|
if err != nil {
|
||||||
|
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range c.callbacks {
|
||||||
|
v(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetString will get the string for k or return the default d if not found or invalid
|
||||||
|
func (c *C) GetString(k, d string) string {
|
||||||
|
r := c.Get(k)
|
||||||
|
if r == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%v", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
|
||||||
|
func (c *C) GetStringSlice(k string, d []string) []string {
|
||||||
|
r := c.Get(k)
|
||||||
|
if r == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
rv, ok := r.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
v := make([]string, len(rv))
|
||||||
|
for i := 0; i < len(v); i++ {
|
||||||
|
v[i] = fmt.Sprintf("%v", rv[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMap will get the map for k or return the default d if not found or invalid
|
||||||
|
func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
|
||||||
|
r := c.Get(k)
|
||||||
|
if r == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok := r.(map[interface{}]interface{})
|
||||||
|
if !ok {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInt will get the int for k or return the default d if not found or invalid
|
||||||
|
func (c *C) GetInt(k string, d int) int {
|
||||||
|
r := c.GetString(k, strconv.Itoa(d))
|
||||||
|
v, err := strconv.Atoi(r)
|
||||||
|
if err != nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBool will get the bool for k or return the default d if not found or invalid
|
||||||
|
func (c *C) GetBool(k string, d bool) bool {
|
||||||
|
r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
|
||||||
|
v, err := strconv.ParseBool(r)
|
||||||
|
if err != nil {
|
||||||
|
switch r {
|
||||||
|
case "y", "yes":
|
||||||
|
return true
|
||||||
|
case "n", "no":
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDuration will get the duration for k or return the default d if not found or invalid
|
||||||
|
func (c *C) GetDuration(k string, d time.Duration) time.Duration {
|
||||||
|
r := c.GetString(k, "")
|
||||||
|
v, err := time.ParseDuration(r)
|
||||||
|
if err != nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *C) Get(k string) interface{} {
|
||||||
|
return c.get(k, c.Settings)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *C) IsSet(k string) bool {
|
||||||
|
return c.get(k, c.Settings) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *C) get(k string, v interface{}) interface{} {
|
||||||
|
parts := strings.Split(k, ".")
|
||||||
|
for _, p := range parts {
|
||||||
|
m, ok := v.(map[interface{}]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = m[p]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// direct signifies if this is the config path directly specified by the user,
|
||||||
|
// versus a file/dir found by recursing into that path
|
||||||
|
func (c *C) resolve(path string, direct bool) error {
|
||||||
|
i, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !i.IsDir() {
|
||||||
|
c.addFile(path, direct)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
paths, err := readDirNames(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("problem while reading directory %s: %s", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range paths {
|
||||||
|
err := c.resolve(filepath.Join(path, p), false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *C) addFile(path string, direct bool) error {
|
||||||
|
ext := filepath.Ext(path)
|
||||||
|
|
||||||
|
if !direct && ext != ".yaml" && ext != ".yml" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ap, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.files = append(c.files, ap)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *C) parseRaw(b []byte) error {
|
||||||
|
var m map[interface{}]interface{}
|
||||||
|
|
||||||
|
err := yaml.Unmarshal(b, &m)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Settings = m
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *C) parse() error {
|
||||||
|
var m map[interface{}]interface{}
|
||||||
|
|
||||||
|
for _, path := range c.files {
|
||||||
|
b, err := ioutil.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var nm map[interface{}]interface{}
|
||||||
|
err = yaml.Unmarshal(b, &nm)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to use WithAppendSlice so that firewall rules in separate
|
||||||
|
// files are appended together
|
||||||
|
err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
|
||||||
|
m = nm
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Settings = m
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readDirNames(path string) ([]string, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
paths, err := f.Readdirnames(-1)
|
||||||
|
f.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(paths)
|
||||||
|
return paths, nil
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package nebula
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -7,19 +7,20 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
dir, err := ioutil.TempDir("", "config-test")
|
dir, err := ioutil.TempDir("", "config-test")
|
||||||
// invalid yaml
|
// invalid yaml
|
||||||
c := NewConfig(l)
|
c := NewC(l)
|
||||||
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
||||||
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
||||||
|
|
||||||
// simple multi config merge
|
// simple multi config merge
|
||||||
c = NewConfig(l)
|
c = NewC(l)
|
||||||
os.RemoveAll(dir)
|
os.RemoveAll(dir)
|
||||||
os.Mkdir(dir, 0755)
|
os.Mkdir(dir, 0755)
|
||||||
|
|
||||||
|
@ -41,9 +42,9 @@ func TestConfig_Load(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_Get(t *testing.T) {
|
func TestConfig_Get(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
// test simple type
|
// test simple type
|
||||||
c := NewConfig(l)
|
c := NewC(l)
|
||||||
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
||||||
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
||||||
|
|
||||||
|
@ -57,15 +58,15 @@ func TestConfig_Get(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GetStringSlice(t *testing.T) {
|
func TestConfig_GetStringSlice(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
c := NewConfig(l)
|
c := NewC(l)
|
||||||
c.Settings["slice"] = []interface{}{"one", "two"}
|
c.Settings["slice"] = []interface{}{"one", "two"}
|
||||||
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GetBool(t *testing.T) {
|
func TestConfig_GetBool(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
c := NewConfig(l)
|
c := NewC(l)
|
||||||
c.Settings["bool"] = true
|
c.Settings["bool"] = true
|
||||||
assert.Equal(t, true, c.GetBool("bool", false))
|
assert.Equal(t, true, c.GetBool("bool", false))
|
||||||
|
|
||||||
|
@ -91,108 +92,22 @@ func TestConfig_GetBool(t *testing.T) {
|
||||||
assert.Equal(t, false, c.GetBool("bool", true))
|
assert.Equal(t, false, c.GetBool("bool", true))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GetAllowList(t *testing.T) {
|
|
||||||
l := NewTestLogger()
|
|
||||||
c := NewConfig(l)
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
|
||||||
"192.168.0.0": true,
|
|
||||||
}
|
|
||||||
r, err := c.GetAllowList("allowlist", nil)
|
|
||||||
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
|
|
||||||
assert.Nil(t, r)
|
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
|
||||||
"192.168.0.0/16": "abc",
|
|
||||||
}
|
|
||||||
r, err = c.GetAllowList("allowlist", nil)
|
|
||||||
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
|
||||||
"192.168.0.0/16": true,
|
|
||||||
"10.0.0.0/8": false,
|
|
||||||
}
|
|
||||||
r, err = c.GetAllowList("allowlist", nil)
|
|
||||||
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
|
||||||
"0.0.0.0/0": true,
|
|
||||||
"10.0.0.0/8": false,
|
|
||||||
"10.42.42.0/24": true,
|
|
||||||
"fd00::/8": true,
|
|
||||||
"fd00:fd00::/16": false,
|
|
||||||
}
|
|
||||||
r, err = c.GetAllowList("allowlist", nil)
|
|
||||||
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
|
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
|
||||||
"0.0.0.0/0": true,
|
|
||||||
"10.0.0.0/8": false,
|
|
||||||
"10.42.42.0/24": true,
|
|
||||||
}
|
|
||||||
r, err = c.GetAllowList("allowlist", nil)
|
|
||||||
if assert.NoError(t, err) {
|
|
||||||
assert.NotNil(t, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
|
||||||
"0.0.0.0/0": true,
|
|
||||||
"10.0.0.0/8": false,
|
|
||||||
"10.42.42.0/24": true,
|
|
||||||
"::/0": false,
|
|
||||||
"fd00::/8": true,
|
|
||||||
"fd00:fd00::/16": false,
|
|
||||||
}
|
|
||||||
r, err = c.GetAllowList("allowlist", nil)
|
|
||||||
if assert.NoError(t, err) {
|
|
||||||
assert.NotNil(t, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test interface names
|
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
|
||||||
"interfaces": map[interface{}]interface{}{
|
|
||||||
`docker.*`: "foo",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
lr, err := c.GetLocalAllowList("allowlist")
|
|
||||||
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
|
||||||
"interfaces": map[interface{}]interface{}{
|
|
||||||
`docker.*`: false,
|
|
||||||
`eth.*`: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
lr, err = c.GetLocalAllowList("allowlist")
|
|
||||||
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
|
||||||
"interfaces": map[interface{}]interface{}{
|
|
||||||
`docker.*`: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
lr, err = c.GetLocalAllowList("allowlist")
|
|
||||||
if assert.NoError(t, err) {
|
|
||||||
assert.NotNil(t, lr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_HasChanged(t *testing.T) {
|
func TestConfig_HasChanged(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
// No reload has occurred, return false
|
// No reload has occurred, return false
|
||||||
c := NewConfig(l)
|
c := NewC(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
assert.False(t, c.HasChanged(""))
|
assert.False(t, c.HasChanged(""))
|
||||||
|
|
||||||
// Test key change
|
// Test key change
|
||||||
c = NewConfig(l)
|
c = NewC(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
c.oldSettings = map[interface{}]interface{}{"test": "no"}
|
c.oldSettings = map[interface{}]interface{}{"test": "no"}
|
||||||
assert.True(t, c.HasChanged("test"))
|
assert.True(t, c.HasChanged("test"))
|
||||||
assert.True(t, c.HasChanged(""))
|
assert.True(t, c.HasChanged(""))
|
||||||
|
|
||||||
// No key change
|
// No key change
|
||||||
c = NewConfig(l)
|
c = NewC(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
|
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
|
||||||
assert.False(t, c.HasChanged("test"))
|
assert.False(t, c.HasChanged("test"))
|
||||||
|
@ -200,13 +115,13 @@ func TestConfig_HasChanged(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_ReloadConfig(t *testing.T) {
|
func TestConfig_ReloadConfig(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
done := make(chan bool, 1)
|
done := make(chan bool, 1)
|
||||||
dir, err := ioutil.TempDir("", "config-test")
|
dir, err := ioutil.TempDir("", "config-test")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||||
|
|
||||||
c := NewConfig(l)
|
c := NewC(l)
|
||||||
assert.Nil(t, c.Load(dir))
|
assert.Nil(t, c.Load(dir))
|
||||||
|
|
||||||
assert.False(t, c.HasChanged("outer.inner"))
|
assert.False(t, c.HasChanged("outer.inner"))
|
||||||
|
@ -215,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) {
|
||||||
|
|
||||||
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644)
|
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644)
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *Config) {
|
c.RegisterReloadCallback(func(c *C) {
|
||||||
done <- true
|
done <- true
|
||||||
})
|
})
|
||||||
|
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
|
// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
|
||||||
|
@ -13,16 +15,16 @@ import (
|
||||||
|
|
||||||
type connectionManager struct {
|
type connectionManager struct {
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
in map[uint32]struct{}
|
in map[iputil.VpnIp]struct{}
|
||||||
inLock *sync.RWMutex
|
inLock *sync.RWMutex
|
||||||
inCount int
|
inCount int
|
||||||
out map[uint32]struct{}
|
out map[iputil.VpnIp]struct{}
|
||||||
outLock *sync.RWMutex
|
outLock *sync.RWMutex
|
||||||
outCount int
|
outCount int
|
||||||
TrafficTimer *SystemTimerWheel
|
TrafficTimer *SystemTimerWheel
|
||||||
intf *Interface
|
intf *Interface
|
||||||
|
|
||||||
pendingDeletion map[uint32]int
|
pendingDeletion map[iputil.VpnIp]int
|
||||||
pendingDeletionLock *sync.RWMutex
|
pendingDeletionLock *sync.RWMutex
|
||||||
pendingDeletionTimer *SystemTimerWheel
|
pendingDeletionTimer *SystemTimerWheel
|
||||||
|
|
||||||
|
@ -36,15 +38,15 @@ type connectionManager struct {
|
||||||
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
|
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
|
||||||
nc := &connectionManager{
|
nc := &connectionManager{
|
||||||
hostMap: intf.hostMap,
|
hostMap: intf.hostMap,
|
||||||
in: make(map[uint32]struct{}),
|
in: make(map[iputil.VpnIp]struct{}),
|
||||||
inLock: &sync.RWMutex{},
|
inLock: &sync.RWMutex{},
|
||||||
inCount: 0,
|
inCount: 0,
|
||||||
out: make(map[uint32]struct{}),
|
out: make(map[iputil.VpnIp]struct{}),
|
||||||
outLock: &sync.RWMutex{},
|
outLock: &sync.RWMutex{},
|
||||||
outCount: 0,
|
outCount: 0,
|
||||||
TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
|
TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
|
||||||
intf: intf,
|
intf: intf,
|
||||||
pendingDeletion: make(map[uint32]int),
|
pendingDeletion: make(map[iputil.VpnIp]int),
|
||||||
pendingDeletionLock: &sync.RWMutex{},
|
pendingDeletionLock: &sync.RWMutex{},
|
||||||
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
|
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
|
||||||
checkInterval: checkInterval,
|
checkInterval: checkInterval,
|
||||||
|
@ -55,7 +57,7 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
|
||||||
return nc
|
return nc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) In(ip uint32) {
|
func (n *connectionManager) In(ip iputil.VpnIp) {
|
||||||
n.inLock.RLock()
|
n.inLock.RLock()
|
||||||
// If this already exists, return
|
// If this already exists, return
|
||||||
if _, ok := n.in[ip]; ok {
|
if _, ok := n.in[ip]; ok {
|
||||||
|
@ -68,7 +70,7 @@ func (n *connectionManager) In(ip uint32) {
|
||||||
n.inLock.Unlock()
|
n.inLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) Out(ip uint32) {
|
func (n *connectionManager) Out(ip iputil.VpnIp) {
|
||||||
n.outLock.RLock()
|
n.outLock.RLock()
|
||||||
// If this already exists, return
|
// If this already exists, return
|
||||||
if _, ok := n.out[ip]; ok {
|
if _, ok := n.out[ip]; ok {
|
||||||
|
@ -87,9 +89,9 @@ func (n *connectionManager) Out(ip uint32) {
|
||||||
n.outLock.Unlock()
|
n.outLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) CheckIn(vpnIP uint32) bool {
|
func (n *connectionManager) CheckIn(vpnIp iputil.VpnIp) bool {
|
||||||
n.inLock.RLock()
|
n.inLock.RLock()
|
||||||
if _, ok := n.in[vpnIP]; ok {
|
if _, ok := n.in[vpnIp]; ok {
|
||||||
n.inLock.RUnlock()
|
n.inLock.RUnlock()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -97,7 +99,7 @@ func (n *connectionManager) CheckIn(vpnIP uint32) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) ClearIP(ip uint32) {
|
func (n *connectionManager) ClearIP(ip iputil.VpnIp) {
|
||||||
n.inLock.Lock()
|
n.inLock.Lock()
|
||||||
n.outLock.Lock()
|
n.outLock.Lock()
|
||||||
delete(n.in, ip)
|
delete(n.in, ip)
|
||||||
|
@ -106,13 +108,13 @@ func (n *connectionManager) ClearIP(ip uint32) {
|
||||||
n.outLock.Unlock()
|
n.outLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) ClearPendingDeletion(ip uint32) {
|
func (n *connectionManager) ClearPendingDeletion(ip iputil.VpnIp) {
|
||||||
n.pendingDeletionLock.Lock()
|
n.pendingDeletionLock.Lock()
|
||||||
delete(n.pendingDeletion, ip)
|
delete(n.pendingDeletion, ip)
|
||||||
n.pendingDeletionLock.Unlock()
|
n.pendingDeletionLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) AddPendingDeletion(ip uint32) {
|
func (n *connectionManager) AddPendingDeletion(ip iputil.VpnIp) {
|
||||||
n.pendingDeletionLock.Lock()
|
n.pendingDeletionLock.Lock()
|
||||||
if _, ok := n.pendingDeletion[ip]; ok {
|
if _, ok := n.pendingDeletion[ip]; ok {
|
||||||
n.pendingDeletion[ip] += 1
|
n.pendingDeletion[ip] += 1
|
||||||
|
@ -123,7 +125,7 @@ func (n *connectionManager) AddPendingDeletion(ip uint32) {
|
||||||
n.pendingDeletionLock.Unlock()
|
n.pendingDeletionLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
|
func (n *connectionManager) checkPendingDeletion(ip iputil.VpnIp) bool {
|
||||||
n.pendingDeletionLock.RLock()
|
n.pendingDeletionLock.RLock()
|
||||||
if _, ok := n.pendingDeletion[ip]; ok {
|
if _, ok := n.pendingDeletion[ip]; ok {
|
||||||
|
|
||||||
|
@ -134,8 +136,8 @@ func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) {
|
func (n *connectionManager) AddTrafficWatch(vpnIp iputil.VpnIp, seconds int) {
|
||||||
n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds))
|
n.TrafficTimer.Add(vpnIp, time.Second*time.Duration(seconds))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) Start(ctx context.Context) {
|
func (n *connectionManager) Start(ctx context.Context) {
|
||||||
|
@ -169,23 +171,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIP := ep.(uint32)
|
vpnIp := ep.(iputil.VpnIp)
|
||||||
|
|
||||||
// Check for traffic coming back in from this host.
|
// Check for traffic coming back in from this host.
|
||||||
traf := n.CheckIn(vpnIP)
|
traf := n.CheckIn(vpnIp)
|
||||||
|
|
||||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
n.l.Debugf("Not found in hostmap: %s", vpnIp)
|
||||||
|
|
||||||
if !n.intf.disconnectInvalid {
|
if !n.intf.disconnectInvalid {
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIp)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIp)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
|
if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,12 +195,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
||||||
// expired, just ignore.
|
// expired, just ignore.
|
||||||
if traf {
|
if traf {
|
||||||
if n.l.Level >= logrus.DebugLevel {
|
if n.l.Level >= logrus.DebugLevel {
|
||||||
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
n.l.WithField("vpnIp", vpnIp).
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
}
|
}
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIp)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIp)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,12 +210,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
||||||
|
|
||||||
if hostinfo != nil && hostinfo.ConnectionState != nil {
|
if hostinfo != nil && hostinfo.ConnectionState != nil {
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||||
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
|
n.intf.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, p, nb, out)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
|
hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", vpnIp)
|
||||||
}
|
}
|
||||||
n.AddPendingDeletion(vpnIP)
|
n.AddPendingDeletion(vpnIp)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -226,38 +228,38 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIP := ep.(uint32)
|
vpnIp := ep.(iputil.VpnIp)
|
||||||
|
|
||||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
n.l.Debugf("Not found in hostmap: %s", vpnIp)
|
||||||
|
|
||||||
if !n.intf.disconnectInvalid {
|
if !n.intf.disconnectInvalid {
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIp)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIp)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
|
if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we saw an incoming packets from this ip and peer's certificate is not
|
// If we saw an incoming packets from this ip and peer's certificate is not
|
||||||
// expired, just ignore.
|
// expired, just ignore.
|
||||||
traf := n.CheckIn(vpnIP)
|
traf := n.CheckIn(vpnIp)
|
||||||
if traf {
|
if traf {
|
||||||
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
n.l.WithField("vpnIp", vpnIp).
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
|
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIp)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIp)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// If it comes around on deletion wheel and hasn't resolved itself, delete
|
// If it comes around on deletion wheel and hasn't resolved itself, delete
|
||||||
if n.checkPendingDeletion(vpnIP) {
|
if n.checkPendingDeletion(vpnIp) {
|
||||||
cn := ""
|
cn := ""
|
||||||
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
|
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
|
||||||
cn = hostinfo.ConnectionState.peerCert.Details.Name
|
cn = hostinfo.ConnectionState.peerCert.Details.Name
|
||||||
|
@ -267,22 +269,22 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||||
WithField("certName", cn).
|
WithField("certName", cn).
|
||||||
Info("Tunnel status")
|
Info("Tunnel status")
|
||||||
|
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIp)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIp)
|
||||||
// TODO: This is only here to let tests work. Should do proper mocking
|
// TODO: This is only here to let tests work. Should do proper mocking
|
||||||
if n.intf.lightHouse != nil {
|
if n.intf.lightHouse != nil {
|
||||||
n.intf.lightHouse.DeleteVpnIP(vpnIP)
|
n.intf.lightHouse.DeleteVpnIp(vpnIp)
|
||||||
}
|
}
|
||||||
n.hostMap.DeleteHostInfo(hostinfo)
|
n.hostMap.DeleteHostInfo(hostinfo)
|
||||||
} else {
|
} else {
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIp)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
|
// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
|
||||||
func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool {
|
func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIp iputil.VpnIp, hostinfo *HostInfo) bool {
|
||||||
if !n.intf.disconnectInvalid {
|
if !n.intf.disconnectInvalid {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -298,7 +300,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
fingerprint, _ := remoteCert.Sha256Sum()
|
fingerprint, _ := remoteCert.Sha256Sum()
|
||||||
n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err).
|
n.l.WithField("vpnIp", vpnIp).WithError(err).
|
||||||
WithField("certName", remoteCert.Details.Name).
|
WithField("certName", remoteCert.Details.Name).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
@ -307,7 +309,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
|
||||||
n.intf.sendCloseTunnel(hostinfo)
|
n.intf.sendCloseTunnel(hostinfo)
|
||||||
n.intf.closeTunnel(hostinfo, false)
|
n.intf.closeTunnel(hostinfo, false)
|
||||||
|
|
||||||
n.ClearIP(vpnIP)
|
n.ClearIP(vpnIp)
|
||||||
n.ClearPendingDeletion(vpnIP)
|
n.ClearPendingDeletion(vpnIp)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,17 +10,20 @@ import (
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var vpnIP uint32
|
var vpnIp iputil.VpnIp
|
||||||
|
|
||||||
func Test_NewConnectionManagerTest(t *testing.T) {
|
func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
|
vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
|
|
||||||
// Very incomplete mock objects
|
// Very incomplete mock objects
|
||||||
|
@ -32,15 +35,15 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
rawCertificateNoKey: []byte{},
|
rawCertificateNoKey: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &Tun{},
|
inside: &Tun{},
|
||||||
outside: &udpConn{},
|
outside: &udp.Conn{},
|
||||||
certState: cs,
|
certState: cs,
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
@ -54,16 +57,16 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
nc.HandleMonitorTick(now, p, nb, out)
|
nc.HandleMonitorTick(now, p, nb, out)
|
||||||
// Add an ip we have established a connection w/ to hostmap
|
// Add an ip we have established a connection w/ to hostmap
|
||||||
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
certState: cs,
|
certState: cs,
|
||||||
H: &noise.HandshakeState{},
|
H: &noise.HandshakeState{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// We saw traffic out to vpnIP
|
// We saw traffic out to vpnIp
|
||||||
nc.Out(vpnIP)
|
nc.Out(vpnIp)
|
||||||
assert.NotContains(t, nc.pendingDeletion, vpnIP)
|
assert.NotContains(t, nc.pendingDeletion, vpnIp)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
|
||||||
// Move ahead 5s. Nothing should happen
|
// Move ahead 5s. Nothing should happen
|
||||||
next_tick := now.Add(5 * time.Second)
|
next_tick := now.Add(5 * time.Second)
|
||||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||||
|
@ -73,20 +76,20 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||||
nc.HandleDeletionTick(next_tick)
|
nc.HandleDeletionTick(next_tick)
|
||||||
// This host should now be up for deletion
|
// This host should now be up for deletion
|
||||||
assert.Contains(t, nc.pendingDeletion, vpnIP)
|
assert.Contains(t, nc.pendingDeletion, vpnIp)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
|
||||||
// Move ahead some more
|
// Move ahead some more
|
||||||
next_tick = now.Add(45 * time.Second)
|
next_tick = now.Add(45 * time.Second)
|
||||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||||
nc.HandleDeletionTick(next_tick)
|
nc.HandleDeletionTick(next_tick)
|
||||||
// The host should be evicted
|
// The host should be evicted
|
||||||
assert.NotContains(t, nc.pendingDeletion, vpnIP)
|
assert.NotContains(t, nc.pendingDeletion, vpnIp)
|
||||||
assert.NotContains(t, nc.hostMap.Hosts, vpnIP)
|
assert.NotContains(t, nc.hostMap.Hosts, vpnIp)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewConnectionManagerTest2(t *testing.T) {
|
func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
@ -101,15 +104,15 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
rawCertificateNoKey: []byte{},
|
rawCertificateNoKey: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &Tun{},
|
inside: &Tun{},
|
||||||
outside: &udpConn{},
|
outside: &udp.Conn{},
|
||||||
certState: cs,
|
certState: cs,
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
@ -123,16 +126,16 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
nc.HandleMonitorTick(now, p, nb, out)
|
nc.HandleMonitorTick(now, p, nb, out)
|
||||||
// Add an ip we have established a connection w/ to hostmap
|
// Add an ip we have established a connection w/ to hostmap
|
||||||
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
certState: cs,
|
certState: cs,
|
||||||
H: &noise.HandshakeState{},
|
H: &noise.HandshakeState{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// We saw traffic out to vpnIP
|
// We saw traffic out to vpnIp
|
||||||
nc.Out(vpnIP)
|
nc.Out(vpnIp)
|
||||||
assert.NotContains(t, nc.pendingDeletion, vpnIP)
|
assert.NotContains(t, nc.pendingDeletion, vpnIp)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
|
||||||
// Move ahead 5s. Nothing should happen
|
// Move ahead 5s. Nothing should happen
|
||||||
next_tick := now.Add(5 * time.Second)
|
next_tick := now.Add(5 * time.Second)
|
||||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||||
|
@ -142,17 +145,17 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||||
nc.HandleDeletionTick(next_tick)
|
nc.HandleDeletionTick(next_tick)
|
||||||
// This host should now be up for deletion
|
// This host should now be up for deletion
|
||||||
assert.Contains(t, nc.pendingDeletion, vpnIP)
|
assert.Contains(t, nc.pendingDeletion, vpnIp)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
|
||||||
// We heard back this time
|
// We heard back this time
|
||||||
nc.In(vpnIP)
|
nc.In(vpnIp)
|
||||||
// Move ahead some more
|
// Move ahead some more
|
||||||
next_tick = now.Add(45 * time.Second)
|
next_tick = now.Add(45 * time.Second)
|
||||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||||
nc.HandleDeletionTick(next_tick)
|
nc.HandleDeletionTick(next_tick)
|
||||||
// The host should be evicted
|
// The host should be evicted
|
||||||
assert.NotContains(t, nc.pendingDeletion, vpnIP)
|
assert.NotContains(t, nc.pendingDeletion, vpnIp)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
// Disconnect only if disconnectInvalid: true is set.
|
// Disconnect only if disconnectInvalid: true is set.
|
||||||
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
ipNet := net.IPNet{
|
ipNet := net.IPNet{
|
||||||
IP: net.IPv4(172, 1, 1, 2),
|
IP: net.IPv4(172, 1, 1, 2),
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
@ -210,15 +213,15 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||||
rawCertificateNoKey: []byte{},
|
rawCertificateNoKey: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &Tun{},
|
inside: &Tun{},
|
||||||
outside: &udpConn{},
|
outside: &udp.Conn{},
|
||||||
certState: cs,
|
certState: cs,
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
||||||
l: l,
|
l: l,
|
||||||
disconnectInvalid: true,
|
disconnectInvalid: true,
|
||||||
caPool: ncp,
|
caPool: ncp,
|
||||||
|
@ -229,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
nc := newConnectionManager(ctx, l, ifce, 5, 10)
|
nc := newConnectionManager(ctx, l, ifce, 5, 10)
|
||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
certState: cs,
|
certState: cs,
|
||||||
peerCert: &peerCert,
|
peerCert: &peerCert,
|
||||||
|
@ -240,13 +243,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||||
// Check if to disconnect with invalid certificate.
|
// Check if to disconnect with invalid certificate.
|
||||||
// Should be alive.
|
// Should be alive.
|
||||||
nextTick := now.Add(45 * time.Second)
|
nextTick := now.Add(45 * time.Second)
|
||||||
destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
|
destroyed := nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
|
||||||
assert.False(t, destroyed)
|
assert.False(t, destroyed)
|
||||||
|
|
||||||
// Move ahead 61s.
|
// Move ahead 61s.
|
||||||
// Check if to disconnect with invalid certificate.
|
// Check if to disconnect with invalid certificate.
|
||||||
// Should be disconnected.
|
// Should be disconnected.
|
||||||
nextTick = now.Add(61 * time.Second)
|
nextTick = now.Add(61 * time.Second)
|
||||||
destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
|
destroyed = nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
|
||||||
assert.True(t, destroyed)
|
assert.True(t, destroyed)
|
||||||
}
|
}
|
||||||
|
|
33
control.go
33
control.go
|
@ -10,6 +10,9 @@ import (
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
|
@ -25,14 +28,14 @@ type Control struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ControlHostInfo struct {
|
type ControlHostInfo struct {
|
||||||
VpnIP net.IP `json:"vpnIp"`
|
VpnIp net.IP `json:"vpnIp"`
|
||||||
LocalIndex uint32 `json:"localIndex"`
|
LocalIndex uint32 `json:"localIndex"`
|
||||||
RemoteIndex uint32 `json:"remoteIndex"`
|
RemoteIndex uint32 `json:"remoteIndex"`
|
||||||
RemoteAddrs []*udpAddr `json:"remoteAddrs"`
|
RemoteAddrs []*udp.Addr `json:"remoteAddrs"`
|
||||||
CachedPackets int `json:"cachedPackets"`
|
CachedPackets int `json:"cachedPackets"`
|
||||||
Cert *cert.NebulaCertificate `json:"cert"`
|
Cert *cert.NebulaCertificate `json:"cert"`
|
||||||
MessageCounter uint64 `json:"messageCounter"`
|
MessageCounter uint64 `json:"messageCounter"`
|
||||||
CurrentRemote *udpAddr `json:"currentRemote"`
|
CurrentRemote *udp.Addr `json:"currentRemote"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||||
|
@ -95,8 +98,8 @@ func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
|
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
|
||||||
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
|
func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
|
||||||
var hm *HostMap
|
var hm *HostMap
|
||||||
if pending {
|
if pending {
|
||||||
hm = c.f.handshakeManager.pendingHostMap
|
hm = c.f.handshakeManager.pendingHostMap
|
||||||
|
@ -104,7 +107,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
|
||||||
hm = c.f.hostMap
|
hm = c.f.hostMap
|
||||||
}
|
}
|
||||||
|
|
||||||
h, err := hm.QueryVpnIP(vpnIP)
|
h, err := hm.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -114,8 +117,8 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
||||||
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
|
func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
|
||||||
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -126,15 +129,15 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
||||||
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
|
func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
|
||||||
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !localOnly {
|
if !localOnly {
|
||||||
c.f.send(
|
c.f.send(
|
||||||
closeTunnel,
|
header.CloseTunnel,
|
||||||
0,
|
0,
|
||||||
hostInfo.ConnectionState,
|
hostInfo.ConnectionState,
|
||||||
hostInfo,
|
hostInfo,
|
||||||
|
@ -156,16 +159,16 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
||||||
c.f.hostMap.Lock()
|
c.f.hostMap.Lock()
|
||||||
for _, h := range c.f.hostMap.Hosts {
|
for _, h := range c.f.hostMap.Hosts {
|
||||||
if excludeLighthouses {
|
if excludeLighthouses {
|
||||||
if _, ok := c.f.lightHouse.lighthouses[h.hostId]; ok {
|
if _, ok := c.f.lightHouse.lighthouses[h.vpnIp]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.ConnectionState.ready {
|
if h.ConnectionState.ready {
|
||||||
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
c.f.closeTunnel(h, true)
|
c.f.closeTunnel(h, true)
|
||||||
|
|
||||||
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
|
c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
|
||||||
Debug("Sending close tunnel message")
|
Debug("Sending close tunnel message")
|
||||||
closed++
|
closed++
|
||||||
}
|
}
|
||||||
|
@ -176,7 +179,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
||||||
|
|
||||||
func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
|
func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
|
||||||
chi := ControlHostInfo{
|
chi := ControlHostInfo{
|
||||||
VpnIP: int2ip(h.hostId),
|
VpnIp: h.vpnIp.ToIP(),
|
||||||
LocalIndex: h.localIndexId,
|
LocalIndex: h.localIndexId,
|
||||||
RemoteIndex: h.remoteIndexId,
|
RemoteIndex: h.remoteIndexId,
|
||||||
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
|
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
|
||||||
|
|
|
@ -8,17 +8,19 @@ import (
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||||
// To properly ensure we are not exposing core memory to the caller
|
// To properly ensure we are not exposing core memory to the caller
|
||||||
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
|
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||||
remote1 := NewUDPAddr(int2ip(100), 4444)
|
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
|
||||||
remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
|
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
|
||||||
ipNet := net.IPNet{
|
ipNet := net.IPNet{
|
||||||
IP: net.IPv4(1, 2, 3, 4),
|
IP: net.IPv4(1, 2, 3, 4),
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
@ -48,7 +50,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||||
remotes := NewRemoteList()
|
remotes := NewRemoteList()
|
||||||
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
|
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
|
||||||
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
|
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
|
||||||
hm.Add(ip2int(ipNet.IP), &HostInfo{
|
hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
|
||||||
remote: remote1,
|
remote: remote1,
|
||||||
remotes: remotes,
|
remotes: remotes,
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
|
@ -56,10 +58,10 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||||
},
|
},
|
||||||
remoteIndexId: 200,
|
remoteIndexId: 200,
|
||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
hostId: ip2int(ipNet.IP),
|
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||||
})
|
})
|
||||||
|
|
||||||
hm.Add(ip2int(ipNet2.IP), &HostInfo{
|
hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
|
||||||
remote: remote1,
|
remote: remote1,
|
||||||
remotes: remotes,
|
remotes: remotes,
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
|
@ -67,7 +69,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||||
},
|
},
|
||||||
remoteIndexId: 200,
|
remoteIndexId: 200,
|
||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
hostId: ip2int(ipNet2.IP),
|
vpnIp: iputil.Ip2VpnIp(ipNet2.IP),
|
||||||
})
|
})
|
||||||
|
|
||||||
c := Control{
|
c := Control{
|
||||||
|
@ -77,26 +79,26 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||||
l: logrus.New(),
|
l: logrus.New(),
|
||||||
}
|
}
|
||||||
|
|
||||||
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
|
thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
|
||||||
|
|
||||||
expectedInfo := ControlHostInfo{
|
expectedInfo := ControlHostInfo{
|
||||||
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
|
VpnIp: net.IPv4(1, 2, 3, 4).To4(),
|
||||||
LocalIndex: 201,
|
LocalIndex: 201,
|
||||||
RemoteIndex: 200,
|
RemoteIndex: 200,
|
||||||
RemoteAddrs: []*udpAddr{remote2, remote1},
|
RemoteAddrs: []*udp.Addr{remote2, remote1},
|
||||||
CachedPackets: 0,
|
CachedPackets: 0,
|
||||||
Cert: crt.Copy(),
|
Cert: crt.Copy(),
|
||||||
MessageCounter: 0,
|
MessageCounter: 0,
|
||||||
CurrentRemote: NewUDPAddr(int2ip(100), 4444),
|
CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we don't have any unexpected fields
|
// Make sure we don't have any unexpected fields
|
||||||
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
|
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
|
||||||
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||||
|
|
||||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
|
thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,12 +8,15 @@ import (
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device
|
// WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device
|
||||||
// returning after a message matching the criteria has been piped
|
// returning after a message matching the criteria has been piped
|
||||||
func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
|
func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
|
||||||
h := &Header{}
|
h := &header.H{}
|
||||||
for {
|
for {
|
||||||
p := c.f.outside.Get(true)
|
p := c.f.outside.Get(true)
|
||||||
if err := h.Parse(p.Data); err != nil {
|
if err := h.Parse(p.Data); err != nil {
|
||||||
|
@ -28,8 +31,8 @@ func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSu
|
||||||
|
|
||||||
// WaitForTypeByIndex is similar to WaitForType except it adds an index check
|
// WaitForTypeByIndex is similar to WaitForType except it adds an index check
|
||||||
// Useful if you have many nodes communicating and want to wait to find a specific nodes packet
|
// Useful if you have many nodes communicating and want to wait to find a specific nodes packet
|
||||||
func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
|
func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
|
||||||
h := &Header{}
|
h := &header.H{}
|
||||||
for {
|
for {
|
||||||
p := c.f.outside.Get(true)
|
p := c.f.outside.Get(true)
|
||||||
if err := h.Parse(p.Data); err != nil {
|
if err := h.Parse(p.Data); err != nil {
|
||||||
|
@ -46,12 +49,12 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
|
||||||
// This is necessary if you did not configure static hosts or are not running a lighthouse
|
// This is necessary if you did not configure static hosts or are not running a lighthouse
|
||||||
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
|
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
|
||||||
c.f.lightHouse.Lock()
|
c.f.lightHouse.Lock()
|
||||||
remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
|
remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
|
||||||
remoteList.Lock()
|
remoteList.Lock()
|
||||||
defer remoteList.Unlock()
|
defer remoteList.Unlock()
|
||||||
c.f.lightHouse.Unlock()
|
c.f.lightHouse.Unlock()
|
||||||
|
|
||||||
iVpnIp := ip2int(vpnIp)
|
iVpnIp := iputil.Ip2VpnIp(vpnIp)
|
||||||
if v4 := toAddr.IP.To4(); v4 != nil {
|
if v4 := toAddr.IP.To4(); v4 != nil {
|
||||||
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
|
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
|
||||||
} else {
|
} else {
|
||||||
|
@ -65,12 +68,12 @@ func (c *Control) GetFromTun(block bool) []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFromUDP will pull a udp packet off the udp side of nebula
|
// GetFromUDP will pull a udp packet off the udp side of nebula
|
||||||
func (c *Control) GetFromUDP(block bool) *UdpPacket {
|
func (c *Control) GetFromUDP(block bool) *udp.Packet {
|
||||||
return c.f.outside.Get(block)
|
return c.f.outside.Get(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetUDPTxChan() <-chan *UdpPacket {
|
func (c *Control) GetUDPTxChan() <-chan *udp.Packet {
|
||||||
return c.f.outside.txPackets
|
return c.f.outside.TxPackets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetTunTxChan() <-chan []byte {
|
func (c *Control) GetTunTxChan() <-chan []byte {
|
||||||
|
@ -78,7 +81,7 @@ func (c *Control) GetTunTxChan() <-chan []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
// InjectUDPPacket will inject a packet into the udp side of nebula
|
// InjectUDPPacket will inject a packet into the udp side of nebula
|
||||||
func (c *Control) InjectUDPPacket(p *UdpPacket) {
|
func (c *Control) InjectUDPPacket(p *udp.Packet) {
|
||||||
c.f.outside.Send(p)
|
c.f.outside.Send(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,11 +118,11 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetUDPAddr() string {
|
func (c *Control) GetUDPAddr() string {
|
||||||
return c.f.outside.addr.String()
|
return c.f.outside.Addr.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
|
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
|
||||||
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
|
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This whole thing should be rewritten to use context
|
// This whole thing should be rewritten to use context
|
||||||
|
@ -44,8 +46,8 @@ func (d *dnsRecords) QueryCert(data string) string {
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
iip := ip2int(ip)
|
iip := iputil.Ip2VpnIp(ip)
|
||||||
hostinfo, err := d.hostMap.QueryVpnIP(iip)
|
hostinfo, err := d.hostMap.QueryVpnIp(iip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -109,7 +111,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
|
||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
|
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
|
||||||
dnsR = newDnsRecords(hostMap)
|
dnsR = newDnsRecords(hostMap)
|
||||||
|
|
||||||
// attach request handler func
|
// attach request handler func
|
||||||
|
@ -117,7 +119,7 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
|
||||||
handleDnsRequest(l, w, r)
|
handleDnsRequest(l, w, r)
|
||||||
})
|
})
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *Config) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
reloadDns(l, c)
|
reloadDns(l, c)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -126,11 +128,11 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDnsServerAddr(c *Config) string {
|
func getDnsServerAddr(c *config.C) string {
|
||||||
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
|
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
|
||||||
}
|
}
|
||||||
|
|
||||||
func startDns(l *logrus.Logger, c *Config) {
|
func startDns(l *logrus.Logger, c *config.C) {
|
||||||
dnsAddr = getDnsServerAddr(c)
|
dnsAddr = getDnsServerAddr(c)
|
||||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
||||||
l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder")
|
l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder")
|
||||||
|
@ -141,7 +143,7 @@ func startDns(l *logrus.Logger, c *Config) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func reloadDns(l *logrus.Logger, c *Config) {
|
func reloadDns(l *logrus.Logger, c *config.C) {
|
||||||
if dnsAddr == getDnsServerAddr(c) {
|
if dnsAddr == getDnsServerAddr(c) {
|
||||||
l.Debug("No DNS server config change detected")
|
l.Debug("No DNS server config change detected")
|
||||||
return
|
return
|
||||||
|
|
|
@ -10,6 +10,9 @@ import (
|
||||||
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,7 +40,7 @@ func TestGoodHandshake(t *testing.T) {
|
||||||
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
|
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
|
||||||
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
|
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
|
||||||
badPacket := stage1Packet.Copy()
|
badPacket := stage1Packet.Copy()
|
||||||
badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
|
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
|
||||||
myControl.InjectUDPPacket(badPacket)
|
myControl.InjectUDPPacket(badPacket)
|
||||||
|
|
||||||
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
|
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
|
||||||
|
@ -87,8 +90,8 @@ func TestWrongResponderHandshake(t *testing.T) {
|
||||||
|
|
||||||
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
|
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
|
||||||
r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType {
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
h := &nebula.Header{}
|
h := &header.H{}
|
||||||
err := h.Parse(p.Data)
|
err := h.Parse(p.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -115,8 +118,8 @@ func TestWrongResponderHandshake(t *testing.T) {
|
||||||
r.FlushAll()
|
r.FlushAll()
|
||||||
|
|
||||||
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
|
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
|
||||||
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil")
|
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil")
|
||||||
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil")
|
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil")
|
||||||
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
|
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
|
||||||
|
|
||||||
//TODO: assert hostmaps for everyone
|
//TODO: assert hostmaps for everyone
|
||||||
|
|
|
@ -5,7 +5,6 @@ package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -19,7 +18,9 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/crypto/curve25519"
|
"golang.org/x/crypto/curve25519"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
|
@ -82,10 +83,10 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config := nebula.NewConfig(l)
|
c := config.NewC(l)
|
||||||
config.LoadString(string(cb))
|
c.LoadString(string(cb))
|
||||||
|
|
||||||
control, err := nebula.Main(config, false, "e2e-test", l, nil)
|
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -200,19 +201,6 @@ func x25519Keypair() ([]byte, []byte) {
|
||||||
return pubkey, privkey
|
return pubkey, privkey
|
||||||
}
|
}
|
||||||
|
|
||||||
func ip2int(ip []byte) uint32 {
|
|
||||||
if len(ip) == 16 {
|
|
||||||
return binary.BigEndian.Uint32(ip[12:16])
|
|
||||||
}
|
|
||||||
return binary.BigEndian.Uint32(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
func int2ip(nn uint32) net.IP {
|
|
||||||
ip := make(net.IP, 4)
|
|
||||||
binary.BigEndian.PutUint32(ip, nn)
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
|
|
||||||
type doneCb func()
|
type doneCb func()
|
||||||
|
|
||||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||||
|
@ -245,15 +233,15 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
|
||||||
|
|
||||||
func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
|
func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
|
||||||
// Get both host infos
|
// Get both host infos
|
||||||
hBinA := controlA.GetHostInfoByVpnIP(ip2int(vpnIpB), false)
|
hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
|
||||||
assert.NotNil(t, hBinA, "Host B was not found by vpnIP in controlA")
|
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
|
||||||
|
|
||||||
hAinB := controlB.GetHostInfoByVpnIP(ip2int(vpnIpA), false)
|
hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
|
||||||
assert.NotNil(t, hAinB, "Host A was not found by vpnIP in controlB")
|
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
|
||||||
|
|
||||||
// Check that both vpn and real addr are correct
|
// Check that both vpn and real addr are correct
|
||||||
assert.Equal(t, vpnIpB, hBinA.VpnIP, "Host B VpnIp is wrong in control A")
|
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
|
||||||
assert.Equal(t, vpnIpA, hAinB.VpnIP, "Host A VpnIp is wrong in control B")
|
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
|
||||||
|
|
||||||
assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
|
assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
|
||||||
assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
|
assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
|
||||||
|
|
|
@ -11,6 +11,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type R struct {
|
type R struct {
|
||||||
|
@ -41,7 +43,7 @@ const (
|
||||||
RouteAndExit ExitType = 2
|
RouteAndExit ExitType = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
|
type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
|
||||||
|
|
||||||
func NewR(controls ...*nebula.Control) *R {
|
func NewR(controls ...*nebula.Control) *R {
|
||||||
r := &R{
|
r := &R{
|
||||||
|
@ -79,7 +81,7 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
|
||||||
// OnceFrom will route a single packet from sender then return
|
// OnceFrom will route a single packet from sender then return
|
||||||
// If the router doesn't have the nebula controller for that address, we panic
|
// If the router doesn't have the nebula controller for that address, we panic
|
||||||
func (r *R) OnceFrom(sender *nebula.Control) {
|
func (r *R) OnceFrom(sender *nebula.Control) {
|
||||||
r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
|
r.RouteExitFunc(sender, func(*udp.Packet, *nebula.Control) ExitType {
|
||||||
return RouteAndExit
|
return RouteAndExit
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -119,7 +121,7 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
|
||||||
// - routeAndExit: this call will return immediately after routing the last packet from sender
|
// - routeAndExit: this call will return immediately after routing the last packet from sender
|
||||||
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
|
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
|
||||||
func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
|
func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
|
||||||
h := &nebula.Header{}
|
h := &header.H{}
|
||||||
for {
|
for {
|
||||||
p := sender.GetFromUDP(true)
|
p := sender.GetFromUDP(true)
|
||||||
r.Lock()
|
r.Lock()
|
||||||
|
@ -159,9 +161,9 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
|
||||||
|
|
||||||
// RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
|
// RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
|
||||||
// If the router doesn't have the nebula controller for that address, we panic
|
// If the router doesn't have the nebula controller for that address, we panic
|
||||||
func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
|
func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType header.MessageType, subType header.MessageSubType) {
|
||||||
h := &nebula.Header{}
|
h := &header.H{}
|
||||||
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
|
r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
|
||||||
if err := h.Parse(p.Data); err != nil {
|
if err := h.Parse(p.Data); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -181,7 +183,7 @@ func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr
|
||||||
finish = RouteAndExit
|
finish = RouteAndExit
|
||||||
}
|
}
|
||||||
|
|
||||||
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
|
r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
|
||||||
if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
|
if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
|
||||||
return finish
|
return finish
|
||||||
}
|
}
|
||||||
|
@ -215,7 +217,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
|
||||||
x, rx, _ := reflect.Select(sc)
|
x, rx, _ := reflect.Select(sc)
|
||||||
r.Lock()
|
r.Lock()
|
||||||
|
|
||||||
p := rx.Interface().(*nebula.UdpPacket)
|
p := rx.Interface().(*udp.Packet)
|
||||||
|
|
||||||
outAddr := cm[x].GetUDPAddr()
|
outAddr := cm[x].GetUDPAddr()
|
||||||
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
|
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
|
||||||
|
@ -277,7 +279,7 @@ func (r *R) FlushAll() {
|
||||||
}
|
}
|
||||||
r.Lock()
|
r.Lock()
|
||||||
|
|
||||||
p := rx.Interface().(*nebula.UdpPacket)
|
p := rx.Interface().(*udp.Packet)
|
||||||
|
|
||||||
outAddr := cm[x].GetUDPAddr()
|
outAddr := cm[x].GetUDPAddr()
|
||||||
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
|
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
|
||||||
|
@ -292,7 +294,7 @@ func (r *R) FlushAll() {
|
||||||
|
|
||||||
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
|
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
|
||||||
// This is an internal router function, the caller must hold the lock
|
// This is an internal router function, the caller must hold the lock
|
||||||
func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
|
func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
|
||||||
if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
|
if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
|
||||||
p.FromIp = newAddr.IP
|
p.FromIp = newAddr.IP
|
||||||
p.FromPort = uint16(newAddr.Port)
|
p.FromPort = uint16(newAddr.Port)
|
||||||
|
|
192
firewall.go
192
firewall.go
|
@ -4,7 +4,6 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -12,22 +11,14 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
"github.com/slackhq/nebula/cidr"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
const (
|
"github.com/slackhq/nebula/firewall"
|
||||||
fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
|
||||||
fwProtoTCP = 6
|
|
||||||
fwProtoUDP = 17
|
|
||||||
fwProtoICMP = 1
|
|
||||||
|
|
||||||
fwPortAny = 0 // Special value for matching `port: any`
|
|
||||||
fwPortFragment = -1 // Special value for matching `port: fragment`
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const tcpACK = 0x10
|
const tcpACK = 0x10
|
||||||
|
@ -63,7 +54,7 @@ type Firewall struct {
|
||||||
DefaultTimeout time.Duration //linux: 600s
|
DefaultTimeout time.Duration //linux: 600s
|
||||||
|
|
||||||
// Used to ensure we don't emit local packets for ips we don't own
|
// Used to ensure we don't emit local packets for ips we don't own
|
||||||
localIps *CIDRTree
|
localIps *cidr.Tree4
|
||||||
|
|
||||||
rules string
|
rules string
|
||||||
rulesVersion uint16
|
rulesVersion uint16
|
||||||
|
@ -85,7 +76,7 @@ type firewallMetrics struct {
|
||||||
type FirewallConntrack struct {
|
type FirewallConntrack struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
Conns map[FirewallPacket]*conn
|
Conns map[firewall.Packet]*conn
|
||||||
TimerWheel *TimerWheel
|
TimerWheel *TimerWheel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,55 +107,13 @@ type FirewallRule struct {
|
||||||
Any bool
|
Any bool
|
||||||
Hosts map[string]struct{}
|
Hosts map[string]struct{}
|
||||||
Groups [][]string
|
Groups [][]string
|
||||||
CIDR *CIDRTree
|
CIDR *cidr.Tree4
|
||||||
}
|
}
|
||||||
|
|
||||||
// Even though ports are uint16, int32 maps are faster for lookup
|
// Even though ports are uint16, int32 maps are faster for lookup
|
||||||
// Plus we can use `-1` for fragment rules
|
// Plus we can use `-1` for fragment rules
|
||||||
type firewallPort map[int32]*FirewallCA
|
type firewallPort map[int32]*FirewallCA
|
||||||
|
|
||||||
type FirewallPacket struct {
|
|
||||||
LocalIP uint32
|
|
||||||
RemoteIP uint32
|
|
||||||
LocalPort uint16
|
|
||||||
RemotePort uint16
|
|
||||||
Protocol uint8
|
|
||||||
Fragment bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fp *FirewallPacket) Copy() *FirewallPacket {
|
|
||||||
return &FirewallPacket{
|
|
||||||
LocalIP: fp.LocalIP,
|
|
||||||
RemoteIP: fp.RemoteIP,
|
|
||||||
LocalPort: fp.LocalPort,
|
|
||||||
RemotePort: fp.RemotePort,
|
|
||||||
Protocol: fp.Protocol,
|
|
||||||
Fragment: fp.Fragment,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
|
|
||||||
var proto string
|
|
||||||
switch fp.Protocol {
|
|
||||||
case fwProtoTCP:
|
|
||||||
proto = "tcp"
|
|
||||||
case fwProtoICMP:
|
|
||||||
proto = "icmp"
|
|
||||||
case fwProtoUDP:
|
|
||||||
proto = "udp"
|
|
||||||
default:
|
|
||||||
proto = fmt.Sprintf("unknown %v", fp.Protocol)
|
|
||||||
}
|
|
||||||
return json.Marshal(m{
|
|
||||||
"LocalIP": int2ip(fp.LocalIP).String(),
|
|
||||||
"RemoteIP": int2ip(fp.RemoteIP).String(),
|
|
||||||
"LocalPort": fp.LocalPort,
|
|
||||||
"RemotePort": fp.RemotePort,
|
|
||||||
"Protocol": proto,
|
|
||||||
"Fragment": fp.Fragment,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||||
//TODO: error on 0 duration
|
//TODO: error on 0 duration
|
||||||
|
@ -184,7 +133,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||||
max = defaultTimeout
|
max = defaultTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
localIps := NewCIDRTree()
|
localIps := cidr.NewTree4()
|
||||||
for _, ip := range c.Details.Ips {
|
for _, ip := range c.Details.Ips {
|
||||||
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||||
}
|
}
|
||||||
|
@ -195,7 +144,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||||
|
|
||||||
return &Firewall{
|
return &Firewall{
|
||||||
Conntrack: &FirewallConntrack{
|
Conntrack: &FirewallConntrack{
|
||||||
Conns: make(map[FirewallPacket]*conn),
|
Conns: make(map[firewall.Packet]*conn),
|
||||||
TimerWheel: NewTimerWheel(min, max),
|
TimerWheel: NewTimerWheel(min, max),
|
||||||
},
|
},
|
||||||
InRules: newFirewallTable(),
|
InRules: newFirewallTable(),
|
||||||
|
@ -220,7 +169,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
|
||||||
fw := NewFirewall(
|
fw := NewFirewall(
|
||||||
l,
|
l,
|
||||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||||
|
@ -278,13 +227,13 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||||
}
|
}
|
||||||
|
|
||||||
switch proto {
|
switch proto {
|
||||||
case fwProtoTCP:
|
case firewall.ProtoTCP:
|
||||||
fp = ft.TCP
|
fp = ft.TCP
|
||||||
case fwProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
fp = ft.UDP
|
fp = ft.UDP
|
||||||
case fwProtoICMP:
|
case firewall.ProtoICMP:
|
||||||
fp = ft.ICMP
|
fp = ft.ICMP
|
||||||
case fwProtoAny:
|
case firewall.ProtoAny:
|
||||||
fp = ft.AnyProto
|
fp = ft.AnyProto
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown protocol %v", proto)
|
return fmt.Errorf("unknown protocol %v", proto)
|
||||||
|
@ -299,7 +248,7 @@ func (f *Firewall) GetRuleHash() string {
|
||||||
return hex.EncodeToString(sum[:])
|
return hex.EncodeToString(sum[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
|
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||||
var table string
|
var table string
|
||||||
if inbound {
|
if inbound {
|
||||||
table = "firewall.inbound"
|
table = "firewall.inbound"
|
||||||
|
@ -307,7 +256,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
|
||||||
table = "firewall.outbound"
|
table = "firewall.outbound"
|
||||||
}
|
}
|
||||||
|
|
||||||
r := config.Get(table)
|
r := c.Get(table)
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -362,13 +311,13 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
|
||||||
var proto uint8
|
var proto uint8
|
||||||
switch r.Proto {
|
switch r.Proto {
|
||||||
case "any":
|
case "any":
|
||||||
proto = fwProtoAny
|
proto = firewall.ProtoAny
|
||||||
case "tcp":
|
case "tcp":
|
||||||
proto = fwProtoTCP
|
proto = firewall.ProtoTCP
|
||||||
case "udp":
|
case "udp":
|
||||||
proto = fwProtoUDP
|
proto = firewall.ProtoUDP
|
||||||
case "icmp":
|
case "icmp":
|
||||||
proto = fwProtoICMP
|
proto = firewall.ProtoICMP
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||||
}
|
}
|
||||||
|
@ -396,7 +345,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||||
|
|
||||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||||
// returns nil if the packet should not be dropped.
|
// returns nil if the packet should not be dropped.
|
||||||
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
|
func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
|
||||||
// Check if we spoke to this tuple, if we did then allow this packet
|
// Check if we spoke to this tuple, if we did then allow this packet
|
||||||
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
|
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
|
||||||
return nil
|
return nil
|
||||||
|
@ -410,7 +359,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Simple case: Certificate has one IP and no subnets
|
// Simple case: Certificate has one IP and no subnets
|
||||||
if fp.RemoteIP != h.hostId {
|
if fp.RemoteIP != h.vpnIp {
|
||||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
|
@ -462,7 +411,7 @@ func (f *Firewall) EmitStats() {
|
||||||
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
|
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
|
func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
|
||||||
if localCache != nil {
|
if localCache != nil {
|
||||||
if _, ok := localCache[fp]; ok {
|
if _, ok := localCache[fp]; ok {
|
||||||
return true
|
return true
|
||||||
|
@ -520,14 +469,14 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
||||||
}
|
}
|
||||||
|
|
||||||
switch fp.Protocol {
|
switch fp.Protocol {
|
||||||
case fwProtoTCP:
|
case firewall.ProtoTCP:
|
||||||
c.Expires = time.Now().Add(f.TCPTimeout)
|
c.Expires = time.Now().Add(f.TCPTimeout)
|
||||||
if incoming {
|
if incoming {
|
||||||
f.checkTCPRTT(c, packet)
|
f.checkTCPRTT(c, packet)
|
||||||
} else {
|
} else {
|
||||||
setTCPRTTTracking(c, packet)
|
setTCPRTTTracking(c, packet)
|
||||||
}
|
}
|
||||||
case fwProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
c.Expires = time.Now().Add(f.UDPTimeout)
|
c.Expires = time.Now().Add(f.UDPTimeout)
|
||||||
default:
|
default:
|
||||||
c.Expires = time.Now().Add(f.DefaultTimeout)
|
c.Expires = time.Now().Add(f.DefaultTimeout)
|
||||||
|
@ -542,17 +491,17 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
c := &conn{}
|
c := &conn{}
|
||||||
|
|
||||||
switch fp.Protocol {
|
switch fp.Protocol {
|
||||||
case fwProtoTCP:
|
case firewall.ProtoTCP:
|
||||||
timeout = f.TCPTimeout
|
timeout = f.TCPTimeout
|
||||||
if !incoming {
|
if !incoming {
|
||||||
setTCPRTTTracking(c, packet)
|
setTCPRTTTracking(c, packet)
|
||||||
}
|
}
|
||||||
case fwProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
timeout = f.UDPTimeout
|
timeout = f.UDPTimeout
|
||||||
default:
|
default:
|
||||||
timeout = f.DefaultTimeout
|
timeout = f.DefaultTimeout
|
||||||
|
@ -575,7 +524,7 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
||||||
|
|
||||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||||
// Caller must own the connMutex lock!
|
// Caller must own the connMutex lock!
|
||||||
func (f *Firewall) evict(p FirewallPacket) {
|
func (f *Firewall) evict(p firewall.Packet) {
|
||||||
//TODO: report a stat if the tcp rtt tracking was never resolved?
|
//TODO: report a stat if the tcp rtt tracking was never resolved?
|
||||||
// Are we still tracking this conn?
|
// Are we still tracking this conn?
|
||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
|
@ -596,21 +545,21 @@ func (f *Firewall) evict(p FirewallPacket) {
|
||||||
delete(conntrack.Conns, p)
|
delete(conntrack.Conns, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||||
if ft.AnyProto.match(p, incoming, c, caPool) {
|
if ft.AnyProto.match(p, incoming, c, caPool) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
switch p.Protocol {
|
switch p.Protocol {
|
||||||
case fwProtoTCP:
|
case firewall.ProtoTCP:
|
||||||
if ft.TCP.match(p, incoming, c, caPool) {
|
if ft.TCP.match(p, incoming, c, caPool) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
case fwProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
if ft.UDP.match(p, incoming, c, caPool) {
|
if ft.UDP.match(p, incoming, c, caPool) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
case fwProtoICMP:
|
case firewall.ProtoICMP:
|
||||||
if ft.ICMP.match(p, incoming, c, caPool) {
|
if ft.ICMP.match(p, incoming, c, caPool) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -640,7 +589,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||||
// We don't have any allowed ports, bail
|
// We don't have any allowed ports, bail
|
||||||
if fp == nil {
|
if fp == nil {
|
||||||
return false
|
return false
|
||||||
|
@ -649,7 +598,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
|
||||||
var port int32
|
var port int32
|
||||||
|
|
||||||
if p.Fragment {
|
if p.Fragment {
|
||||||
port = fwPortFragment
|
port = firewall.PortFragment
|
||||||
} else if incoming {
|
} else if incoming {
|
||||||
port = int32(p.LocalPort)
|
port = int32(p.LocalPort)
|
||||||
} else {
|
} else {
|
||||||
|
@ -660,7 +609,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return fp[fwPortAny].match(p, c, caPool)
|
return fp[firewall.PortAny].match(p, c, caPool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
|
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
|
||||||
|
@ -668,7 +617,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
|
||||||
return &FirewallRule{
|
return &FirewallRule{
|
||||||
Hosts: make(map[string]struct{}),
|
Hosts: make(map[string]struct{}),
|
||||||
Groups: make([][]string, 0),
|
Groups: make([][]string, 0),
|
||||||
CIDR: NewCIDRTree(),
|
CIDR: cidr.NewTree4(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -703,7 +652,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||||
if fc == nil {
|
if fc == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -736,7 +685,7 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
|
||||||
// If it's any we need to wipe out any pre-existing rules to save on memory
|
// If it's any we need to wipe out any pre-existing rules to save on memory
|
||||||
fr.Groups = make([][]string, 0)
|
fr.Groups = make([][]string, 0)
|
||||||
fr.Hosts = make(map[string]struct{})
|
fr.Hosts = make(map[string]struct{})
|
||||||
fr.CIDR = NewCIDRTree()
|
fr.CIDR = cidr.NewTree4()
|
||||||
} else {
|
} else {
|
||||||
if len(groups) > 0 {
|
if len(groups) > 0 {
|
||||||
fr.Groups = append(fr.Groups, groups)
|
fr.Groups = append(fr.Groups, groups)
|
||||||
|
@ -776,7 +725,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
|
func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
|
||||||
if fr == nil {
|
if fr == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -885,12 +834,12 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
|
||||||
|
|
||||||
func parsePort(s string) (startPort, endPort int32, err error) {
|
func parsePort(s string) (startPort, endPort int32, err error) {
|
||||||
if s == "any" {
|
if s == "any" {
|
||||||
startPort = fwPortAny
|
startPort = firewall.PortAny
|
||||||
endPort = fwPortAny
|
endPort = firewall.PortAny
|
||||||
|
|
||||||
} else if s == "fragment" {
|
} else if s == "fragment" {
|
||||||
startPort = fwPortFragment
|
startPort = firewall.PortFragment
|
||||||
endPort = fwPortFragment
|
endPort = firewall.PortFragment
|
||||||
|
|
||||||
} else if strings.Contains(s, `-`) {
|
} else if strings.Contains(s, `-`) {
|
||||||
sPorts := strings.SplitN(s, `-`, 2)
|
sPorts := strings.SplitN(s, `-`, 2)
|
||||||
|
@ -914,8 +863,8 @@ func parsePort(s string) (startPort, endPort int32, err error) {
|
||||||
startPort = int32(rStartPort)
|
startPort = int32(rStartPort)
|
||||||
endPort = int32(rEndPort)
|
endPort = int32(rEndPort)
|
||||||
|
|
||||||
if startPort == fwPortAny {
|
if startPort == firewall.PortAny {
|
||||||
endPort = fwPortAny
|
endPort = firewall.PortAny
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
@ -968,54 +917,3 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
|
||||||
c.Seq = 0
|
c.Seq = 0
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
|
||||||
// has been seen in the conntrack table.
|
|
||||||
type ConntrackCache map[FirewallPacket]struct{}
|
|
||||||
|
|
||||||
type ConntrackCacheTicker struct {
|
|
||||||
cacheV uint64
|
|
||||||
cacheTick uint64
|
|
||||||
|
|
||||||
cache ConntrackCache
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
|
||||||
if d == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &ConntrackCacheTicker{
|
|
||||||
cache: ConntrackCache{},
|
|
||||||
}
|
|
||||||
|
|
||||||
go c.tick(d)
|
|
||||||
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
|
||||||
for {
|
|
||||||
time.Sleep(d)
|
|
||||||
atomic.AddUint64(&c.cacheTick, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get checks if the cache ticker has moved to the next version before returning
|
|
||||||
// the map. If it has moved, we reset the map.
|
|
||||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
|
||||||
if c == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
|
|
||||||
c.cacheV = tick
|
|
||||||
if ll := len(c.cache); ll > 0 {
|
|
||||||
if l.Level == logrus.DebugLevel {
|
|
||||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
|
||||||
}
|
|
||||||
c.cache = make(ConntrackCache, ll)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.cache
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||||
|
// has been seen in the conntrack table.
|
||||||
|
type ConntrackCache map[Packet]struct{}
|
||||||
|
|
||||||
|
type ConntrackCacheTicker struct {
|
||||||
|
cacheV uint64
|
||||||
|
cacheTick uint64
|
||||||
|
|
||||||
|
cache ConntrackCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
||||||
|
if d == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &ConntrackCacheTicker{
|
||||||
|
cache: ConntrackCache{},
|
||||||
|
}
|
||||||
|
|
||||||
|
go c.tick(d)
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
||||||
|
for {
|
||||||
|
time.Sleep(d)
|
||||||
|
atomic.AddUint64(&c.cacheTick, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get checks if the cache ticker has moved to the next version before returning
|
||||||
|
// the map. If it has moved, we reset the map.
|
||||||
|
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
|
||||||
|
c.cacheV = tick
|
||||||
|
if ll := len(c.cache); ll > 0 {
|
||||||
|
if l.Level == logrus.DebugLevel {
|
||||||
|
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||||
|
}
|
||||||
|
c.cache = make(ConntrackCache, ll)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.cache
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type m map[string]interface{}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
||||||
|
ProtoTCP = 6
|
||||||
|
ProtoUDP = 17
|
||||||
|
ProtoICMP = 1
|
||||||
|
|
||||||
|
PortAny = 0 // Special value for matching `port: any`
|
||||||
|
PortFragment = -1 // Special value for matching `port: fragment`
|
||||||
|
)
|
||||||
|
|
||||||
|
type Packet struct {
|
||||||
|
LocalIP iputil.VpnIp
|
||||||
|
RemoteIP iputil.VpnIp
|
||||||
|
LocalPort uint16
|
||||||
|
RemotePort uint16
|
||||||
|
Protocol uint8
|
||||||
|
Fragment bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fp *Packet) Copy() *Packet {
|
||||||
|
return &Packet{
|
||||||
|
LocalIP: fp.LocalIP,
|
||||||
|
RemoteIP: fp.RemoteIP,
|
||||||
|
LocalPort: fp.LocalPort,
|
||||||
|
RemotePort: fp.RemotePort,
|
||||||
|
Protocol: fp.Protocol,
|
||||||
|
Fragment: fp.Fragment,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fp Packet) MarshalJSON() ([]byte, error) {
|
||||||
|
var proto string
|
||||||
|
switch fp.Protocol {
|
||||||
|
case ProtoTCP:
|
||||||
|
proto = "tcp"
|
||||||
|
case ProtoICMP:
|
||||||
|
proto = "icmp"
|
||||||
|
case ProtoUDP:
|
||||||
|
proto = "udp"
|
||||||
|
default:
|
||||||
|
proto = fmt.Sprintf("unknown %v", fp.Protocol)
|
||||||
|
}
|
||||||
|
return json.Marshal(m{
|
||||||
|
"LocalIP": fp.LocalIP.String(),
|
||||||
|
"RemoteIP": fp.RemoteIP.String(),
|
||||||
|
"LocalPort": fp.LocalPort,
|
||||||
|
"RemotePort": fp.RemotePort,
|
||||||
|
"Protocol": proto,
|
||||||
|
"Fragment": fp.Fragment,
|
||||||
|
})
|
||||||
|
}
|
210
firewall_test.go
210
firewall_test.go
|
@ -11,11 +11,15 @@ import (
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewFirewall(t *testing.T) {
|
func TestNewFirewall(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
conntrack := fw.Conntrack
|
conntrack := fw.Conntrack
|
||||||
|
@ -54,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_AddRule(t *testing.T) {
|
func TestFirewall_AddRule(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
@ -65,92 +69,80 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
|
|
||||||
_, ti, _ := net.ParseCIDR("1.2.3.4/32")
|
_, ti, _ := net.ParseCIDR("1.2.3.4/32")
|
||||||
|
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
|
||||||
// An empty rule is any
|
// An empty rule is any
|
||||||
assert.True(t, fw.InRules.TCP[1].Any.Any)
|
assert.True(t, fw.InRules.TCP[1].Any.Any)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
|
|
||||||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
|
|
||||||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
|
||||||
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
||||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
||||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
|
|
||||||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
|
|
||||||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
|
||||||
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
||||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
|
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
|
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
|
||||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||||
|
|
||||||
// Set any and clear fields
|
// Set any and clear fields
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
|
||||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
||||||
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
|
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
||||||
|
|
||||||
// run twice just to make sure
|
// run twice just to make sure
|
||||||
//TODO: these ANY rules should clear the CA firewall portion
|
//TODO: these ANY rules should clear the CA firewall portion
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
|
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
|
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
|
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
|
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||||
|
|
||||||
// Test error conditions
|
// Test error conditions
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
|
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
|
||||||
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
|
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := firewall.Packet{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||||
10,
|
10,
|
||||||
90,
|
90,
|
||||||
fwProtoUDP,
|
firewall.ProtoUDP,
|
||||||
false,
|
false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,12 +164,12 @@ func TestFirewall_Drop(t *testing.T) {
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c,
|
peerCert: &c,
|
||||||
},
|
},
|
||||||
hostId: ip2int(ipNet.IP),
|
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||||
}
|
}
|
||||||
h.CreateRemoteCIDR(&c)
|
h.CreateRemoteCIDR(&c)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
|
@ -190,34 +182,34 @@ func TestFirewall_Drop(t *testing.T) {
|
||||||
|
|
||||||
// test remote mismatch
|
// test remote mismatch
|
||||||
oldRemote := p.RemoteIP
|
oldRemote := p.RemoteIP
|
||||||
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
|
p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
|
||||||
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||||
p.RemoteIP = oldRemote
|
p.RemoteIP = oldRemote
|
||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
||||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
||||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
||||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
||||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,14 +229,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
b.Run("fail on proto", func(b *testing.B) {
|
b.Run("fail on proto", func(b *testing.B) {
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
|
ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("fail on port", func(b *testing.B) {
|
b.Run("fail on port", func(b *testing.B) {
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
|
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -258,7 +250,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
|
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -270,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
|
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -282,12 +274,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
|
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("pass on ip", func(b *testing.B) {
|
b.Run("pass on ip", func(b *testing.B) {
|
||||||
ip := ip2int(net.IPv4(172, 1, 1, 1))
|
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||||
c := &cert.NebulaCertificate{
|
c := &cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
Details: cert.NebulaCertificateDetails{
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||||
|
@ -295,14 +287,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
|
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
|
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
|
||||||
|
|
||||||
b.Run("pass on ip with any port", func(b *testing.B) {
|
b.Run("pass on ip with any port", func(b *testing.B) {
|
||||||
ip := ip2int(net.IPv4(172, 1, 1, 1))
|
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||||
c := &cert.NebulaCertificate{
|
c := &cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
Details: cert.NebulaCertificateDetails{
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||||
|
@ -310,22 +302,22 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
|
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop2(t *testing.T) {
|
func TestFirewall_Drop2(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := firewall.Packet{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||||
10,
|
10,
|
||||||
90,
|
90,
|
||||||
fwProtoUDP,
|
firewall.ProtoUDP,
|
||||||
false,
|
false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -345,7 +337,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c,
|
peerCert: &c,
|
||||||
},
|
},
|
||||||
hostId: ip2int(ipNet.IP),
|
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||||
}
|
}
|
||||||
h.CreateRemoteCIDR(&c)
|
h.CreateRemoteCIDR(&c)
|
||||||
|
|
||||||
|
@ -364,7 +356,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||||
h1.CreateRemoteCIDR(&c1)
|
h1.CreateRemoteCIDR(&c1)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// h1/c1 lacks the proper groups
|
// h1/c1 lacks the proper groups
|
||||||
|
@ -375,16 +367,16 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3(t *testing.T) {
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := firewall.Packet{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
fwProtoUDP,
|
firewall.ProtoUDP,
|
||||||
false,
|
false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -411,7 +403,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c1,
|
peerCert: &c1,
|
||||||
},
|
},
|
||||||
hostId: ip2int(ipNet.IP),
|
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||||
}
|
}
|
||||||
h1.CreateRemoteCIDR(&c1)
|
h1.CreateRemoteCIDR(&c1)
|
||||||
|
|
||||||
|
@ -426,7 +418,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c2,
|
peerCert: &c2,
|
||||||
},
|
},
|
||||||
hostId: ip2int(ipNet.IP),
|
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||||
}
|
}
|
||||||
h2.CreateRemoteCIDR(&c2)
|
h2.CreateRemoteCIDR(&c2)
|
||||||
|
|
||||||
|
@ -441,13 +433,13 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c3,
|
peerCert: &c3,
|
||||||
},
|
},
|
||||||
hostId: ip2int(ipNet.IP),
|
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||||
}
|
}
|
||||||
h3.CreateRemoteCIDR(&c3)
|
h3.CreateRemoteCIDR(&c3)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// c1 should pass because host match
|
// c1 should pass because host match
|
||||||
|
@ -461,16 +453,16 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := firewall.Packet{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||||
10,
|
10,
|
||||||
90,
|
90,
|
||||||
fwProtoUDP,
|
firewall.ProtoUDP,
|
||||||
false,
|
false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -492,12 +484,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c,
|
peerCert: &c,
|
||||||
},
|
},
|
||||||
hostId: ip2int(ipNet.IP),
|
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||||
}
|
}
|
||||||
h.CreateRemoteCIDR(&c)
|
h.CreateRemoteCIDR(&c)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
|
@ -510,7 +502,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
|
|
||||||
oldFw := fw
|
oldFw := fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
|
@ -519,7 +511,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
|
|
||||||
oldFw = fw
|
oldFw = fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
|
@ -643,28 +635,28 @@ func Test_parsePort(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewFirewallFromConfig(t *testing.T) {
|
func TestNewFirewallFromConfig(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
// Test a bad rule definition
|
// Test a bad rule definition
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
conf := NewConfig(l)
|
conf := config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
||||||
_, err := NewFirewallFromConfig(l, c, conf)
|
_, err := NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||||
|
|
||||||
// Test both port and code
|
// Test both port and code
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||||
|
|
||||||
// Test missing host, group, cidr, ca_name and ca_sha
|
// Test missing host, group, cidr, ca_name and ca_sha
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
|
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
|
||||||
|
|
||||||
// Test code/port error
|
// Test code/port error
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||||
|
@ -674,91 +666,91 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||||
|
|
||||||
// Test proto error
|
// Test proto error
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||||
|
|
||||||
// Test cidr parse error
|
// Test cidr parse error
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
|
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
|
||||||
|
|
||||||
// Test both group and groups
|
// Test both group and groups
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, c, conf)
|
||||||
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
// Test adding tcp rule
|
// Test adding tcp rule
|
||||||
conf := NewConfig(l)
|
conf := config.NewC(l)
|
||||||
mf := &mockFirewall{}
|
mf := &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding udp rule
|
// Test adding udp rule
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding icmp rule
|
// Test adding icmp rule
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding any rule
|
// Test adding any rule
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_name
|
// Test adding rule with ca_name
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
|
||||||
|
|
||||||
// Test single group
|
// Test single group
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test single groups
|
// Test single groups
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test multiple AND groups
|
// Test multiple AND groups
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
|
||||||
|
|
||||||
// Test Add error
|
// Test Add error
|
||||||
conf = NewConfig(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
mf.nextCallReturn = errors.New("test error")
|
mf.nextCallReturn = errors.New("test error")
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
|
@ -857,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRule(t *testing.T) {
|
func TestFirewall_convertRule(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
@ -929,6 +921,6 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
|
||||||
|
|
||||||
func resetConntrack(fw *Firewall) {
|
func resetConntrack(fw *Firewall) {
|
||||||
fw.Conntrack.Lock()
|
fw.Conntrack.Lock()
|
||||||
fw.Conntrack.Conns = map[FirewallPacket]*conn{}
|
fw.Conntrack.Conns = map[firewall.Packet]*conn{}
|
||||||
fw.Conntrack.Unlock()
|
fw.Conntrack.Unlock()
|
||||||
}
|
}
|
||||||
|
|
10
handshake.go
10
handshake.go
|
@ -1,11 +1,11 @@
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
const (
|
import (
|
||||||
handshakeIXPSK0 = 0
|
"github.com/slackhq/nebula/header"
|
||||||
handshakeXXPSK0 = 1
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
|
func HandleIncomingHandshake(f *Interface, addr *udp.Addr, packet []byte, h *header.H, hostinfo *HostInfo) {
|
||||||
// First remote allow list check before we know the vpnIp
|
// First remote allow list check before we know the vpnIp
|
||||||
if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) {
|
if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) {
|
||||||
f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||||
|
@ -13,7 +13,7 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head
|
||||||
}
|
}
|
||||||
|
|
||||||
switch h.Subtype {
|
switch h.Subtype {
|
||||||
case handshakeIXPSK0:
|
case header.HandshakeIXPSK0:
|
||||||
switch h.MessageCounter {
|
switch h.MessageCounter {
|
||||||
case 1:
|
case 1:
|
||||||
ixHandshakeStage1(f, addr, packet, h)
|
ixHandshakeStage1(f, addr, packet, h)
|
||||||
|
|
115
handshake_ix.go
115
handshake_ix.go
|
@ -6,13 +6,16 @@ import (
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NOISE IX Handshakes
|
// NOISE IX Handshakes
|
||||||
|
|
||||||
// This function constructs a handshake packet, but does not actually send it
|
// This function constructs a handshake packet, but does not actually send it
|
||||||
// Sending is done by the handshake manager
|
// Sending is done by the handshake manager
|
||||||
func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
|
||||||
// This queries the lighthouse if we don't know a remote for the host
|
// This queries the lighthouse if we don't know a remote for the host
|
||||||
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
|
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
|
||||||
// more quickly, effect is a quicker handshake.
|
// more quickly, effect is a quicker handshake.
|
||||||
|
@ -22,7 +25,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||||
|
|
||||||
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
|
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
f.l.WithError(err).WithField("vpnIp", vpnIp).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -43,17 +46,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||||
hsBytes, err = proto.Marshal(hs)
|
hsBytes, err = proto.Marshal(hs)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
f.l.WithError(err).WithField("vpnIp", vpnIp).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1)
|
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
|
||||||
atomic.AddUint64(&ci.atomicMessageCounter, 1)
|
atomic.AddUint64(&ci.atomicMessageCounter, 1)
|
||||||
|
|
||||||
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
|
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
f.l.WithError(err).WithField("vpnIp", vpnIp).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -67,12 +70,12 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||||
hostinfo.handshakeStart = time.Now()
|
hostinfo.handshakeStart = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) {
|
||||||
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
|
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
|
||||||
// Mark packet 1 as seen so it doesn't show up as missed
|
// Mark packet 1 as seen so it doesn't show up as missed
|
||||||
ci.window.Update(f.l, 1)
|
ci.window.Update(f.l, 1)
|
||||||
|
|
||||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
||||||
|
@ -97,13 +100,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
Info("Invalid certificate from host")
|
Info("Invalid certificate from host")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
|
vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
|
||||||
certName := remoteCert.Details.Name
|
certName := remoteCert.Details.Name
|
||||||
fingerprint, _ := remoteCert.Sha256Sum()
|
fingerprint, _ := remoteCert.Sha256Sum()
|
||||||
issuer := remoteCert.Details.Issuer
|
issuer := remoteCert.Details.Issuer
|
||||||
|
|
||||||
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
|
if vpnIp == f.myVpnIp {
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
@ -111,14 +114,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !f.lightHouse.remoteAllowList.Allow(vpnIP, addr.IP) {
|
if !f.lightHouse.remoteAllowList.Allow(vpnIp, addr.IP) {
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
myIndex, err := generateIndex(f.l)
|
myIndex, err := generateIndex(f.l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
@ -130,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
ConnectionState: ci,
|
ConnectionState: ci,
|
||||||
localIndexId: myIndex,
|
localIndexId: myIndex,
|
||||||
remoteIndexId: hs.Details.InitiatorIndex,
|
remoteIndexId: hs.Details.InitiatorIndex,
|
||||||
hostId: vpnIP,
|
vpnIp: vpnIp,
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
lastHandshakeTime: hs.Details.Time,
|
lastHandshakeTime: hs.Details.Time,
|
||||||
}
|
}
|
||||||
|
@ -138,7 +141,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
hostinfo.Lock()
|
hostinfo.Lock()
|
||||||
defer hostinfo.Unlock()
|
defer hostinfo.Unlock()
|
||||||
|
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
@ -153,7 +156,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
|
|
||||||
hsBytes, err := proto.Marshal(hs)
|
hsBytes, err := proto.Marshal(hs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
@ -161,17 +164,17 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||||
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
|
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
return
|
return
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
@ -179,8 +182,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
|
hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:]))
|
||||||
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
|
copy(hostinfo.HandshakePacket[0], packet[header.Len:])
|
||||||
|
|
||||||
// Regardless of whether you are the sender or receiver, you should arrive here
|
// Regardless of whether you are the sender or receiver, you should arrive here
|
||||||
// and complete standing up the connection.
|
// and complete standing up the connection.
|
||||||
|
@ -195,12 +198,12 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
ci.dKey = NewNebulaCipherState(dKey)
|
ci.dKey = NewNebulaCipherState(dKey)
|
||||||
ci.eKey = NewNebulaCipherState(eKey)
|
ci.eKey = NewNebulaCipherState(eKey)
|
||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.CreateRemoteCIDR(remoteCert)
|
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||||
|
|
||||||
// Only overwrite existing record if we should win the handshake race
|
// Only overwrite existing record if we should win the handshake race
|
||||||
overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
|
overwrite := vpnIp > f.myVpnIp
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err {
|
switch err {
|
||||||
|
@ -214,27 +217,27 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
|
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
|
||||||
// Send a test packet to ensure the other side has also switched to
|
// Send a test packet to ensure the other side has also switched to
|
||||||
// the preferred remote
|
// the preferred remote
|
||||||
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}
|
}
|
||||||
existing.Unlock()
|
existing.Unlock()
|
||||||
hostinfo.Lock()
|
hostinfo.Lock()
|
||||||
|
|
||||||
msg = existing.HandshakePacket[2]
|
msg = existing.HandshakePacket[2]
|
||||||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||||
err := f.outside.WriteTo(msg, addr)
|
err := f.outside.WriteTo(msg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||||
WithError(err).Error("Failed to send handshake message")
|
WithError(err).Error("Failed to send handshake message")
|
||||||
} else {
|
} else {
|
||||||
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case ErrExistingHostInfo:
|
case ErrExistingHostInfo:
|
||||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
||||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
||||||
|
@ -245,22 +248,22 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
Info("Handshake too old")
|
Info("Handshake too old")
|
||||||
|
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||||
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
return
|
return
|
||||||
case ErrLocalIndexCollision:
|
case ErrLocalIndexCollision:
|
||||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
|
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
|
||||||
Error("Failed to add HostInfo due to localIndex collision")
|
Error("Failed to add HostInfo due to localIndex collision")
|
||||||
return
|
return
|
||||||
case ErrExistingHandshake:
|
case ErrExistingHandshake:
|
||||||
// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
|
// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
@ -271,7 +274,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
default:
|
default:
|
||||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||||
// And we forget to update it here
|
// And we forget to update it here
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
@ -283,10 +286,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do the send
|
// Do the send
|
||||||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||||
err = f.outside.WriteTo(msg, addr)
|
err = f.outside.WriteTo(msg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
@ -294,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithError(err).Error("Failed to send handshake")
|
WithError(err).Error("Failed to send handshake")
|
||||||
} else {
|
} else {
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
@ -309,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
|
func ixHandshakeStage2(f *Interface, addr *udp.Addr, hostinfo *HostInfo, packet []byte, h *header.H) bool {
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
// Nothing here to tear down, got a bogus stage 2 packet
|
// Nothing here to tear down, got a bogus stage 2 packet
|
||||||
return true
|
return true
|
||||||
|
@ -318,14 +321,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
hostinfo.Lock()
|
hostinfo.Lock()
|
||||||
defer hostinfo.Unlock()
|
defer hostinfo.Unlock()
|
||||||
|
|
||||||
if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) {
|
if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
|
||||||
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
if ci.ready {
|
if ci.ready {
|
||||||
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||||
Info("Handshake is already complete")
|
Info("Handshake is already complete")
|
||||||
|
|
||||||
|
@ -333,16 +336,16 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
|
if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
|
||||||
// Send a test packet to ensure the other side has also switched to
|
// Send a test packet to ensure the other side has also switched to
|
||||||
// the preferred remote
|
// the preferred remote
|
||||||
f.SendMessageToVpnIp(test, testRequest, hostinfo.hostId, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}
|
}
|
||||||
|
|
||||||
// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
|
// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||||
Error("Failed to call noise.ReadMessage")
|
Error("Failed to call noise.ReadMessage")
|
||||||
|
|
||||||
|
@ -351,7 +354,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
// near future
|
// near future
|
||||||
return false
|
return false
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Error("Noise did not arrive at a key")
|
Error("Noise did not arrive at a key")
|
||||||
|
|
||||||
|
@ -363,7 +366,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = proto.Unmarshal(msg, hs)
|
err = proto.Unmarshal(msg, hs)
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||||
|
|
||||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||||
|
@ -372,7 +375,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
|
|
||||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Error("Invalid certificate from host")
|
Error("Invalid certificate from host")
|
||||||
|
|
||||||
|
@ -380,14 +383,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
|
vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
|
||||||
certName := remoteCert.Details.Name
|
certName := remoteCert.Details.Name
|
||||||
fingerprint, _ := remoteCert.Sha256Sum()
|
fingerprint, _ := remoteCert.Sha256Sum()
|
||||||
issuer := remoteCert.Details.Issuer
|
issuer := remoteCert.Details.Issuer
|
||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if vpnIP != hostinfo.hostId {
|
if vpnIp != hostinfo.vpnIp {
|
||||||
f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
|
f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
|
||||||
WithField("udpAddr", addr).WithField("certName", certName).
|
WithField("udpAddr", addr).WithField("certName", certName).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Info("Incorrect host responded to handshake")
|
Info("Incorrect host responded to handshake")
|
||||||
|
@ -397,7 +400,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
// Create a new hostinfo/handshake for the intended vpn ip
|
||||||
//TODO: this adds it to the timer wheel in a way that aggressively retries
|
//TODO: this adds it to the timer wheel in a way that aggressively retries
|
||||||
newHostInfo := f.getOrHandshake(hostinfo.hostId)
|
newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
|
||||||
newHostInfo.Lock()
|
newHostInfo.Lock()
|
||||||
|
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
|
@ -405,9 +408,9 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
newHostInfo.remotes.BlockRemote(addr)
|
newHostInfo.remotes.BlockRemote(addr)
|
||||||
|
|
||||||
// Get the correct remote list for the host we did handshake with
|
// Get the correct remote list for the host we did handshake with
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
|
||||||
|
|
||||||
f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)).
|
f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
|
||||||
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
|
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
|
||||||
Info("Blocked addresses for handshakes")
|
Info("Blocked addresses for handshakes")
|
||||||
|
|
||||||
|
@ -418,7 +421,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
hostinfo.ConnectionState.queueLock.Unlock()
|
hostinfo.ConnectionState.queueLock.Unlock()
|
||||||
|
|
||||||
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
|
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
|
||||||
hostinfo.hostId = vpnIP
|
hostinfo.vpnIp = vpnIp
|
||||||
f.sendCloseTunnel(hostinfo)
|
f.sendCloseTunnel(hostinfo)
|
||||||
newHostInfo.Unlock()
|
newHostInfo.Unlock()
|
||||||
|
|
||||||
|
@ -429,7 +432,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
|
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
|
|
|
@ -11,6 +11,9 @@ import (
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -39,7 +42,7 @@ type HandshakeManager struct {
|
||||||
pendingHostMap *HostMap
|
pendingHostMap *HostMap
|
||||||
mainHostMap *HostMap
|
mainHostMap *HostMap
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
outside *udpConn
|
outside *udp.Conn
|
||||||
config HandshakeConfig
|
config HandshakeConfig
|
||||||
OutboundHandshakeTimer *SystemTimerWheel
|
OutboundHandshakeTimer *SystemTimerWheel
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
|
@ -47,18 +50,18 @@ type HandshakeManager struct {
|
||||||
metricTimedOut metrics.Counter
|
metricTimedOut metrics.Counter
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
// can be used to trigger outbound handshake for the given vpnIP
|
// can be used to trigger outbound handshake for the given vpnIp
|
||||||
trigger chan uint32
|
trigger chan iputil.VpnIp
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
|
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager {
|
||||||
return &HandshakeManager{
|
return &HandshakeManager{
|
||||||
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
|
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
|
||||||
mainHostMap: mainHostMap,
|
mainHostMap: mainHostMap,
|
||||||
lightHouse: lightHouse,
|
lightHouse: lightHouse,
|
||||||
outside: outside,
|
outside: outside,
|
||||||
config: config,
|
config: config,
|
||||||
trigger: make(chan uint32, config.triggerBuffer),
|
trigger: make(chan iputil.VpnIp, config.triggerBuffer),
|
||||||
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
|
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
|
||||||
messageMetrics: config.messageMetrics,
|
messageMetrics: config.messageMetrics,
|
||||||
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
|
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
|
||||||
|
@ -67,7 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
|
func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
|
||||||
clockSource := time.NewTicker(c.config.tryInterval)
|
clockSource := time.NewTicker(c.config.tryInterval)
|
||||||
defer clockSource.Stop()
|
defer clockSource.Stop()
|
||||||
|
|
||||||
|
@ -76,7 +79,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case vpnIP := <-c.trigger:
|
case vpnIP := <-c.trigger:
|
||||||
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
|
c.l.WithField("vpnIp", vpnIP).Debug("HandshakeManager: triggered")
|
||||||
c.handleOutbound(vpnIP, f, true)
|
c.handleOutbound(vpnIP, f, true)
|
||||||
case now := <-clockSource.C:
|
case now := <-clockSource.C:
|
||||||
c.NextOutboundHandshakeTimerTick(now, f)
|
c.NextOutboundHandshakeTimerTick(now, f)
|
||||||
|
@ -84,20 +87,20 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
|
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
|
||||||
c.OutboundHandshakeTimer.advance(now)
|
c.OutboundHandshakeTimer.advance(now)
|
||||||
for {
|
for {
|
||||||
ep := c.OutboundHandshakeTimer.Purge()
|
ep := c.OutboundHandshakeTimer.Purge()
|
||||||
if ep == nil {
|
if ep == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
vpnIP := ep.(uint32)
|
vpnIp := ep.(iputil.VpnIp)
|
||||||
c.handleOutbound(vpnIP, f, false)
|
c.handleOutbound(vpnIp, f, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
|
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
|
||||||
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
|
hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -115,7 +118,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
|
||||||
if !hostinfo.HandshakeReady {
|
if !hostinfo.HandshakeReady {
|
||||||
// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
|
// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
|
||||||
// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
|
// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
|
||||||
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
|
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,21 +146,21 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
|
||||||
// Get a remotes object if we don't already have one.
|
// Get a remotes object if we don't already have one.
|
||||||
// This is mainly to protect us as this should never be the case
|
// This is mainly to protect us as this should never be the case
|
||||||
if hostinfo.remotes == nil {
|
if hostinfo.remotes == nil {
|
||||||
hostinfo.remotes = c.lightHouse.QueryCache(vpnIP)
|
hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
|
//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
|
||||||
if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
|
if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
|
||||||
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
|
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
|
||||||
// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
|
// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
|
||||||
// the learned public ip for them. Query again to short circuit the promotion counter
|
// the learned public ip for them. Query again to short circuit the promotion counter
|
||||||
c.lightHouse.QueryServer(vpnIP, f)
|
c.lightHouse.QueryServer(vpnIp, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
|
// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
|
||||||
var sentTo []*udpAddr
|
var sentTo []*udp.Addr
|
||||||
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) {
|
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
|
||||||
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||||
err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(c.l).WithField("udpAddr", addr).
|
hostinfo.logger(c.l).WithField("udpAddr", addr).
|
||||||
|
@ -184,16 +187,16 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
|
||||||
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
|
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
|
||||||
if !lighthouseTriggered {
|
if !lighthouseTriggered {
|
||||||
//TODO: feel like we dupe handshake real fast in a tight loop, why?
|
//TODO: feel like we dupe handshake real fast in a tight loop, why?
|
||||||
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
|
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
|
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
||||||
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
|
hostinfo := c.pendingHostMap.AddVpnIp(vpnIp)
|
||||||
// We lock here and use an array to insert items to prevent locking the
|
// We lock here and use an array to insert items to prevent locking the
|
||||||
// main receive thread for very long by waiting to add items to the pending map
|
// main receive thread for very long by waiting to add items to the pending map
|
||||||
//TODO: what lock?
|
//TODO: what lock?
|
||||||
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
|
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
|
||||||
c.metricInitiated.Inc(1)
|
c.metricInitiated.Inc(1)
|
||||||
|
|
||||||
return hostinfo
|
return hostinfo
|
||||||
|
@ -208,12 +211,12 @@ var (
|
||||||
|
|
||||||
// CheckAndComplete checks for any conflicts in the main and pending hostmap
|
// CheckAndComplete checks for any conflicts in the main and pending hostmap
|
||||||
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
|
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
|
||||||
|
//
|
||||||
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the
|
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the
|
||||||
// exact same handshake packet
|
// exact same handshake packet
|
||||||
//
|
//
|
||||||
// ErrExistingHostInfo if we already have an entry in the hostmap for this
|
// ErrExistingHostInfo if we already have an entry in the hostmap for this
|
||||||
// VpnIP and the new handshake was older than the one we currently have
|
// VpnIp and the new handshake was older than the one we currently have
|
||||||
//
|
//
|
||||||
// ErrLocalIndexCollision if we already have an entry in the main or pending
|
// ErrLocalIndexCollision if we already have an entry in the main or pending
|
||||||
// hostmap for the hostinfo.localIndexId.
|
// hostmap for the hostinfo.localIndexId.
|
||||||
|
@ -224,7 +227,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
||||||
defer c.mainHostMap.Unlock()
|
defer c.mainHostMap.Unlock()
|
||||||
|
|
||||||
// Check if we already have a tunnel with this vpn ip
|
// Check if we already have a tunnel with this vpn ip
|
||||||
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
|
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
|
||||||
if found && existingHostInfo != nil {
|
if found && existingHostInfo != nil {
|
||||||
// Is it just a delayed handshake packet?
|
// Is it just a delayed handshake packet?
|
||||||
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
|
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
|
||||||
|
@ -252,16 +255,16 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
||||||
}
|
}
|
||||||
|
|
||||||
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
||||||
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
|
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
|
||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger(c.l).
|
hostinfo.logger(c.l).
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
|
||||||
Info("New host shadows existing host remoteIndex")
|
Info("New host shadows existing host remoteIndex")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we are also handshaking with this vpn ip
|
// Check if we are also handshaking with this vpn ip
|
||||||
pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId]
|
pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp]
|
||||||
if found && pendingHostInfo != nil {
|
if found && pendingHostInfo != nil {
|
||||||
if !overwrite {
|
if !overwrite {
|
||||||
// We won, let our pending handshake win
|
// We won, let our pending handshake win
|
||||||
|
@ -278,7 +281,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
||||||
|
|
||||||
if existingHostInfo != nil {
|
if existingHostInfo != nil {
|
||||||
// We are going to overwrite this entry, so remove the old references
|
// We are going to overwrite this entry, so remove the old references
|
||||||
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
|
delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
|
||||||
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
||||||
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
|
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
|
||||||
}
|
}
|
||||||
|
@ -296,10 +299,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
||||||
c.mainHostMap.Lock()
|
c.mainHostMap.Lock()
|
||||||
defer c.mainHostMap.Unlock()
|
defer c.mainHostMap.Unlock()
|
||||||
|
|
||||||
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
|
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
|
||||||
if found && existingHostInfo != nil {
|
if found && existingHostInfo != nil {
|
||||||
// We are going to overwrite this entry, so remove the old references
|
// We are going to overwrite this entry, so remove the old references
|
||||||
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
|
delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
|
||||||
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
||||||
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
|
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
|
||||||
}
|
}
|
||||||
|
@ -309,7 +312,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger(c.l).
|
hostinfo.logger(c.l).
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
|
||||||
Info("New host shadows existing host remoteIndex")
|
Info("New host shadows existing host remoteIndex")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,25 +5,29 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
ip := ip2int(net.ParseIP("172.1.1.2"))
|
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
|
|
||||||
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udp.Conn{}, defaultHandshakeConfig)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||||
|
|
||||||
i := blah.AddVpnIP(ip)
|
i := blah.AddVpnIp(ip)
|
||||||
i.remotes = NewRemoteList()
|
i.remotes = NewRemoteList()
|
||||||
i.HandshakeReady = true
|
i.HandshakeReady = true
|
||||||
|
|
||||||
|
@ -50,24 +54,24 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
ip := ip2int(net.ParseIP("172.1.1.2"))
|
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
|
||||||
preferredRanges := []*net.IPNet{localrange}
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||||
lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l}
|
lh := &LightHouse{addrMap: make(map[iputil.VpnIp]*RemoteList), l: l}
|
||||||
|
|
||||||
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||||
|
|
||||||
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||||
|
|
||||||
hi := blah.AddVpnIP(ip)
|
hi := blah.AddVpnIp(ip)
|
||||||
hi.HandshakeReady = true
|
hi.HandshakeReady = true
|
||||||
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||||
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
|
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
|
||||||
|
@ -80,7 +84,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
||||||
// Make sure the trigger doesn't double schedule the timer entry
|
// Make sure the trigger doesn't double schedule the timer entry
|
||||||
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||||
|
|
||||||
uaddr := NewUDPAddrFromString("10.1.1.1:4242")
|
uaddr := udp.NewAddrFromString("10.1.1.1:4242")
|
||||||
hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
|
hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
|
||||||
|
|
||||||
// We now have remotes but only the first trigger should have pushed things forward
|
// We now have remotes but only the first trigger should have pushed things forward
|
||||||
|
@ -103,6 +107,6 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
|
||||||
type mockEncWriter struct {
|
type mockEncWriter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package nebula
|
package header
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
@ -19,82 +19,78 @@ import (
|
||||||
// |-----------------------------------------------------------------------|
|
// |-----------------------------------------------------------------------|
|
||||||
// | payload... |
|
// | payload... |
|
||||||
|
|
||||||
|
type m map[string]interface{}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Version uint8 = 1
|
Version uint8 = 1
|
||||||
HeaderLen = 16
|
Len = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
type NebulaMessageType uint8
|
type MessageType uint8
|
||||||
type NebulaMessageSubType uint8
|
type MessageSubType uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
handshake NebulaMessageType = 0
|
Handshake MessageType = 0
|
||||||
message NebulaMessageType = 1
|
Message MessageType = 1
|
||||||
recvError NebulaMessageType = 2
|
RecvError MessageType = 2
|
||||||
lightHouse NebulaMessageType = 3
|
LightHouse MessageType = 3
|
||||||
test NebulaMessageType = 4
|
Test MessageType = 4
|
||||||
closeTunnel NebulaMessageType = 5
|
CloseTunnel MessageType = 5
|
||||||
|
|
||||||
//TODO These are deprecated as of 06/12/2018 - NB
|
|
||||||
testRemote NebulaMessageType = 6
|
|
||||||
testRemoteReply NebulaMessageType = 7
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var typeMap = map[NebulaMessageType]string{
|
var typeMap = map[MessageType]string{
|
||||||
handshake: "handshake",
|
Handshake: "handshake",
|
||||||
message: "message",
|
Message: "message",
|
||||||
recvError: "recvError",
|
RecvError: "recvError",
|
||||||
lightHouse: "lightHouse",
|
LightHouse: "lightHouse",
|
||||||
test: "test",
|
Test: "test",
|
||||||
closeTunnel: "closeTunnel",
|
CloseTunnel: "closeTunnel",
|
||||||
|
|
||||||
//TODO These are deprecated as of 06/12/2018 - NB
|
|
||||||
testRemote: "testRemote",
|
|
||||||
testRemoteReply: "testRemoteReply",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testRequest NebulaMessageSubType = 0
|
TestRequest MessageSubType = 0
|
||||||
testReply NebulaMessageSubType = 1
|
TestReply MessageSubType = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
var eHeaderTooShort = errors.New("header is too short")
|
const (
|
||||||
|
HandshakeIXPSK0 MessageSubType = 0
|
||||||
|
HandshakeXXPSK0 MessageSubType = 1
|
||||||
|
)
|
||||||
|
|
||||||
var subTypeTestMap = map[NebulaMessageSubType]string{
|
var ErrHeaderTooShort = errors.New("header is too short")
|
||||||
testRequest: "testRequest",
|
|
||||||
testReply: "testReply",
|
var subTypeTestMap = map[MessageSubType]string{
|
||||||
|
TestRequest: "testRequest",
|
||||||
|
TestReply: "testReply",
|
||||||
}
|
}
|
||||||
|
|
||||||
var subTypeNoneMap = map[NebulaMessageSubType]string{0: "none"}
|
var subTypeNoneMap = map[MessageSubType]string{0: "none"}
|
||||||
|
|
||||||
var subTypeMap = map[NebulaMessageType]*map[NebulaMessageSubType]string{
|
var subTypeMap = map[MessageType]*map[MessageSubType]string{
|
||||||
message: &subTypeNoneMap,
|
Message: &subTypeNoneMap,
|
||||||
recvError: &subTypeNoneMap,
|
RecvError: &subTypeNoneMap,
|
||||||
lightHouse: &subTypeNoneMap,
|
LightHouse: &subTypeNoneMap,
|
||||||
test: &subTypeTestMap,
|
Test: &subTypeTestMap,
|
||||||
closeTunnel: &subTypeNoneMap,
|
CloseTunnel: &subTypeNoneMap,
|
||||||
handshake: {
|
Handshake: {
|
||||||
handshakeIXPSK0: "ix_psk0",
|
HandshakeIXPSK0: "ix_psk0",
|
||||||
},
|
},
|
||||||
//TODO: these are deprecated
|
|
||||||
testRemote: &subTypeNoneMap,
|
|
||||||
testRemoteReply: &subTypeNoneMap,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Header struct {
|
type H struct {
|
||||||
Version uint8
|
Version uint8
|
||||||
Type NebulaMessageType
|
Type MessageType
|
||||||
Subtype NebulaMessageSubType
|
Subtype MessageSubType
|
||||||
Reserved uint16
|
Reserved uint16
|
||||||
RemoteIndex uint32
|
RemoteIndex uint32
|
||||||
MessageCounter uint64
|
MessageCounter uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// HeaderEncode uses the provided byte array to encode the provided header values into.
|
// Encode uses the provided byte array to encode the provided header values into.
|
||||||
// Byte array must be capped higher than HeaderLen or this will panic
|
// Byte array must be capped higher than HeaderLen or this will panic
|
||||||
func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []byte {
|
func Encode(b []byte, v uint8, t MessageType, st MessageSubType, ri uint32, c uint64) []byte {
|
||||||
b = b[:HeaderLen]
|
b = b[:Len]
|
||||||
b[0] = byte(v<<4 | (t & 0x0f))
|
b[0] = v<<4 | byte(t&0x0f)
|
||||||
b[1] = byte(st)
|
b[1] = byte(st)
|
||||||
binary.BigEndian.PutUint16(b[2:4], 0)
|
binary.BigEndian.PutUint16(b[2:4], 0)
|
||||||
binary.BigEndian.PutUint32(b[4:8], ri)
|
binary.BigEndian.PutUint32(b[4:8], ri)
|
||||||
|
@ -103,7 +99,7 @@ func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []b
|
||||||
}
|
}
|
||||||
|
|
||||||
// String creates a readable string representation of a header
|
// String creates a readable string representation of a header
|
||||||
func (h *Header) String() string {
|
func (h *H) String() string {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
return "<nil>"
|
return "<nil>"
|
||||||
}
|
}
|
||||||
|
@ -112,7 +108,7 @@ func (h *Header) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON creates a json string representation of a header
|
// MarshalJSON creates a json string representation of a header
|
||||||
func (h *Header) MarshalJSON() ([]byte, error) {
|
func (h *H) MarshalJSON() ([]byte, error) {
|
||||||
return json.Marshal(m{
|
return json.Marshal(m{
|
||||||
"version": h.Version,
|
"version": h.Version,
|
||||||
"type": h.TypeName(),
|
"type": h.TypeName(),
|
||||||
|
@ -124,24 +120,24 @@ func (h *Header) MarshalJSON() ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode turns header into bytes
|
// Encode turns header into bytes
|
||||||
func (h *Header) Encode(b []byte) ([]byte, error) {
|
func (h *H) Encode(b []byte) ([]byte, error) {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
return nil, errors.New("nil header")
|
return nil, errors.New("nil header")
|
||||||
}
|
}
|
||||||
|
|
||||||
return HeaderEncode(b, h.Version, uint8(h.Type), uint8(h.Subtype), h.RemoteIndex, h.MessageCounter), nil
|
return Encode(b, h.Version, h.Type, h.Subtype, h.RemoteIndex, h.MessageCounter), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse is a helper function to parses given bytes into new Header struct
|
// Parse is a helper function to parses given bytes into new Header struct
|
||||||
func (h *Header) Parse(b []byte) error {
|
func (h *H) Parse(b []byte) error {
|
||||||
if len(b) < HeaderLen {
|
if len(b) < Len {
|
||||||
return eHeaderTooShort
|
return ErrHeaderTooShort
|
||||||
}
|
}
|
||||||
// get upper 4 bytes
|
// get upper 4 bytes
|
||||||
h.Version = uint8((b[0] >> 4) & 0x0f)
|
h.Version = uint8((b[0] >> 4) & 0x0f)
|
||||||
// get lower 4 bytes
|
// get lower 4 bytes
|
||||||
h.Type = NebulaMessageType(b[0] & 0x0f)
|
h.Type = MessageType(b[0] & 0x0f)
|
||||||
h.Subtype = NebulaMessageSubType(b[1])
|
h.Subtype = MessageSubType(b[1])
|
||||||
h.Reserved = binary.BigEndian.Uint16(b[2:4])
|
h.Reserved = binary.BigEndian.Uint16(b[2:4])
|
||||||
h.RemoteIndex = binary.BigEndian.Uint32(b[4:8])
|
h.RemoteIndex = binary.BigEndian.Uint32(b[4:8])
|
||||||
h.MessageCounter = binary.BigEndian.Uint64(b[8:16])
|
h.MessageCounter = binary.BigEndian.Uint64(b[8:16])
|
||||||
|
@ -149,12 +145,12 @@ func (h *Header) Parse(b []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TypeName will transform the headers message type into a human string
|
// TypeName will transform the headers message type into a human string
|
||||||
func (h *Header) TypeName() string {
|
func (h *H) TypeName() string {
|
||||||
return TypeName(h.Type)
|
return TypeName(h.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TypeName will transform a nebula message type into a human string
|
// TypeName will transform a nebula message type into a human string
|
||||||
func TypeName(t NebulaMessageType) string {
|
func TypeName(t MessageType) string {
|
||||||
if n, ok := typeMap[t]; ok {
|
if n, ok := typeMap[t]; ok {
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
@ -163,12 +159,12 @@ func TypeName(t NebulaMessageType) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubTypeName will transform the headers message sub type into a human string
|
// SubTypeName will transform the headers message sub type into a human string
|
||||||
func (h *Header) SubTypeName() string {
|
func (h *H) SubTypeName() string {
|
||||||
return SubTypeName(h.Type, h.Subtype)
|
return SubTypeName(h.Type, h.Subtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubTypeName will transform a nebula message sub type into a human string
|
// SubTypeName will transform a nebula message sub type into a human string
|
||||||
func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
|
func SubTypeName(t MessageType, s MessageSubType) string {
|
||||||
if n, ok := subTypeMap[t]; ok {
|
if n, ok := subTypeMap[t]; ok {
|
||||||
if x, ok := (*n)[s]; ok {
|
if x, ok := (*n)[s]; ok {
|
||||||
return x
|
return x
|
||||||
|
@ -179,8 +175,8 @@ func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHeader turns bytes into a header
|
// NewHeader turns bytes into a header
|
||||||
func NewHeader(b []byte) (*Header, error) {
|
func NewHeader(b []byte) (*H, error) {
|
||||||
h := new(Header)
|
h := new(H)
|
||||||
if err := h.Parse(b); err != nil {
|
if err := h.Parse(b); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
|
@ -0,0 +1,115 @@
|
||||||
|
package header
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type headerTest struct {
|
||||||
|
expectedBytes []byte
|
||||||
|
*H
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0001 0010 00010010
|
||||||
|
var headerBigEndianTests = []headerTest{{
|
||||||
|
expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
|
||||||
|
// 1010 0000
|
||||||
|
H: &H{
|
||||||
|
// 1111 1+2+4+8 = 15
|
||||||
|
Version: 5,
|
||||||
|
Type: 4,
|
||||||
|
Subtype: 0,
|
||||||
|
Reserved: 0,
|
||||||
|
RemoteIndex: 10,
|
||||||
|
MessageCounter: 9,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncode(t *testing.T) {
|
||||||
|
for _, tt := range headerBigEndianTests {
|
||||||
|
b, err := tt.Encode(make([]byte, Len))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedBytes, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
for _, tt := range headerBigEndianTests {
|
||||||
|
b := tt.expectedBytes
|
||||||
|
parsedHeader := &H{}
|
||||||
|
parsedHeader.Parse(b)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tt.H, parsedHeader) {
|
||||||
|
t.Fatalf("got %#v; want %#v", parsedHeader, tt.H)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeName(t *testing.T) {
|
||||||
|
assert.Equal(t, "test", TypeName(Test))
|
||||||
|
assert.Equal(t, "test", (&H{Type: Test}).TypeName())
|
||||||
|
|
||||||
|
assert.Equal(t, "unknown", TypeName(99))
|
||||||
|
assert.Equal(t, "unknown", (&H{Type: 99}).TypeName())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubTypeName(t *testing.T) {
|
||||||
|
assert.Equal(t, "testRequest", SubTypeName(Test, TestRequest))
|
||||||
|
assert.Equal(t, "testRequest", (&H{Type: Test, Subtype: TestRequest}).SubTypeName())
|
||||||
|
|
||||||
|
assert.Equal(t, "unknown", SubTypeName(99, TestRequest))
|
||||||
|
assert.Equal(t, "unknown", (&H{Type: 99, Subtype: TestRequest}).SubTypeName())
|
||||||
|
|
||||||
|
assert.Equal(t, "unknown", SubTypeName(Test, 99))
|
||||||
|
assert.Equal(t, "unknown", (&H{Type: Test, Subtype: 99}).SubTypeName())
|
||||||
|
|
||||||
|
assert.Equal(t, "none", SubTypeName(Message, 0))
|
||||||
|
assert.Equal(t, "none", (&H{Type: Message, Subtype: 0}).SubTypeName())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeMap(t *testing.T) {
|
||||||
|
// Force people to document this stuff
|
||||||
|
assert.Equal(t, map[MessageType]string{
|
||||||
|
Handshake: "handshake",
|
||||||
|
Message: "message",
|
||||||
|
RecvError: "recvError",
|
||||||
|
LightHouse: "lightHouse",
|
||||||
|
Test: "test",
|
||||||
|
CloseTunnel: "closeTunnel",
|
||||||
|
}, typeMap)
|
||||||
|
|
||||||
|
assert.Equal(t, map[MessageType]*map[MessageSubType]string{
|
||||||
|
Message: &subTypeNoneMap,
|
||||||
|
RecvError: &subTypeNoneMap,
|
||||||
|
LightHouse: &subTypeNoneMap,
|
||||||
|
Test: &subTypeTestMap,
|
||||||
|
CloseTunnel: &subTypeNoneMap,
|
||||||
|
Handshake: {
|
||||||
|
HandshakeIXPSK0: "ix_psk0",
|
||||||
|
},
|
||||||
|
}, subTypeMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeader_String(t *testing.T) {
|
||||||
|
assert.Equal(
|
||||||
|
t,
|
||||||
|
"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
|
||||||
|
(&H{100, Test, TestRequest, 99, 98, 97}).String(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeader_MarshalJSON(t *testing.T) {
|
||||||
|
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(
|
||||||
|
t,
|
||||||
|
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
|
||||||
|
string(b),
|
||||||
|
)
|
||||||
|
}
|
119
header_test.go
119
header_test.go
|
@ -1,119 +0,0 @@
|
||||||
package nebula
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
type headerTest struct {
|
|
||||||
expectedBytes []byte
|
|
||||||
*Header
|
|
||||||
}
|
|
||||||
|
|
||||||
// 0001 0010 00010010
|
|
||||||
var headerBigEndianTests = []headerTest{{
|
|
||||||
expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
|
|
||||||
// 1010 0000
|
|
||||||
Header: &Header{
|
|
||||||
// 1111 1+2+4+8 = 15
|
|
||||||
Version: 5,
|
|
||||||
Type: 4,
|
|
||||||
Subtype: 0,
|
|
||||||
Reserved: 0,
|
|
||||||
RemoteIndex: 10,
|
|
||||||
MessageCounter: 9,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncode(t *testing.T) {
|
|
||||||
for _, tt := range headerBigEndianTests {
|
|
||||||
b, err := tt.Encode(make([]byte, HeaderLen))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedBytes, b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParse(t *testing.T) {
|
|
||||||
for _, tt := range headerBigEndianTests {
|
|
||||||
b := tt.expectedBytes
|
|
||||||
parsedHeader := &Header{}
|
|
||||||
parsedHeader.Parse(b)
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(tt.Header, parsedHeader) {
|
|
||||||
t.Fatalf("got %#v; want %#v", parsedHeader, tt.Header)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTypeName(t *testing.T) {
|
|
||||||
assert.Equal(t, "test", TypeName(test))
|
|
||||||
assert.Equal(t, "test", (&Header{Type: test}).TypeName())
|
|
||||||
|
|
||||||
assert.Equal(t, "unknown", TypeName(99))
|
|
||||||
assert.Equal(t, "unknown", (&Header{Type: 99}).TypeName())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSubTypeName(t *testing.T) {
|
|
||||||
assert.Equal(t, "testRequest", SubTypeName(test, testRequest))
|
|
||||||
assert.Equal(t, "testRequest", (&Header{Type: test, Subtype: testRequest}).SubTypeName())
|
|
||||||
|
|
||||||
assert.Equal(t, "unknown", SubTypeName(99, testRequest))
|
|
||||||
assert.Equal(t, "unknown", (&Header{Type: 99, Subtype: testRequest}).SubTypeName())
|
|
||||||
|
|
||||||
assert.Equal(t, "unknown", SubTypeName(test, 99))
|
|
||||||
assert.Equal(t, "unknown", (&Header{Type: test, Subtype: 99}).SubTypeName())
|
|
||||||
|
|
||||||
assert.Equal(t, "none", SubTypeName(message, 0))
|
|
||||||
assert.Equal(t, "none", (&Header{Type: message, Subtype: 0}).SubTypeName())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTypeMap(t *testing.T) {
|
|
||||||
// Force people to document this stuff
|
|
||||||
assert.Equal(t, map[NebulaMessageType]string{
|
|
||||||
handshake: "handshake",
|
|
||||||
message: "message",
|
|
||||||
recvError: "recvError",
|
|
||||||
lightHouse: "lightHouse",
|
|
||||||
test: "test",
|
|
||||||
closeTunnel: "closeTunnel",
|
|
||||||
testRemote: "testRemote",
|
|
||||||
testRemoteReply: "testRemoteReply",
|
|
||||||
}, typeMap)
|
|
||||||
|
|
||||||
assert.Equal(t, map[NebulaMessageType]*map[NebulaMessageSubType]string{
|
|
||||||
message: &subTypeNoneMap,
|
|
||||||
recvError: &subTypeNoneMap,
|
|
||||||
lightHouse: &subTypeNoneMap,
|
|
||||||
test: &subTypeTestMap,
|
|
||||||
closeTunnel: &subTypeNoneMap,
|
|
||||||
handshake: {
|
|
||||||
handshakeIXPSK0: "ix_psk0",
|
|
||||||
},
|
|
||||||
testRemote: &subTypeNoneMap,
|
|
||||||
testRemoteReply: &subTypeNoneMap,
|
|
||||||
}, subTypeMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHeader_String(t *testing.T) {
|
|
||||||
assert.Equal(
|
|
||||||
t,
|
|
||||||
"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
|
|
||||||
(&Header{100, test, testRequest, 99, 98, 97}).String(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHeader_MarshalJSON(t *testing.T) {
|
|
||||||
b, err := (&Header{100, test, testRequest, 99, 98, 97}).MarshalJSON()
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(
|
|
||||||
t,
|
|
||||||
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
|
|
||||||
string(b),
|
|
||||||
)
|
|
||||||
}
|
|
157
hostmap.go
157
hostmap.go
|
@ -12,6 +12,10 @@ import (
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/cidr"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
//const ProbeLen = 100
|
//const ProbeLen = 100
|
||||||
|
@ -28,10 +32,10 @@ type HostMap struct {
|
||||||
name string
|
name string
|
||||||
Indexes map[uint32]*HostInfo
|
Indexes map[uint32]*HostInfo
|
||||||
RemoteIndexes map[uint32]*HostInfo
|
RemoteIndexes map[uint32]*HostInfo
|
||||||
Hosts map[uint32]*HostInfo
|
Hosts map[iputil.VpnIp]*HostInfo
|
||||||
preferredRanges []*net.IPNet
|
preferredRanges []*net.IPNet
|
||||||
vpnCIDR *net.IPNet
|
vpnCIDR *net.IPNet
|
||||||
unsafeRoutes *CIDRTree
|
unsafeRoutes *cidr.Tree4
|
||||||
metricsEnabled bool
|
metricsEnabled bool
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
@ -39,7 +43,7 @@ type HostMap struct {
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
remote *udpAddr
|
remote *udp.Addr
|
||||||
remotes *RemoteList
|
remotes *RemoteList
|
||||||
promoteCounter uint32
|
promoteCounter uint32
|
||||||
ConnectionState *ConnectionState
|
ConnectionState *ConnectionState
|
||||||
|
@ -51,9 +55,9 @@ type HostInfo struct {
|
||||||
packetStore []*cachedPacket //todo: this is other handshake manager entry
|
packetStore []*cachedPacket //todo: this is other handshake manager entry
|
||||||
remoteIndexId uint32
|
remoteIndexId uint32
|
||||||
localIndexId uint32
|
localIndexId uint32
|
||||||
hostId uint32
|
vpnIp iputil.VpnIp
|
||||||
recvError int
|
recvError int
|
||||||
remoteCidr *CIDRTree
|
remoteCidr *cidr.Tree4
|
||||||
|
|
||||||
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
|
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
|
||||||
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
|
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
|
||||||
|
@ -66,17 +70,17 @@ type HostInfo struct {
|
||||||
lastHandshakeTime uint64
|
lastHandshakeTime uint64
|
||||||
|
|
||||||
lastRoam time.Time
|
lastRoam time.Time
|
||||||
lastRoamRemote *udpAddr
|
lastRoamRemote *udp.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
type cachedPacket struct {
|
type cachedPacket struct {
|
||||||
messageType NebulaMessageType
|
messageType header.MessageType
|
||||||
messageSubType NebulaMessageSubType
|
messageSubType header.MessageSubType
|
||||||
callback packetCallback
|
callback packetCallback
|
||||||
packet []byte
|
packet []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
|
type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte)
|
||||||
|
|
||||||
type cachedPacketMetrics struct {
|
type cachedPacketMetrics struct {
|
||||||
sent metrics.Counter
|
sent metrics.Counter
|
||||||
|
@ -84,7 +88,7 @@ type cachedPacketMetrics struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
|
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
|
||||||
h := map[uint32]*HostInfo{}
|
h := map[iputil.VpnIp]*HostInfo{}
|
||||||
i := map[uint32]*HostInfo{}
|
i := map[uint32]*HostInfo{}
|
||||||
r := map[uint32]*HostInfo{}
|
r := map[uint32]*HostInfo{}
|
||||||
m := HostMap{
|
m := HostMap{
|
||||||
|
@ -94,7 +98,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
|
||||||
Hosts: h,
|
Hosts: h,
|
||||||
preferredRanges: preferredRanges,
|
preferredRanges: preferredRanges,
|
||||||
vpnCIDR: vpnCIDR,
|
vpnCIDR: vpnCIDR,
|
||||||
unsafeRoutes: NewCIDRTree(),
|
unsafeRoutes: cidr.NewTree4(),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
return &m
|
return &m
|
||||||
|
@ -113,9 +117,9 @@ func (hm *HostMap) EmitStats(name string) {
|
||||||
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
|
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
|
func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
if i, ok := hm.Hosts[vpnIP]; ok {
|
if i, ok := hm.Hosts[vpnIp]; ok {
|
||||||
index := i.localIndexId
|
index := i.localIndexId
|
||||||
hm.RUnlock()
|
hm.RUnlock()
|
||||||
return index, nil
|
return index, nil
|
||||||
|
@ -124,43 +128,43 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
|
||||||
return 0, errors.New("vpn IP not found")
|
return 0, errors.New("vpn IP not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
|
func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
|
||||||
hm.Lock()
|
hm.Lock()
|
||||||
hm.Hosts[ip] = hostinfo
|
hm.Hosts[ip] = hostinfo
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
|
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
||||||
h := &HostInfo{}
|
h := &HostInfo{}
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
if _, ok := hm.Hosts[vpnIP]; !ok {
|
if _, ok := hm.Hosts[vpnIp]; !ok {
|
||||||
hm.RUnlock()
|
hm.RUnlock()
|
||||||
h = &HostInfo{
|
h = &HostInfo{
|
||||||
promoteCounter: 0,
|
promoteCounter: 0,
|
||||||
hostId: vpnIP,
|
vpnIp: vpnIp,
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
}
|
}
|
||||||
hm.Lock()
|
hm.Lock()
|
||||||
hm.Hosts[vpnIP] = h
|
hm.Hosts[vpnIp] = h
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
return h
|
return h
|
||||||
} else {
|
} else {
|
||||||
h = hm.Hosts[vpnIP]
|
h = hm.Hosts[vpnIp]
|
||||||
hm.RUnlock()
|
hm.RUnlock()
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
|
func (hm *HostMap) DeleteVpnIp(vpnIp iputil.VpnIp) {
|
||||||
hm.Lock()
|
hm.Lock()
|
||||||
delete(hm.Hosts, vpnIP)
|
delete(hm.Hosts, vpnIp)
|
||||||
if len(hm.Hosts) == 0 {
|
if len(hm.Hosts) == 0 {
|
||||||
hm.Hosts = map[uint32]*HostInfo{}
|
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
|
||||||
}
|
}
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts)}).
|
||||||
Debug("Hostmap vpnIp deleted")
|
Debug("Hostmap vpnIp deleted")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -174,22 +178,22 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
|
||||||
|
|
||||||
if hm.l.Level > logrus.DebugLevel {
|
if hm.l.Level > logrus.DebugLevel {
|
||||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
|
||||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
|
||||||
Debug("Hostmap remoteIndex added")
|
Debug("Hostmap remoteIndex added")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
|
func (hm *HostMap) AddVpnIpHostInfo(vpnIp iputil.VpnIp, h *HostInfo) {
|
||||||
hm.Lock()
|
hm.Lock()
|
||||||
h.hostId = vpnIP
|
h.vpnIp = vpnIp
|
||||||
hm.Hosts[vpnIP] = h
|
hm.Hosts[vpnIp] = h
|
||||||
hm.Indexes[h.localIndexId] = h
|
hm.Indexes[h.localIndexId] = h
|
||||||
hm.RemoteIndexes[h.remoteIndexId] = h
|
hm.RemoteIndexes[h.remoteIndexId] = h
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
|
||||||
if hm.l.Level > logrus.DebugLevel {
|
if hm.l.Level > logrus.DebugLevel {
|
||||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts),
|
||||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "vpnIp": h.vpnIp}}).
|
||||||
Debug("Hostmap vpnIp added")
|
Debug("Hostmap vpnIp added")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -204,9 +208,9 @@ func (hm *HostMap) DeleteIndex(index uint32) {
|
||||||
|
|
||||||
// Check if we have an entry under hostId that matches the same hostinfo
|
// Check if we have an entry under hostId that matches the same hostinfo
|
||||||
// instance. Clean it up as well if we do.
|
// instance. Clean it up as well if we do.
|
||||||
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
|
hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
|
||||||
if ok && hostinfo2 == hostinfo {
|
if ok && hostinfo2 == hostinfo {
|
||||||
delete(hm.Hosts, hostinfo.hostId)
|
delete(hm.Hosts, hostinfo.vpnIp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
@ -228,9 +232,9 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
|
||||||
// Check if we have an entry under hostId that matches the same hostinfo
|
// Check if we have an entry under hostId that matches the same hostinfo
|
||||||
// instance. Clean it up as well if we do (they might not match in pendingHostmap)
|
// instance. Clean it up as well if we do (they might not match in pendingHostmap)
|
||||||
var hostinfo2 *HostInfo
|
var hostinfo2 *HostInfo
|
||||||
hostinfo2, ok = hm.Hosts[hostinfo.hostId]
|
hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
|
||||||
if ok && hostinfo2 == hostinfo {
|
if ok && hostinfo2 == hostinfo {
|
||||||
delete(hm.Hosts, hostinfo.hostId)
|
delete(hm.Hosts, hostinfo.vpnIp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
|
@ -251,16 +255,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||||
// Check if this same hostId is in the hostmap with a different instance.
|
// Check if this same hostId is in the hostmap with a different instance.
|
||||||
// This could happen if we have an entry in the pending hostmap with different
|
// This could happen if we have an entry in the pending hostmap with different
|
||||||
// index values than the one in the main hostmap.
|
// index values than the one in the main hostmap.
|
||||||
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
|
hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
|
||||||
if ok && hostinfo2 != hostinfo {
|
if ok && hostinfo2 != hostinfo {
|
||||||
delete(hm.Hosts, hostinfo2.hostId)
|
delete(hm.Hosts, hostinfo2.vpnIp)
|
||||||
delete(hm.Indexes, hostinfo2.localIndexId)
|
delete(hm.Indexes, hostinfo2.localIndexId)
|
||||||
delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
|
delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(hm.Hosts, hostinfo.hostId)
|
delete(hm.Hosts, hostinfo.vpnIp)
|
||||||
if len(hm.Hosts) == 0 {
|
if len(hm.Hosts) == 0 {
|
||||||
hm.Hosts = map[uint32]*HostInfo{}
|
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
|
||||||
}
|
}
|
||||||
delete(hm.Indexes, hostinfo.localIndexId)
|
delete(hm.Indexes, hostinfo.localIndexId)
|
||||||
if len(hm.Indexes) == 0 {
|
if len(hm.Indexes) == 0 {
|
||||||
|
@ -273,7 +277,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
|
||||||
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||||
Debug("Hostmap hostInfo deleted")
|
Debug("Hostmap hostInfo deleted")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -301,17 +305,17 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
|
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
|
||||||
return hm.queryVpnIP(vpnIp, nil)
|
return hm.queryVpnIp(vpnIp, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every
|
// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
|
||||||
// `PromoteEvery` calls to this function for a given host.
|
// `PromoteEvery` calls to this function for a given host.
|
||||||
func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) {
|
func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
|
||||||
return hm.queryVpnIP(vpnIp, ifce)
|
return hm.queryVpnIp(vpnIp, ifce)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
|
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
if h, ok := hm.Hosts[vpnIp]; ok {
|
if h, ok := hm.Hosts[vpnIp]; ok {
|
||||||
hm.RUnlock()
|
hm.RUnlock()
|
||||||
|
@ -327,10 +331,10 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
|
||||||
return nil, errors.New("unable to find host")
|
return nil, errors.New("unable to find host")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
|
func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp {
|
||||||
r := hm.unsafeRoutes.MostSpecificContains(ip)
|
r := hm.unsafeRoutes.MostSpecificContains(ip)
|
||||||
if r != nil {
|
if r != nil {
|
||||||
return r.(uint32)
|
return r.(iputil.VpnIp)
|
||||||
} else {
|
} else {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
@ -344,13 +348,13 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||||
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
hm.Hosts[hostinfo.hostId] = hostinfo
|
hm.Hosts[hostinfo.vpnIp] = hostinfo
|
||||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
|
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
|
||||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
|
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
|
||||||
Debug("Hostmap vpnIp added")
|
Debug("Hostmap vpnIp added")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -370,7 +374,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
|
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
|
||||||
func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
|
func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
|
||||||
var metricsTxPunchy metrics.Counter
|
var metricsTxPunchy metrics.Counter
|
||||||
if hm.metricsEnabled {
|
if hm.metricsEnabled {
|
||||||
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
|
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
|
||||||
|
@ -406,7 +410,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
|
||||||
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
|
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
|
||||||
for _, r := range *routes {
|
for _, r := range *routes {
|
||||||
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
|
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
|
||||||
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
|
hm.unsafeRoutes.AddCIDR(r.route, iputil.Ip2VpnIp(*r.via))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -431,24 +435,24 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
|
i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
|
||||||
if addr == nil || !preferred {
|
if addr == nil || !preferred {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to send a test packet to that host, this should
|
// Try to send a test packet to that host, this should
|
||||||
// cause it to detect a roaming event and switch remotes
|
// cause it to detect a roaming event and switch remotes
|
||||||
ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
ifce.send(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re query our lighthouses for new remotes occasionally
|
// Re query our lighthouses for new remotes occasionally
|
||||||
if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
|
if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
|
||||||
ifce.lightHouse.QueryServer(i.hostId, ifce)
|
ifce.lightHouse.QueryServer(i.vpnIp, ifce)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
||||||
//TODO: return the error so we can log with more context
|
//TODO: return the error so we can log with more context
|
||||||
if len(i.packetStore) < 100 {
|
if len(i.packetStore) < 100 {
|
||||||
tempPacket := make([]byte, len(packet))
|
tempPacket := make([]byte, len(packet))
|
||||||
|
@ -510,17 +514,17 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) SetRemote(remote *udpAddr) {
|
func (i *HostInfo) SetRemote(remote *udp.Addr) {
|
||||||
// We copy here because we likely got this remote from a source that reuses the object
|
// We copy here because we likely got this remote from a source that reuses the object
|
||||||
if !i.remote.Equals(remote) {
|
if !i.remote.Equals(remote) {
|
||||||
i.remote = remote.Copy()
|
i.remote = remote.Copy()
|
||||||
i.remotes.LearnRemote(i.hostId, remote.Copy())
|
i.remotes.LearnRemote(i.vpnIp, remote.Copy())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
|
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
|
||||||
// time on the HostInfo will also be updated.
|
// time on the HostInfo will also be updated.
|
||||||
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udpAddr) bool {
|
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
|
||||||
currentRemote := i.remote
|
currentRemote := i.remote
|
||||||
if currentRemote == nil {
|
if currentRemote == nil {
|
||||||
i.SetRemote(newRemote)
|
i.SetRemote(newRemote)
|
||||||
|
@ -572,7 +576,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCidr := NewCIDRTree()
|
remoteCidr := cidr.NewTree4()
|
||||||
for _, ip := range c.Details.Ips {
|
for _, ip := range c.Details.Ips {
|
||||||
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||||
}
|
}
|
||||||
|
@ -588,8 +592,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||||
return logrus.NewEntry(l)
|
return logrus.NewEntry(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
li := l.WithField("vpnIp", IntIp(i.hostId))
|
li := l.WithField("vpnIp", i.vpnIp)
|
||||||
|
|
||||||
if connState := i.ConnectionState; connState != nil {
|
if connState := i.ConnectionState; connState != nil {
|
||||||
if peerCert := connState.peerCert; peerCert != nil {
|
if peerCert := connState.peerCert; peerCert != nil {
|
||||||
li = li.WithField("certName", peerCert.Details.Name)
|
li = li.WithField("certName", peerCert.Details.Name)
|
||||||
|
@ -599,38 +602,6 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||||
return li
|
return li
|
||||||
}
|
}
|
||||||
|
|
||||||
//########################
|
|
||||||
|
|
||||||
/*
|
|
||||||
|
|
||||||
func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
|
|
||||||
s := "\n"
|
|
||||||
for _, h := range hm.Hosts {
|
|
||||||
for _, r := range h.Remotes {
|
|
||||||
s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
|
|
||||||
for _, r := range i.Remotes {
|
|
||||||
if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
|
|
||||||
r.ProbeReceived(counter)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *HostInfo) Probes() []*Probe {
|
|
||||||
p := []*Probe{}
|
|
||||||
for _, d := range i.Remotes {
|
|
||||||
p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()})
|
|
||||||
}
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Utility functions
|
// Utility functions
|
||||||
|
|
||||||
func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
|
func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
|
||||||
|
|
51
inside.go
51
inside.go
|
@ -5,9 +5,13 @@ import (
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||||
|
@ -32,7 +36,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
|
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
|
f.l.WithField("vpnIp", fwPacket.RemoteIP).
|
||||||
WithField("fwPacket", fwPacket).
|
WithField("fwPacket", fwPacket).
|
||||||
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
|
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
|
||||||
}
|
}
|
||||||
|
@ -45,7 +49,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
// the packet queue.
|
// the packet queue.
|
||||||
ci.queueLock.Lock()
|
ci.queueLock.Lock()
|
||||||
if !ci.ready {
|
if !ci.ready {
|
||||||
hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||||
ci.queueLock.Unlock()
|
ci.queueLock.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -54,7 +58,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
|
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
|
||||||
if dropReason == nil {
|
if dropReason == nil {
|
||||||
f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
|
||||||
|
|
||||||
} else if f.l.Level >= logrus.DebugLevel {
|
} else if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).
|
hostinfo.logger(f.l).
|
||||||
|
@ -65,20 +69,21 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrHandshake returns nil if the vpnIp is not routable
|
// getOrHandshake returns nil if the vpnIp is not routable
|
||||||
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
||||||
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
|
//TODO: we can find contains without converting back to bytes
|
||||||
|
if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false {
|
||||||
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
|
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
|
hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
|
||||||
|
|
||||||
//if err != nil || hostinfo.ConnectionState == nil {
|
//if err != nil || hostinfo.ConnectionState == nil {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIp)
|
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
|
hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
|
@ -126,8 +131,8 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
||||||
return hostinfo
|
return hostinfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
|
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
|
||||||
fp := &FirewallPacket{}
|
fp := &firewall.Packet{}
|
||||||
err := newPacket(p, false, fp)
|
err := newPacket(p, false, fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
||||||
|
@ -145,15 +150,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
|
f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
|
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
|
||||||
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
|
||||||
hostInfo := f.getOrHandshake(vpnIp)
|
hostInfo := f.getOrHandshake(vpnIp)
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("vpnIp", IntIp(vpnIp)).
|
f.l.WithField("vpnIp", vpnIp).
|
||||||
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
|
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -175,16 +180,16 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
|
func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
|
||||||
f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
|
f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
|
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
|
||||||
f.messageMetrics.Tx(t, st, 1)
|
f.messageMetrics.Tx(t, st, 1)
|
||||||
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
|
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) {
|
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
|
||||||
if ci.eKey == nil {
|
if ci.eKey == nil {
|
||||||
//TODO: log warning
|
//TODO: log warning
|
||||||
return
|
return
|
||||||
|
@ -196,18 +201,18 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
|
c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
|
||||||
|
|
||||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
||||||
out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c)
|
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
|
||||||
f.connectionManager.Out(hostinfo.hostId)
|
f.connectionManager.Out(hostinfo.vpnIp)
|
||||||
|
|
||||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
||||||
// all our IPs and enable a faster roaming.
|
// all our IPs and enable a faster roaming.
|
||||||
if t != closeTunnel && hostinfo.lastRebindCount != f.rebindCount {
|
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
|
||||||
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
||||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||||
f.lightHouse.QueryServer(hostinfo.hostId, f)
|
f.lightHouse.QueryServer(hostinfo.vpnIp, f)
|
||||||
hostinfo.lastRebindCount = f.rebindCount
|
hostinfo.lastRebindCount = f.rebindCount
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
|
f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,7 +235,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func isMulticast(ip uint32) bool {
|
func isMulticast(ip iputil.VpnIp) bool {
|
||||||
// Class D multicast
|
// Class D multicast
|
||||||
if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
|
if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
|
||||||
return true
|
return true
|
||||||
|
|
47
interface.go
47
interface.go
|
@ -12,6 +12,10 @@ import (
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const mtu = 9001
|
||||||
|
@ -27,7 +31,7 @@ type Inside interface {
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
Outside *udpConn
|
Outside *udp.Conn
|
||||||
Inside Inside
|
Inside Inside
|
||||||
certState *CertState
|
certState *CertState
|
||||||
Cipher string
|
Cipher string
|
||||||
|
@ -39,7 +43,6 @@ type InterfaceConfig struct {
|
||||||
pendingDeletionInterval int
|
pendingDeletionInterval int
|
||||||
DropLocalBroadcast bool
|
DropLocalBroadcast bool
|
||||||
DropMulticast bool
|
DropMulticast bool
|
||||||
UDPBatchSize int
|
|
||||||
routines int
|
routines int
|
||||||
MessageMetrics *MessageMetrics
|
MessageMetrics *MessageMetrics
|
||||||
version string
|
version string
|
||||||
|
@ -52,7 +55,7 @@ type InterfaceConfig struct {
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
outside *udpConn
|
outside *udp.Conn
|
||||||
inside Inside
|
inside Inside
|
||||||
certState *CertState
|
certState *CertState
|
||||||
cipher string
|
cipher string
|
||||||
|
@ -62,11 +65,10 @@ type Interface struct {
|
||||||
serveDns bool
|
serveDns bool
|
||||||
createTime time.Time
|
createTime time.Time
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
localBroadcast uint32
|
localBroadcast iputil.VpnIp
|
||||||
myVpnIp uint32
|
myVpnIp iputil.VpnIp
|
||||||
dropLocalBroadcast bool
|
dropLocalBroadcast bool
|
||||||
dropMulticast bool
|
dropMulticast bool
|
||||||
udpBatchSize int
|
|
||||||
routines int
|
routines int
|
||||||
caPool *cert.NebulaCAPool
|
caPool *cert.NebulaCAPool
|
||||||
disconnectInvalid bool
|
disconnectInvalid bool
|
||||||
|
@ -77,7 +79,7 @@ type Interface struct {
|
||||||
|
|
||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
|
|
||||||
writers []*udpConn
|
writers []*udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
|
@ -101,6 +103,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||||
return nil, errors.New("no firewall rules")
|
return nil, errors.New("no firewall rules")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: c.HostMap,
|
hostMap: c.HostMap,
|
||||||
outside: c.Outside,
|
outside: c.Outside,
|
||||||
|
@ -112,17 +115,16 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||||
handshakeManager: c.HandshakeManager,
|
handshakeManager: c.HandshakeManager,
|
||||||
createTime: time.Now(),
|
createTime: time.Now(),
|
||||||
lightHouse: c.lightHouse,
|
lightHouse: c.lightHouse,
|
||||||
localBroadcast: ip2int(c.certState.certificate.Details.Ips[0].IP) | ^ip2int(c.certState.certificate.Details.Ips[0].Mask),
|
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
|
||||||
dropLocalBroadcast: c.DropLocalBroadcast,
|
dropLocalBroadcast: c.DropLocalBroadcast,
|
||||||
dropMulticast: c.DropMulticast,
|
dropMulticast: c.DropMulticast,
|
||||||
udpBatchSize: c.UDPBatchSize,
|
|
||||||
routines: c.routines,
|
routines: c.routines,
|
||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]*udpConn, c.routines),
|
writers: make([]*udp.Conn, c.routines),
|
||||||
readers: make([]io.ReadWriteCloser, c.routines),
|
readers: make([]io.ReadWriteCloser, c.routines),
|
||||||
caPool: c.caPool,
|
caPool: c.caPool,
|
||||||
disconnectInvalid: c.disconnectInvalid,
|
disconnectInvalid: c.disconnectInvalid,
|
||||||
myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP),
|
myVpnIp: myVpnIp,
|
||||||
|
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
|
|
||||||
|
@ -190,14 +192,17 @@ func (f *Interface) run() {
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li *udpConn
|
var li *udp.Conn
|
||||||
// TODO clean this up with a coherent interface for each outside connection
|
// TODO clean this up with a coherent interface for each outside connection
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[i]
|
||||||
} else {
|
} else {
|
||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
li.ListenOut(f, i)
|
|
||||||
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
|
@ -205,10 +210,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
|
|
||||||
packet := make([]byte, mtu)
|
packet := make([]byte, mtu)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
fwPacket := &FirewallPacket{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
n, err := reader.Read(packet)
|
||||||
|
@ -222,16 +227,16 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
c.RegisterReloadCallback(f.reloadCA)
|
c.RegisterReloadCallback(f.reloadCA)
|
||||||
c.RegisterReloadCallback(f.reloadCertKey)
|
c.RegisterReloadCallback(f.reloadCertKey)
|
||||||
c.RegisterReloadCallback(f.reloadFirewall)
|
c.RegisterReloadCallback(f.reloadFirewall)
|
||||||
for _, udpConn := range f.writers {
|
for _, udpConn := range f.writers {
|
||||||
c.RegisterReloadCallback(udpConn.reloadConfig)
|
c.RegisterReloadCallback(udpConn.ReloadConfig)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadCA(c *Config) {
|
func (f *Interface) reloadCA(c *config.C) {
|
||||||
// reload and check regardless
|
// reload and check regardless
|
||||||
// todo: need mutex?
|
// todo: need mutex?
|
||||||
newCAs, err := loadCAFromConfig(f.l, c)
|
newCAs, err := loadCAFromConfig(f.l, c)
|
||||||
|
@ -244,7 +249,7 @@ func (f *Interface) reloadCA(c *Config) {
|
||||||
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadCertKey(c *Config) {
|
func (f *Interface) reloadCertKey(c *config.C) {
|
||||||
// reload and check in all cases
|
// reload and check in all cases
|
||||||
cs, err := NewCertStateFromConfig(c)
|
cs, err := NewCertStateFromConfig(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -264,7 +269,7 @@ func (f *Interface) reloadCertKey(c *Config) {
|
||||||
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadFirewall(c *Config) {
|
func (f *Interface) reloadFirewall(c *config.C) {
|
||||||
//TODO: need to trigger/detect if the certificate changed too
|
//TODO: need to trigger/detect if the certificate changed too
|
||||||
if c.HasChanged("firewall") == false {
|
if c.HasChanged("firewall") == false {
|
||||||
f.l.Debug("No firewall config change detected")
|
f.l.Debug("No firewall config change detected")
|
||||||
|
@ -307,7 +312,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||||
ticker := time.NewTicker(i)
|
ticker := time.NewTicker(i)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
udpStats := NewUDPStatsEmitter(f.writers)
|
udpStats := udp.NewUDPStatsEmitter(f.writers)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
package iputil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VpnIp uint32
|
||||||
|
|
||||||
|
const maxIPv4StringLen = len("255.255.255.255")
|
||||||
|
|
||||||
|
func (ip VpnIp) String() string {
|
||||||
|
b := make([]byte, maxIPv4StringLen)
|
||||||
|
|
||||||
|
n := ubtoa(b, 0, byte(ip>>24))
|
||||||
|
b[n] = '.'
|
||||||
|
n++
|
||||||
|
|
||||||
|
n += ubtoa(b, n, byte(ip>>16&255))
|
||||||
|
b[n] = '.'
|
||||||
|
n++
|
||||||
|
|
||||||
|
n += ubtoa(b, n, byte(ip>>8&255))
|
||||||
|
b[n] = '.'
|
||||||
|
n++
|
||||||
|
|
||||||
|
n += ubtoa(b, n, byte(ip&255))
|
||||||
|
return string(b[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip VpnIp) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip VpnIp) ToIP() net.IP {
|
||||||
|
nip := make(net.IP, 4)
|
||||||
|
binary.BigEndian.PutUint32(nip, uint32(ip))
|
||||||
|
return nip
|
||||||
|
}
|
||||||
|
|
||||||
|
func Ip2VpnIp(ip []byte) VpnIp {
|
||||||
|
if len(ip) == 16 {
|
||||||
|
return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
|
||||||
|
}
|
||||||
|
return VpnIp(binary.BigEndian.Uint32(ip))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ubtoa encodes the string form of the integer v to dst[start:] and
|
||||||
|
// returns the number of bytes written to dst. The caller must ensure
|
||||||
|
// that dst has sufficient length.
|
||||||
|
func ubtoa(dst []byte, start int, v byte) int {
|
||||||
|
if v < 10 {
|
||||||
|
dst[start] = v + '0'
|
||||||
|
return 1
|
||||||
|
} else if v < 100 {
|
||||||
|
dst[start+1] = v%10 + '0'
|
||||||
|
dst[start] = v/10 + '0'
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[start+2] = v%10 + '0'
|
||||||
|
dst[start+1] = (v/10)%10 + '0'
|
||||||
|
dst[start] = v/100 + '0'
|
||||||
|
return 3
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
package iputil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVpnIp_String(t *testing.T) {
|
||||||
|
assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
|
||||||
|
assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
|
||||||
|
assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
|
||||||
|
assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
|
||||||
|
assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
|
||||||
|
assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
|
||||||
|
}
|
163
lighthouse.go
163
lighthouse.go
|
@ -12,6 +12,9 @@ import (
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
|
//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
|
||||||
|
@ -23,13 +26,13 @@ type LightHouse struct {
|
||||||
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
|
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
|
||||||
sync.RWMutex //Because we concurrently read and write to our maps
|
sync.RWMutex //Because we concurrently read and write to our maps
|
||||||
amLighthouse bool
|
amLighthouse bool
|
||||||
myVpnIp uint32
|
myVpnIp iputil.VpnIp
|
||||||
myVpnZeros uint32
|
myVpnZeros iputil.VpnIp
|
||||||
punchConn *udpConn
|
punchConn *udp.Conn
|
||||||
|
|
||||||
// Local cache of answers from light houses
|
// Local cache of answers from light houses
|
||||||
// map of vpn Ip to answers
|
// map of vpn Ip to answers
|
||||||
addrMap map[uint32]*RemoteList
|
addrMap map[iputil.VpnIp]*RemoteList
|
||||||
|
|
||||||
// filters remote addresses allowed for each host
|
// filters remote addresses allowed for each host
|
||||||
// - When we are a lighthouse, this filters what addresses we store and
|
// - When we are a lighthouse, this filters what addresses we store and
|
||||||
|
@ -42,12 +45,12 @@ type LightHouse struct {
|
||||||
localAllowList *LocalAllowList
|
localAllowList *LocalAllowList
|
||||||
|
|
||||||
// used to trigger the HandshakeManager when we receive HostQueryReply
|
// used to trigger the HandshakeManager when we receive HostQueryReply
|
||||||
handshakeTrigger chan<- uint32
|
handshakeTrigger chan<- iputil.VpnIp
|
||||||
|
|
||||||
// staticList exists to avoid having a bool in each addrMap entry
|
// staticList exists to avoid having a bool in each addrMap entry
|
||||||
// since static should be rare
|
// since static should be rare
|
||||||
staticList map[uint32]struct{}
|
staticList map[iputil.VpnIp]struct{}
|
||||||
lighthouses map[uint32]struct{}
|
lighthouses map[iputil.VpnIp]struct{}
|
||||||
interval int
|
interval int
|
||||||
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
|
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
|
||||||
punchBack bool
|
punchBack bool
|
||||||
|
@ -58,20 +61,16 @@ type LightHouse struct {
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []iputil.VpnIp, interval int, nebulaPort uint32, pc *udp.Conn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
|
||||||
SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
|
|
||||||
ones, _ := myVpnIpNet.Mask.Size()
|
ones, _ := myVpnIpNet.Mask.Size()
|
||||||
h := LightHouse{
|
h := LightHouse{
|
||||||
amLighthouse: amLighthouse,
|
amLighthouse: amLighthouse,
|
||||||
myVpnIp: ip2int(myVpnIpNet.IP),
|
myVpnIp: iputil.Ip2VpnIp(myVpnIpNet.IP),
|
||||||
myVpnZeros: uint32(32 - ones),
|
myVpnZeros: iputil.VpnIp(32 - ones),
|
||||||
addrMap: make(map[uint32]*RemoteList),
|
addrMap: make(map[iputil.VpnIp]*RemoteList),
|
||||||
nebulaPort: nebulaPort,
|
nebulaPort: nebulaPort,
|
||||||
lighthouses: make(map[uint32]struct{}),
|
lighthouses: make(map[iputil.VpnIp]struct{}),
|
||||||
staticList: make(map[uint32]struct{}),
|
staticList: make(map[iputil.VpnIp]struct{}),
|
||||||
interval: interval,
|
interval: interval,
|
||||||
punchConn: pc,
|
punchConn: pc,
|
||||||
punchBack: punchBack,
|
punchBack: punchBack,
|
||||||
|
@ -111,13 +110,13 @@ func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) {
|
||||||
func (lh *LightHouse) ValidateLHStaticEntries() error {
|
func (lh *LightHouse) ValidateLHStaticEntries() error {
|
||||||
for lhIP, _ := range lh.lighthouses {
|
for lhIP, _ := range lh.lighthouses {
|
||||||
if _, ok := lh.staticList[lhIP]; !ok {
|
if _, ok := lh.staticList[lhIP]; !ok {
|
||||||
return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", IntIp(lhIP))
|
return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", lhIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
|
func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
|
||||||
if !lh.IsLighthouseIP(ip) {
|
if !lh.IsLighthouseIP(ip) {
|
||||||
lh.QueryServer(ip, f)
|
lh.QueryServer(ip, f)
|
||||||
}
|
}
|
||||||
|
@ -131,7 +130,7 @@ func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is asynchronous so no reply should be expected
|
// This is asynchronous so no reply should be expected
|
||||||
func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
|
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
|
||||||
if lh.amLighthouse {
|
if lh.amLighthouse {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -143,7 +142,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
|
||||||
// Send a query to the lighthouses and hope for the best next time
|
// Send a query to the lighthouses and hope for the best next time
|
||||||
query, err := proto.Marshal(NewLhQueryByInt(ip))
|
query, err := proto.Marshal(NewLhQueryByInt(ip))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
|
lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,11 +150,11 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
for n := range lh.lighthouses {
|
for n := range lh.lighthouses {
|
||||||
f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
|
f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
|
func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
|
||||||
lh.RLock()
|
lh.RLock()
|
||||||
if v, ok := lh.addrMap[ip]; ok {
|
if v, ok := lh.addrMap[ip]; ok {
|
||||||
lh.RUnlock()
|
lh.RUnlock()
|
||||||
|
@ -172,7 +171,7 @@ func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
|
||||||
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
||||||
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
|
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
|
||||||
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
|
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
|
||||||
func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) {
|
func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
|
||||||
lh.RLock()
|
lh.RLock()
|
||||||
// Do we have an entry in the main cache?
|
// Do we have an entry in the main cache?
|
||||||
if v, ok := lh.addrMap[vpnIp]; ok {
|
if v, ok := lh.addrMap[vpnIp]; ok {
|
||||||
|
@ -195,18 +194,18 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, err
|
||||||
return false, 0, nil
|
return false, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
|
func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
|
||||||
// First we check the static mapping
|
// First we check the static mapping
|
||||||
// and do nothing if it is there
|
// and do nothing if it is there
|
||||||
if _, ok := lh.staticList[vpnIP]; ok {
|
if _, ok := lh.staticList[vpnIp]; ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lh.Lock()
|
lh.Lock()
|
||||||
//l.Debugln(lh.addrMap)
|
//l.Debugln(lh.addrMap)
|
||||||
delete(lh.addrMap, vpnIP)
|
delete(lh.addrMap, vpnIp)
|
||||||
|
|
||||||
if lh.l.Level >= logrus.DebugLevel {
|
if lh.l.Level >= logrus.DebugLevel {
|
||||||
lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
|
lh.l.Debugf("deleting %s from lighthouse.", vpnIp)
|
||||||
}
|
}
|
||||||
|
|
||||||
lh.Unlock()
|
lh.Unlock()
|
||||||
|
@ -215,7 +214,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
|
||||||
// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
|
// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
|
||||||
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
|
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
|
||||||
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
|
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
|
||||||
func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
|
func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) {
|
||||||
lh.Lock()
|
lh.Lock()
|
||||||
am := lh.unlockedGetRemoteList(vpnIp)
|
am := lh.unlockedGetRemoteList(vpnIp)
|
||||||
am.Lock()
|
am.Lock()
|
||||||
|
@ -242,23 +241,23 @@ func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedGetRemoteList assumes you have the lh lock
|
// unlockedGetRemoteList assumes you have the lh lock
|
||||||
func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList {
|
func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
|
||||||
am, ok := lh.addrMap[vpnIP]
|
am, ok := lh.addrMap[vpnIp]
|
||||||
if !ok {
|
if !ok {
|
||||||
am = NewRemoteList()
|
am = NewRemoteList()
|
||||||
lh.addrMap[vpnIP] = am
|
lh.addrMap[vpnIp] = am
|
||||||
}
|
}
|
||||||
return am
|
return am
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedShouldAddV4 checks if to is allowed by our allow list
|
// unlockedShouldAddV4 checks if to is allowed by our allow list
|
||||||
func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool {
|
func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
|
||||||
allow := lh.remoteAllowList.AllowIpV4(vpnIp, to.Ip)
|
allow := lh.remoteAllowList.AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Level >= logrus.TraceLevel {
|
||||||
lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, to.Ip) {
|
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,7 +265,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedShouldAddV6 checks if to is allowed by our allow list
|
// unlockedShouldAddV6 checks if to is allowed by our allow list
|
||||||
func (lh *LightHouse) unlockedShouldAddV6(vpnIp uint32, to *Ip6AndPort) bool {
|
func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
|
||||||
allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo)
|
allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo)
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Level >= logrus.TraceLevel {
|
||||||
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
||||||
|
@ -287,25 +286,25 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
|
||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
|
func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
|
||||||
if _, ok := lh.lighthouses[vpnIP]; ok {
|
if _, ok := lh.lighthouses[vpnIp]; ok {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
|
func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
|
||||||
return &NebulaMeta{
|
return &NebulaMeta{
|
||||||
Type: NebulaMeta_HostQuery,
|
Type: NebulaMeta_HostQuery,
|
||||||
Details: &NebulaMetaDetails{
|
Details: &NebulaMetaDetails{
|
||||||
VpnIp: VpnIp,
|
VpnIp: uint32(VpnIp),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
|
func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
|
||||||
ipp := Ip4AndPort{Port: port}
|
ipp := Ip4AndPort{Port: port}
|
||||||
ipp.Ip = ip2int(ip)
|
ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
|
||||||
return &ipp
|
return &ipp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -317,19 +316,19 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udpAddr {
|
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
|
||||||
ip := ipp.Ip
|
ip := ipp.Ip
|
||||||
return NewUDPAddr(
|
return udp.NewAddr(
|
||||||
net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
|
net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
|
||||||
uint16(ipp.Port),
|
uint16(ipp.Port),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr {
|
func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
|
||||||
return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
|
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
|
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
|
||||||
if lh.amLighthouse || lh.interval == 0 {
|
if lh.amLighthouse || lh.interval == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -349,12 +348,12 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) SendUpdate(f EncWriter) {
|
func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
|
||||||
var v4 []*Ip4AndPort
|
var v4 []*Ip4AndPort
|
||||||
var v6 []*Ip6AndPort
|
var v6 []*Ip6AndPort
|
||||||
|
|
||||||
for _, e := range *localIps(lh.l, lh.localAllowList) {
|
for _, e := range *localIps(lh.l, lh.localAllowList) {
|
||||||
if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip2int(ip4)) {
|
if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -368,7 +367,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
|
||||||
m := &NebulaMeta{
|
m := &NebulaMeta{
|
||||||
Type: NebulaMeta_HostUpdateNotification,
|
Type: NebulaMeta_HostUpdateNotification,
|
||||||
Details: &NebulaMetaDetails{
|
Details: &NebulaMetaDetails{
|
||||||
VpnIp: lh.myVpnIp,
|
VpnIp: uint32(lh.myVpnIp),
|
||||||
Ip4AndPorts: v4,
|
Ip4AndPorts: v4,
|
||||||
Ip6AndPorts: v6,
|
Ip6AndPorts: v6,
|
||||||
},
|
},
|
||||||
|
@ -385,7 +384,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for vpnIp := range lh.lighthouses {
|
for vpnIp := range lh.lighthouses {
|
||||||
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
|
f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -415,11 +414,11 @@ func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
|
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
|
||||||
lh.metrics.Rx(NebulaMessageType(t), 0, i)
|
lh.metrics.Rx(header.MessageType(t), 0, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
|
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
|
||||||
lh.metrics.Tx(NebulaMessageType(t), 0, i)
|
lh.metrics.Tx(header.MessageType(t), 0, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This method is similar to Reset(), but it re-uses the pointer structs
|
// This method is similar to Reset(), but it re-uses the pointer structs
|
||||||
|
@ -436,18 +435,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
|
||||||
return lhh.meta
|
return lhh.meta
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
|
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
|
||||||
n := lhh.resetMeta()
|
n := lhh.resetMeta()
|
||||||
err := n.Unmarshal(p)
|
err := n.Unmarshal(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
|
||||||
Error("Failed to unmarshal lighthouse packet")
|
Error("Failed to unmarshal lighthouse packet")
|
||||||
//TODO: send recv_error?
|
//TODO: send recv_error?
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.Details == nil {
|
if n.Details == nil {
|
||||||
lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
|
||||||
Error("Invalid lighthouse update")
|
Error("Invalid lighthouse update")
|
||||||
//TODO: send recv_error?
|
//TODO: send recv_error?
|
||||||
return
|
return
|
||||||
|
@ -471,7 +470,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr *udpAddr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
|
||||||
// Exit if we don't answer queries
|
// Exit if we don't answer queries
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
|
@ -481,12 +480,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: we can DRY this further
|
//TODO: we can DRY this further
|
||||||
reqVpnIP := n.Details.VpnIp
|
reqVpnIp := n.Details.VpnIp
|
||||||
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
|
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
|
||||||
found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) {
|
found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
|
||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostQueryReply
|
n.Type = NebulaMeta_HostQueryReply
|
||||||
n.Details.VpnIp = reqVpnIP
|
n.Details.VpnIp = reqVpnIp
|
||||||
|
|
||||||
lhh.coalesceAnswers(c, n)
|
lhh.coalesceAnswers(c, n)
|
||||||
|
|
||||||
|
@ -498,18 +497,18 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
|
lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
||||||
w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||||
|
|
||||||
// This signals the other side to punch some zero byte udp packets
|
// This signals the other side to punch some zero byte udp packets
|
||||||
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
|
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
|
||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostPunchNotification
|
n.Type = NebulaMeta_HostPunchNotification
|
||||||
n.Details.VpnIp = vpnIp
|
n.Details.VpnIp = uint32(vpnIp)
|
||||||
|
|
||||||
lhh.coalesceAnswers(c, n)
|
lhh.coalesceAnswers(c, n)
|
||||||
|
|
||||||
|
@ -521,12 +520,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host was queried for")
|
lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
|
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
|
||||||
w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
|
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
|
||||||
|
@ -549,28 +548,29 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
|
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
|
||||||
if !lhh.lh.IsLighthouseIP(vpnIp) {
|
if !lhh.lh.IsLighthouseIP(vpnIp) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.Lock()
|
lhh.lh.Lock()
|
||||||
am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp)
|
am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
|
||||||
am.Lock()
|
am.Lock()
|
||||||
lhh.lh.Unlock()
|
lhh.lh.Unlock()
|
||||||
|
|
||||||
am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
|
certVpnIp := iputil.VpnIp(n.Details.VpnIp)
|
||||||
am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
|
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
|
||||||
|
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
|
||||||
am.Unlock()
|
am.Unlock()
|
||||||
|
|
||||||
// Non-blocking attempt to trigger, skip if it would block
|
// Non-blocking attempt to trigger, skip if it would block
|
||||||
select {
|
select {
|
||||||
case lhh.lh.handshakeTrigger <- n.Details.VpnIp:
|
case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp uint32) {
|
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp) {
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
|
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
|
||||||
|
@ -579,9 +579,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
|
||||||
}
|
}
|
||||||
|
|
||||||
//Simple check that the host sent this not someone else
|
//Simple check that the host sent this not someone else
|
||||||
if n.Details.VpnIp != vpnIp {
|
if n.Details.VpnIp != uint32(vpnIp) {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
|
lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -591,18 +591,19 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
|
||||||
am.Lock()
|
am.Lock()
|
||||||
lhh.lh.Unlock()
|
lhh.lh.Unlock()
|
||||||
|
|
||||||
am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
|
certVpnIp := iputil.VpnIp(n.Details.VpnIp)
|
||||||
am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
|
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
|
||||||
|
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
|
||||||
am.Unlock()
|
am.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
|
||||||
if !lhh.lh.IsLighthouseIP(vpnIp) {
|
if !lhh.lh.IsLighthouseIP(vpnIp) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
empty := []byte{0}
|
empty := []byte{0}
|
||||||
punch := func(vpnPeer *udpAddr) {
|
punch := func(vpnPeer *udp.Addr) {
|
||||||
if vpnPeer == nil {
|
if vpnPeer == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -615,7 +616,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
|
||||||
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
||||||
lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, IntIp(n.Details.VpnIp))
|
lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -634,18 +635,18 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Second * 5)
|
time.Sleep(time.Second * 5)
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
|
lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", iputil.VpnIp(n.Details.VpnIp))
|
||||||
}
|
}
|
||||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||||
// managed by a channel.
|
// managed by a channel.
|
||||||
w.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
w.SendMessageToVpnIp(header.Test, header.TestRequest, iputil.VpnIp(n.Details.VpnIp), []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ipMaskContains checks if testIp is contained by ip after applying a cidr
|
// ipMaskContains checks if testIp is contained by ip after applying a cidr
|
||||||
// zeros is 32 - bits from net.IPMask.Size()
|
// zeros is 32 - bits from net.IPMask.Size()
|
||||||
func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {
|
func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
|
||||||
return (testIp^ip)>>zeros == 0
|
return (testIp^ip)>>zeros == 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,12 +21,12 @@ func TestOldIPv4Only(t *testing.T) {
|
||||||
var m Ip4AndPort
|
var m Ip4AndPort
|
||||||
err := proto.Unmarshal(b, &m)
|
err := proto.Unmarshal(b, &m)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "10.1.1.1", int2ip(m.GetIp()).String())
|
assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewLhQuery(t *testing.T) {
|
func TestNewLhQuery(t *testing.T) {
|
||||||
myIp := net.ParseIP("192.1.1.1")
|
myIp := net.ParseIP("192.1.1.1")
|
||||||
myIpint := ip2int(myIp)
|
myIpint := iputil.Ip2VpnIp(myIp)
|
||||||
|
|
||||||
// Generating a new lh query should work
|
// Generating a new lh query should work
|
||||||
a := NewLhQueryByInt(myIpint)
|
a := NewLhQueryByInt(myIpint)
|
||||||
|
@ -42,37 +46,37 @@ func TestNewLhQuery(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_lhStaticMapping(t *testing.T) {
|
func Test_lhStaticMapping(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
lh1IP := net.ParseIP(lh1)
|
lh1IP := net.ParseIP(lh1)
|
||||||
|
|
||||||
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
|
||||||
|
|
||||||
meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||||
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
|
meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
|
||||||
err := meh.ValidateLHStaticEntries()
|
err := meh.ValidateLHStaticEntries()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
lh2 := "10.128.0.3"
|
lh2 := "10.128.0.3"
|
||||||
lh2IP := net.ParseIP(lh2)
|
lh2IP := net.ParseIP(lh2)
|
||||||
|
|
||||||
meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
|
meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP), iputil.Ip2VpnIp(lh2IP)}, 10, 10003, udpServer, false, 1, false)
|
||||||
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
|
meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
|
||||||
err = meh.ValidateLHStaticEntries()
|
err = meh.ValidateLHStaticEntries()
|
||||||
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
|
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
lh1IP := net.ParseIP(lh1)
|
lh1IP := net.ParseIP(lh1)
|
||||||
|
|
||||||
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
|
||||||
|
|
||||||
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||||
|
|
||||||
hAddr := NewUDPAddrFromString("4.5.6.7:12345")
|
hAddr := udp.NewAddrFromString("4.5.6.7:12345")
|
||||||
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
|
hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
|
||||||
lh.addrMap[3] = NewRemoteList()
|
lh.addrMap[3] = NewRemoteList()
|
||||||
lh.addrMap[3].unlockedSetV4(
|
lh.addrMap[3].unlockedSetV4(
|
||||||
3,
|
3,
|
||||||
|
@ -81,11 +85,11 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
|
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
|
||||||
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
|
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
|
||||||
},
|
},
|
||||||
func(uint32, *Ip4AndPort) bool { return true },
|
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rAddr := NewUDPAddrFromString("1.2.2.3:12345")
|
rAddr := udp.NewAddrFromString("1.2.2.3:12345")
|
||||||
rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
|
rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
|
||||||
lh.addrMap[2] = NewRemoteList()
|
lh.addrMap[2] = NewRemoteList()
|
||||||
lh.addrMap[2].unlockedSetV4(
|
lh.addrMap[2].unlockedSetV4(
|
||||||
3,
|
3,
|
||||||
|
@ -94,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
|
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
|
||||||
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
|
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
|
||||||
},
|
},
|
||||||
func(uint32, *Ip4AndPort) bool { return true },
|
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
|
@ -133,50 +137,50 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLighthouse_Memory(t *testing.T) {
|
func TestLighthouse_Memory(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
|
|
||||||
myUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
|
myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
|
||||||
myUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
|
myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
|
||||||
myUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
|
myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
|
||||||
myUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
|
myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
|
||||||
myUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
|
myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
|
||||||
myUdpAddr5 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
|
myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
|
||||||
myUdpAddr6 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
|
myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
|
||||||
myUdpAddr7 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
|
myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
|
||||||
myUdpAddr8 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
|
myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
|
||||||
myUdpAddr9 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
|
myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
|
||||||
myUdpAddr10 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
|
myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
|
||||||
myUdpAddr11 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
|
myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
|
||||||
myVpnIp := ip2int(net.ParseIP("10.128.0.2"))
|
myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
|
||||||
|
|
||||||
theirUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
|
theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
|
||||||
theirUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
|
theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
|
||||||
theirUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
|
theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
|
||||||
theirUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
|
theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
|
||||||
theirUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
|
theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
|
||||||
theirVpnIp := ip2int(net.ParseIP("10.128.0.3"))
|
theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
|
||||||
|
|
||||||
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
|
||||||
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{}, 10, 10003, udpServer, false, 1, false)
|
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []iputil.VpnIp{}, 10, 10003, udpServer, false, 1, false)
|
||||||
lhh := lh.NewRequestHandler()
|
lhh := lh.NewRequestHandler()
|
||||||
|
|
||||||
// Test that my first update responds with just that
|
// Test that my first update responds with just that
|
||||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr2}, lhh)
|
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
|
||||||
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
|
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
|
||||||
|
|
||||||
// Ensure we don't accumulate addresses
|
// Ensure we don't accumulate addresses
|
||||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr3}, lhh)
|
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
|
||||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
|
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
|
||||||
|
|
||||||
// Grow it back to 2
|
// Grow it back to 2
|
||||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr4}, lhh)
|
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
|
||||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
|
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
|
||||||
|
|
||||||
// Update a different host
|
// Update a different host
|
||||||
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udpAddr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
|
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
|
||||||
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
||||||
|
|
||||||
|
@ -189,7 +193,7 @@ func TestLighthouse_Memory(t *testing.T) {
|
||||||
newLHHostUpdate(
|
newLHHostUpdate(
|
||||||
myUdpAddr0,
|
myUdpAddr0,
|
||||||
myVpnIp,
|
myVpnIp,
|
||||||
[]*udpAddr{
|
[]*udp.Addr{
|
||||||
myUdpAddr1,
|
myUdpAddr1,
|
||||||
myUdpAddr2,
|
myUdpAddr2,
|
||||||
myUdpAddr3,
|
myUdpAddr3,
|
||||||
|
@ -212,19 +216,19 @@ func TestLighthouse_Memory(t *testing.T) {
|
||||||
)
|
)
|
||||||
|
|
||||||
// Make sure we won't add ips in our vpn network
|
// Make sure we won't add ips in our vpn network
|
||||||
bad1 := &udpAddr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
|
bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
|
||||||
bad2 := &udpAddr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
|
bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
|
||||||
good := &udpAddr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
|
good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
|
||||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{bad1, bad2, good}, lhh)
|
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
|
||||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
|
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightHouseHandler) testLhReply {
|
func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
|
||||||
req := &NebulaMeta{
|
req := &NebulaMeta{
|
||||||
Type: NebulaMeta_HostQuery,
|
Type: NebulaMeta_HostQuery,
|
||||||
Details: &NebulaMetaDetails{
|
Details: &NebulaMetaDetails{
|
||||||
VpnIp: queryVpnIp,
|
VpnIp: uint32(queryVpnIp),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -238,17 +242,17 @@ func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightH
|
||||||
return w.lastReply
|
return w.lastReply
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *LightHouseHandler) {
|
func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
|
||||||
req := &NebulaMeta{
|
req := &NebulaMeta{
|
||||||
Type: NebulaMeta_HostUpdateNotification,
|
Type: NebulaMeta_HostUpdateNotification,
|
||||||
Details: &NebulaMetaDetails{
|
Details: &NebulaMetaDetails{
|
||||||
VpnIp: vpnIp,
|
VpnIp: uint32(vpnIp),
|
||||||
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
|
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range addrs {
|
for k, v := range addrs {
|
||||||
req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: ip2int(v.IP), Port: uint32(v.Port)}
|
req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := req.Marshal()
|
b, err := req.Marshal()
|
||||||
|
@ -327,15 +331,15 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig
|
||||||
//}
|
//}
|
||||||
|
|
||||||
func Test_ipMaskContains(t *testing.T) {
|
func Test_ipMaskContains(t *testing.T) {
|
||||||
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
|
assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
|
||||||
assert.False(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.1.1"))))
|
assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
|
||||||
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32, ip2int(net.ParseIP("10.0.1.1"))))
|
assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
|
||||||
}
|
}
|
||||||
|
|
||||||
type testLhReply struct {
|
type testLhReply struct {
|
||||||
nebType NebulaMessageType
|
nebType header.MessageType
|
||||||
nebSubType NebulaMessageSubType
|
nebSubType header.MessageSubType
|
||||||
vpnIp uint32
|
vpnIp iputil.VpnIp
|
||||||
msg *NebulaMeta
|
msg *NebulaMeta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,7 +347,7 @@ type testEncWriter struct {
|
||||||
lastReply testLhReply
|
lastReply testLhReply
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, _, _ []byte) {
|
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
|
||||||
tw.lastReply = testLhReply{
|
tw.lastReply = testLhReply{
|
||||||
nebType: t,
|
nebType: t,
|
||||||
nebSubType: st,
|
nebSubType: st,
|
||||||
|
@ -358,17 +362,17 @@ func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessag
|
||||||
}
|
}
|
||||||
|
|
||||||
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
||||||
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) {
|
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
|
||||||
assert.Len(t, have, len(want))
|
assert.Len(t, have, len(want))
|
||||||
for k, w := range want {
|
for k, w := range want {
|
||||||
if !(have[k].Ip == ip2int(w.IP) && have[k].Port == uint32(w.Port)) {
|
if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
|
||||||
assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
|
assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
|
// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
|
||||||
func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
|
func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
|
||||||
assert.Len(t, have, len(want))
|
assert.Len(t, have, len(want))
|
||||||
for k, w := range want {
|
for k, w := range want {
|
||||||
if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
|
if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
|
||||||
|
@ -377,8 +381,8 @@ func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
|
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
|
||||||
addrs := make([]*udpAddr, len(ips))
|
addrs := make([]*udp.Addr, len(ips))
|
||||||
for k, v := range ips {
|
for k, v := range ips {
|
||||||
addrs[k] = NewUDPAddrFromLH4(v)
|
addrs[k] = NewUDPAddrFromLH4(v)
|
||||||
}
|
}
|
||||||
|
|
39
logger.go
39
logger.go
|
@ -2,8 +2,12 @@ package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ContextualError struct {
|
type ContextualError struct {
|
||||||
|
@ -37,3 +41,38 @@ func (ce *ContextualError) Log(lr *logrus.Logger) {
|
||||||
lr.WithFields(ce.Fields).Error(ce.Context)
|
lr.WithFields(ce.Fields).Error(ce.Context)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func configLogger(l *logrus.Logger, c *config.C) error {
|
||||||
|
// set up our logging level
|
||||||
|
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
|
||||||
|
}
|
||||||
|
l.SetLevel(logLevel)
|
||||||
|
|
||||||
|
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
||||||
|
timestampFormat := c.GetString("logging.timestamp_format", "")
|
||||||
|
fullTimestamp := (timestampFormat != "")
|
||||||
|
if timestampFormat == "" {
|
||||||
|
timestampFormat = time.RFC3339
|
||||||
|
}
|
||||||
|
|
||||||
|
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
||||||
|
switch logFormat {
|
||||||
|
case "text":
|
||||||
|
l.Formatter = &logrus.TextFormatter{
|
||||||
|
TimestampFormat: timestampFormat,
|
||||||
|
FullTimestamp: fullTimestamp,
|
||||||
|
DisableTimestamp: disableTimestamp,
|
||||||
|
}
|
||||||
|
case "json":
|
||||||
|
l.Formatter = &logrus.JSONFormatter{
|
||||||
|
TimestampFormat: timestampFormat,
|
||||||
|
DisableTimestamp: disableTimestamp,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
139
main.go
139
main.go
|
@ -8,14 +8,16 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m map[string]interface{}
|
||||||
|
|
||||||
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
|
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -31,7 +33,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
|
|
||||||
// Print the config if in test, the exit comes later
|
// Print the config if in test, the exit comes later
|
||||||
if configTest {
|
if configTest {
|
||||||
b, err := yaml.Marshal(config.Settings)
|
b, err := yaml.Marshal(c.Settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -40,33 +42,33 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
l.Println(string(b))
|
l.Println(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := configLogger(config)
|
err := configLogger(l, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to configure the logger", nil, err)
|
return nil, NewContextualError("Failed to configure the logger", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config.RegisterReloadCallback(func(c *Config) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
err := configLogger(c)
|
err := configLogger(l, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to configure the logger")
|
l.WithError(err).Error("Failed to configure the logger")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
caPool, err := loadCAFromConfig(l, config)
|
caPool, err := loadCAFromConfig(l, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of loadCA are already nicely formatted
|
//The errors coming out of loadCA are already nicely formatted
|
||||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||||
|
|
||||||
cs, err := NewCertStateFromConfig(config)
|
cs, err := NewCertStateFromConfig(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||||
return nil, NewContextualError("Failed to load certificate from config", nil, err)
|
return nil, NewContextualError("Failed to load certificate from config", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||||
|
|
||||||
fw, err := NewFirewallFromConfig(l, cs.certificate, config)
|
fw, err := NewFirewallFromConfig(l, cs.certificate, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
||||||
}
|
}
|
||||||
|
@ -74,20 +76,20 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
|
|
||||||
// TODO: make sure mask is 4 bytes
|
// TODO: make sure mask is 4 bytes
|
||||||
tunCidr := cs.certificate.Details.Ips[0]
|
tunCidr := cs.certificate.Details.Ips[0]
|
||||||
routes, err := parseRoutes(config, tunCidr)
|
routes, err := parseRoutes(c, tunCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Could not parse tun.routes", nil, err)
|
return nil, NewContextualError("Could not parse tun.routes", nil, err)
|
||||||
}
|
}
|
||||||
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||||
wireSSHReload(l, ssh, config)
|
wireSSHReload(l, ssh, c)
|
||||||
var sshStart func()
|
var sshStart func()
|
||||||
if config.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, config)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
||||||
}
|
}
|
||||||
|
@ -101,7 +103,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
var routines int
|
var routines int
|
||||||
|
|
||||||
// If `routines` is set, use that and ignore the specific values
|
// If `routines` is set, use that and ignore the specific values
|
||||||
if routines = config.GetInt("routines", 0); routines != 0 {
|
if routines = c.GetInt("routines", 0); routines != 0 {
|
||||||
if routines < 1 {
|
if routines < 1 {
|
||||||
routines = 1
|
routines = 1
|
||||||
}
|
}
|
||||||
|
@ -110,8 +112,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// deprecated and undocumented
|
// deprecated and undocumented
|
||||||
tunQueues := config.GetInt("tun.routines", 1)
|
tunQueues := c.GetInt("tun.routines", 1)
|
||||||
udpQueues := config.GetInt("listen.routines", 1)
|
udpQueues := c.GetInt("listen.routines", 1)
|
||||||
if tunQueues > udpQueues {
|
if tunQueues > udpQueues {
|
||||||
routines = tunQueues
|
routines = tunQueues
|
||||||
} else {
|
} else {
|
||||||
|
@ -125,8 +127,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
// EXPERIMENTAL
|
// EXPERIMENTAL
|
||||||
// Intentionally not documented yet while we do more testing and determine
|
// Intentionally not documented yet while we do more testing and determine
|
||||||
// a good default value.
|
// a good default value.
|
||||||
conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
|
conntrackCacheTimeout := c.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
|
||||||
if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") {
|
if routines > 1 && !c.IsSet("firewall.conntrack.routine_cache_timeout") {
|
||||||
// Use a different default if we are running with multiple routines
|
// Use a different default if we are running with multiple routines
|
||||||
conntrackCacheTimeout = 1 * time.Second
|
conntrackCacheTimeout = 1 * time.Second
|
||||||
}
|
}
|
||||||
|
@ -136,30 +138,30 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
|
|
||||||
var tun Inside
|
var tun Inside
|
||||||
if !configTest {
|
if !configTest {
|
||||||
config.CatchHUP(ctx)
|
c.CatchHUP(ctx)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case config.GetBool("tun.disabled", false):
|
case c.GetBool("tun.disabled", false):
|
||||||
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
|
tun = newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||||
case tunFd != nil:
|
case tunFd != nil:
|
||||||
tun, err = newTunFromFd(
|
tun, err = newTunFromFd(
|
||||||
l,
|
l,
|
||||||
*tunFd,
|
*tunFd,
|
||||||
tunCidr,
|
tunCidr,
|
||||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
c.GetInt("tun.mtu", DEFAULT_MTU),
|
||||||
routes,
|
routes,
|
||||||
unsafeRoutes,
|
unsafeRoutes,
|
||||||
config.GetInt("tun.tx_queue", 500),
|
c.GetInt("tun.tx_queue", 500),
|
||||||
)
|
)
|
||||||
default:
|
default:
|
||||||
tun, err = newTun(
|
tun, err = newTun(
|
||||||
l,
|
l,
|
||||||
config.GetString("tun.dev", ""),
|
c.GetString("tun.dev", ""),
|
||||||
tunCidr,
|
tunCidr,
|
||||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
c.GetInt("tun.mtu", DEFAULT_MTU),
|
||||||
routes,
|
routes,
|
||||||
unsafeRoutes,
|
unsafeRoutes,
|
||||||
config.GetInt("tun.tx_queue", 500),
|
c.GetInt("tun.tx_queue", 500),
|
||||||
routines > 1,
|
routines > 1,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -176,16 +178,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// set up our UDP listener
|
// set up our UDP listener
|
||||||
udpConns := make([]*udpConn, routines)
|
udpConns := make([]*udp.Conn, routines)
|
||||||
port := config.GetInt("listen.port", 0)
|
port := c.GetInt("listen.port", 0)
|
||||||
|
|
||||||
if !configTest {
|
if !configTest {
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
|
udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
}
|
}
|
||||||
udpServer.reloadConfig(config)
|
udpServer.ReloadConfig(c)
|
||||||
udpConns[i] = udpServer
|
udpConns[i] = udpServer
|
||||||
|
|
||||||
// If port is dynamic, discover it
|
// If port is dynamic, discover it
|
||||||
|
@ -201,7 +203,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
|
|
||||||
// Set up my internal host map
|
// Set up my internal host map
|
||||||
var preferredRanges []*net.IPNet
|
var preferredRanges []*net.IPNet
|
||||||
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
|
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
|
||||||
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
|
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
|
||||||
if len(rawPreferredRanges) > 0 {
|
if len(rawPreferredRanges) > 0 {
|
||||||
for _, rawPreferredRange := range rawPreferredRanges {
|
for _, rawPreferredRange := range rawPreferredRanges {
|
||||||
|
@ -216,7 +218,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
// local_range was superseded by preferred_ranges. If it is still present,
|
// local_range was superseded by preferred_ranges. If it is still present,
|
||||||
// merge the local_range setting into preferred_ranges. We will probably
|
// merge the local_range setting into preferred_ranges. We will probably
|
||||||
// deprecate local_range and remove in the future.
|
// deprecate local_range and remove in the future.
|
||||||
rawLocalRange := config.GetString("local_range", "")
|
rawLocalRange := c.GetString("local_range", "")
|
||||||
if rawLocalRange != "" {
|
if rawLocalRange != "" {
|
||||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -240,7 +242,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
|
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
|
||||||
|
|
||||||
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
||||||
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
|
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
|
||||||
|
|
||||||
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
|
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
|
||||||
|
|
||||||
|
@ -249,26 +251,26 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
go hostMap.Promoter(config.GetInt("promoter.interval"))
|
go hostMap.Promoter(config.GetInt("promoter.interval"))
|
||||||
*/
|
*/
|
||||||
|
|
||||||
punchy := NewPunchyFromConfig(config)
|
punchy := NewPunchyFromConfig(c)
|
||||||
if punchy.Punch && !configTest {
|
if punchy.Punch && !configTest {
|
||||||
l.Info("UDP hole punching enabled")
|
l.Info("UDP hole punching enabled")
|
||||||
go hostMap.Punchy(ctx, udpConns[0])
|
go hostMap.Punchy(ctx, udpConns[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
|
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||||
|
|
||||||
// fatal if am_lighthouse is enabled but we are using an ephemeral port
|
// fatal if am_lighthouse is enabled but we are using an ephemeral port
|
||||||
if amLighthouse && (config.GetInt("listen.port", 0) == 0) {
|
if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
|
||||||
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
|
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// warn if am_lighthouse is enabled but upstream lighthouses exists
|
// warn if am_lighthouse is enabled but upstream lighthouses exists
|
||||||
rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{})
|
rawLighthouseHosts := c.GetStringSlice("lighthouse.hosts", []string{})
|
||||||
if amLighthouse && len(rawLighthouseHosts) != 0 {
|
if amLighthouse && len(rawLighthouseHosts) != 0 {
|
||||||
l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
||||||
}
|
}
|
||||||
|
|
||||||
lighthouseHosts := make([]uint32, len(rawLighthouseHosts))
|
lighthouseHosts := make([]iputil.VpnIp, len(rawLighthouseHosts))
|
||||||
for i, host := range rawLighthouseHosts {
|
for i, host := range rawLighthouseHosts {
|
||||||
ip := net.ParseIP(host)
|
ip := net.ParseIP(host)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
|
@ -277,7 +279,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
if !tunCidr.Contains(ip) {
|
if !tunCidr.Contains(ip) {
|
||||||
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||||
}
|
}
|
||||||
lighthouseHosts[i] = ip2int(ip)
|
lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
lightHouse := NewLightHouse(
|
lightHouse := NewLightHouse(
|
||||||
|
@ -286,47 +288,48 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
tunCidr,
|
tunCidr,
|
||||||
lighthouseHosts,
|
lighthouseHosts,
|
||||||
//TODO: change to a duration
|
//TODO: change to a duration
|
||||||
config.GetInt("lighthouse.interval", 10),
|
c.GetInt("lighthouse.interval", 10),
|
||||||
uint32(port),
|
uint32(port),
|
||||||
udpConns[0],
|
udpConns[0],
|
||||||
punchy.Respond,
|
punchy.Respond,
|
||||||
punchy.Delay,
|
punchy.Delay,
|
||||||
config.GetBool("stats.lighthouse_metrics", false),
|
c.GetBool("stats.lighthouse_metrics", false),
|
||||||
)
|
)
|
||||||
|
|
||||||
remoteAllowList, err := config.GetRemoteAllowList("lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
|
remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||||
}
|
}
|
||||||
lightHouse.SetRemoteAllowList(remoteAllowList)
|
lightHouse.SetRemoteAllowList(remoteAllowList)
|
||||||
|
|
||||||
localAllowList, err := config.GetLocalAllowList("lighthouse.local_allow_list")
|
localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||||
}
|
}
|
||||||
lightHouse.SetLocalAllowList(localAllowList)
|
lightHouse.SetLocalAllowList(localAllowList)
|
||||||
|
|
||||||
//TODO: Move all of this inside functions in lighthouse.go
|
//TODO: Move all of this inside functions in lighthouse.go
|
||||||
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
for k, v := range c.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
||||||
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
|
ip := net.ParseIP(fmt.Sprintf("%v", k))
|
||||||
if !tunCidr.Contains(vpnIp) {
|
vpnIp := iputil.Ip2VpnIp(ip)
|
||||||
|
if !tunCidr.Contains(ip) {
|
||||||
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||||
}
|
}
|
||||||
vals, ok := v.([]interface{})
|
vals, ok := v.([]interface{})
|
||||||
if ok {
|
if ok {
|
||||||
for _, v := range vals {
|
for _, v := range vals {
|
||||||
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
|
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||||
}
|
}
|
||||||
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
|
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
|
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||||
}
|
}
|
||||||
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
|
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -336,16 +339,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
}
|
}
|
||||||
|
|
||||||
var messageMetrics *MessageMetrics
|
var messageMetrics *MessageMetrics
|
||||||
if config.GetBool("stats.message_metrics", false) {
|
if c.GetBool("stats.message_metrics", false) {
|
||||||
messageMetrics = newMessageMetrics()
|
messageMetrics = newMessageMetrics()
|
||||||
} else {
|
} else {
|
||||||
messageMetrics = newMessageMetricsOnlyRecvError()
|
messageMetrics = newMessageMetricsOnlyRecvError()
|
||||||
}
|
}
|
||||||
|
|
||||||
handshakeConfig := HandshakeConfig{
|
handshakeConfig := HandshakeConfig{
|
||||||
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
|
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
|
||||||
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
|
retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries),
|
||||||
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
|
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
|
||||||
|
|
||||||
messageMetrics: messageMetrics,
|
messageMetrics: messageMetrics,
|
||||||
}
|
}
|
||||||
|
@ -358,36 +361,35 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
|
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
|
||||||
|
|
||||||
serveDns := false
|
serveDns := false
|
||||||
if config.GetBool("lighthouse.serve_dns", false) {
|
if c.GetBool("lighthouse.serve_dns", false) {
|
||||||
if config.GetBool("lighthouse.am_lighthouse", false) {
|
if c.GetBool("lighthouse.am_lighthouse", false) {
|
||||||
serveDns = true
|
serveDns = true
|
||||||
} else {
|
} else {
|
||||||
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
|
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
checkInterval := config.GetInt("timers.connection_alive_interval", 5)
|
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
|
||||||
pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10)
|
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
HostMap: hostMap,
|
HostMap: hostMap,
|
||||||
Inside: tun,
|
Inside: tun,
|
||||||
Outside: udpConns[0],
|
Outside: udpConns[0],
|
||||||
certState: cs,
|
certState: cs,
|
||||||
Cipher: config.GetString("cipher", "aes"),
|
Cipher: c.GetString("cipher", "aes"),
|
||||||
Firewall: fw,
|
Firewall: fw,
|
||||||
ServeDns: serveDns,
|
ServeDns: serveDns,
|
||||||
HandshakeManager: handshakeManager,
|
HandshakeManager: handshakeManager,
|
||||||
lightHouse: lightHouse,
|
lightHouse: lightHouse,
|
||||||
checkInterval: checkInterval,
|
checkInterval: checkInterval,
|
||||||
pendingDeletionInterval: pendingDeletionInterval,
|
pendingDeletionInterval: pendingDeletionInterval,
|
||||||
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
|
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
||||||
DropMulticast: config.GetBool("tun.drop_multicast", false),
|
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
||||||
UDPBatchSize: config.GetInt("listen.batch", 64),
|
|
||||||
routines: routines,
|
routines: routines,
|
||||||
MessageMetrics: messageMetrics,
|
MessageMetrics: messageMetrics,
|
||||||
version: buildVersion,
|
version: buildVersion,
|
||||||
caPool: caPool,
|
caPool: caPool,
|
||||||
disconnectInvalid: config.GetBool("pki.disconnect_invalid", false),
|
disconnectInvalid: c.GetBool("pki.disconnect_invalid", false),
|
||||||
|
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
l: l,
|
l: l,
|
||||||
|
@ -413,7 +415,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
// I don't want to make this initial commit too far-reaching though
|
// I don't want to make this initial commit too far-reaching though
|
||||||
ifce.writers = udpConns
|
ifce.writers = udpConns
|
||||||
|
|
||||||
ifce.RegisterConfigChangeCallbacks(config)
|
ifce.RegisterConfigChangeCallbacks(c)
|
||||||
|
|
||||||
go handshakeManager.Run(ctx, ifce)
|
go handshakeManager.Run(ctx, ifce)
|
||||||
go lightHouse.LhUpdateWorker(ctx, ifce)
|
go lightHouse.LhUpdateWorker(ctx, ifce)
|
||||||
|
@ -421,7 +423,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
|
|
||||||
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
|
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
|
||||||
// a context so that they can exit when the context is Done.
|
// a context so that they can exit when the context is Done.
|
||||||
statsStart, err := startStats(l, config, buildVersion, configTest)
|
statsStart, err := startStats(l, c, buildVersion, configTest)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||||
}
|
}
|
||||||
|
@ -431,7 +434,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: check if we _should_ be emitting stats
|
//TODO: check if we _should_ be emitting stats
|
||||||
go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10))
|
go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
|
||||||
|
|
||||||
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||||
|
|
||||||
|
@ -439,7 +442,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||||
var dnsStart func()
|
var dnsStart func()
|
||||||
if amLighthouse && serveDns {
|
if amLighthouse && serveDns {
|
||||||
l.Debugln("Starting dns server")
|
l.Debugln("Starting dns server")
|
||||||
dnsStart = dnsMain(l, hostMap, config)
|
dnsStart = dnsMain(l, hostMap, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil
|
return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil
|
||||||
|
|
|
@ -4,8 +4,11 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//TODO: this can probably move into the header package
|
||||||
|
|
||||||
type MessageMetrics struct {
|
type MessageMetrics struct {
|
||||||
rx [][]metrics.Counter
|
rx [][]metrics.Counter
|
||||||
tx [][]metrics.Counter
|
tx [][]metrics.Counter
|
||||||
|
@ -14,7 +17,7 @@ type MessageMetrics struct {
|
||||||
txUnknown metrics.Counter
|
txUnknown metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
|
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
|
||||||
if m != nil {
|
if m != nil {
|
||||||
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
|
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
|
||||||
m.rx[t][s].Inc(i)
|
m.rx[t][s].Inc(i)
|
||||||
|
@ -23,7 +26,7 @@ func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
|
func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) {
|
||||||
if m != nil {
|
if m != nil {
|
||||||
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
|
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
|
||||||
m.tx[t][s].Inc(i)
|
m.tx[t][s].Inc(i)
|
||||||
|
|
118
outside.go
118
outside.go
|
@ -10,6 +10,10 @@ import (
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,8 +21,8 @@ const (
|
||||||
minFwPacketLen = 4
|
minFwPacketLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) {
|
func (f *Interface) readOutsidePackets(addr *udp.Addr, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := header.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: best if we return this and let caller log
|
// TODO: best if we return this and let caller log
|
||||||
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
|
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
|
||||||
|
@ -32,30 +36,30 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
|
|
||||||
// verify if we've seen this index before, otherwise respond to the handshake initiation
|
// verify if we've seen this index before, otherwise respond to the handshake initiation
|
||||||
hostinfo, err := f.hostMap.QueryIndex(header.RemoteIndex)
|
hostinfo, err := f.hostMap.QueryIndex(h.RemoteIndex)
|
||||||
|
|
||||||
var ci *ConnectionState
|
var ci *ConnectionState
|
||||||
if err == nil {
|
if err == nil {
|
||||||
ci = hostinfo.ConnectionState
|
ci = hostinfo.ConnectionState
|
||||||
}
|
}
|
||||||
|
|
||||||
switch header.Type {
|
switch h.Type {
|
||||||
case message:
|
case header.Message:
|
||||||
if !f.handleEncrypted(ci, addr, header) {
|
if !f.handleEncrypted(ci, addr, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache)
|
f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
case lightHouse:
|
case header.LightHouse:
|
||||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
if !f.handleEncrypted(ci, addr, header) {
|
if !f.handleEncrypted(ci, addr, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("packet", packet).
|
WithField("packet", packet).
|
||||||
|
@ -66,17 +70,17 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhh.HandleRequest(addr, hostinfo.hostId, d, f)
|
lhf(addr, hostinfo.vpnIp, d, f)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
case test:
|
case header.Test:
|
||||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
if !f.handleEncrypted(ci, addr, header) {
|
if !f.handleEncrypted(ci, addr, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("packet", packet).
|
WithField("packet", packet).
|
||||||
|
@ -87,11 +91,11 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if header.Subtype == testRequest {
|
if h.Subtype == header.TestRequest {
|
||||||
// This testRequest might be from TryPromoteBest, so we should roam
|
// This testRequest might be from TryPromoteBest, so we should roam
|
||||||
// to the new IP address before responding
|
// to the new IP address before responding
|
||||||
f.handleHostRoaming(hostinfo, addr)
|
f.handleHostRoaming(hostinfo, addr)
|
||||||
f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out)
|
f.send(header.Test, header.TestReply, ci, hostinfo, hostinfo.remote, d, nb, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
@ -99,19 +103,19 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
|
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
|
||||||
// are unauthenticated
|
// are unauthenticated
|
||||||
|
|
||||||
case handshake:
|
case header.Handshake:
|
||||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
HandleIncomingHandshake(f, addr, packet, header, hostinfo)
|
HandleIncomingHandshake(f, addr, packet, h, hostinfo)
|
||||||
return
|
return
|
||||||
|
|
||||||
case recvError:
|
case header.RecvError:
|
||||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
f.handleRecvError(addr, header)
|
f.handleRecvError(addr, h)
|
||||||
return
|
return
|
||||||
|
|
||||||
case closeTunnel:
|
case header.CloseTunnel:
|
||||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
if !f.handleEncrypted(ci, addr, header) {
|
if !f.handleEncrypted(ci, addr, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,22 +126,22 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||||
return
|
return
|
||||||
|
|
||||||
default:
|
default:
|
||||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
|
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.handleHostRoaming(hostinfo, addr)
|
f.handleHostRoaming(hostinfo, addr)
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo.hostId)
|
f.connectionManager.In(hostinfo.vpnIp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
||||||
func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
|
func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
|
||||||
//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
|
//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
|
||||||
f.connectionManager.ClearIP(hostInfo.hostId)
|
f.connectionManager.ClearIP(hostInfo.vpnIp)
|
||||||
f.connectionManager.ClearPendingDeletion(hostInfo.hostId)
|
f.connectionManager.ClearPendingDeletion(hostInfo.vpnIp)
|
||||||
f.lightHouse.DeleteVpnIP(hostInfo.hostId)
|
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
|
||||||
|
|
||||||
if hasHostMapLock {
|
if hasHostMapLock {
|
||||||
f.hostMap.unlockedDeleteHostInfo(hostInfo)
|
f.hostMap.unlockedDeleteHostInfo(hostInfo)
|
||||||
|
@ -148,12 +152,12 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
|
||||||
|
|
||||||
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
|
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
|
||||||
func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
||||||
f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
|
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
|
||||||
if hostDidRoam(hostinfo.remote, addr) {
|
if !hostinfo.remote.Equals(addr) {
|
||||||
if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) {
|
if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
|
||||||
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
|
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -175,11 +179,11 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
|
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
|
||||||
// If connectionstate exists and the replay protector allows, process packet
|
// If connectionstate exists and the replay protector allows, process packet
|
||||||
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
||||||
if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
|
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
|
||||||
f.sendRecvError(addr, header.RemoteIndex)
|
f.sendRecvError(addr, h.RemoteIndex)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,7 +191,7 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *
|
||||||
}
|
}
|
||||||
|
|
||||||
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
|
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
|
||||||
func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
|
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||||
// Do we at least have an ipv4 header worth of data?
|
// Do we at least have an ipv4 header worth of data?
|
||||||
if len(data) < ipv4.HeaderLen {
|
if len(data) < ipv4.HeaderLen {
|
||||||
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
|
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
|
||||||
|
@ -215,7 +219,7 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
|
||||||
|
|
||||||
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
|
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
|
||||||
minLen := ihl
|
minLen := ihl
|
||||||
if !fp.Fragment && fp.Protocol != fwProtoICMP {
|
if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
|
||||||
minLen += minFwPacketLen
|
minLen += minFwPacketLen
|
||||||
}
|
}
|
||||||
if len(data) < minLen {
|
if len(data) < minLen {
|
||||||
|
@ -224,9 +228,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
|
||||||
|
|
||||||
// Firewall packets are locally oriented
|
// Firewall packets are locally oriented
|
||||||
if incoming {
|
if incoming {
|
||||||
fp.RemoteIP = binary.BigEndian.Uint32(data[12:16])
|
fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
|
||||||
fp.LocalIP = binary.BigEndian.Uint32(data[16:20])
|
fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
|
||||||
if fp.Fragment || fp.Protocol == fwProtoICMP {
|
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||||
fp.RemotePort = 0
|
fp.RemotePort = 0
|
||||||
fp.LocalPort = 0
|
fp.LocalPort = 0
|
||||||
} else {
|
} else {
|
||||||
|
@ -234,9 +238,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fp.LocalIP = binary.BigEndian.Uint32(data[12:16])
|
fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
|
||||||
fp.RemoteIP = binary.BigEndian.Uint32(data[16:20])
|
fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
|
||||||
if fp.Fragment || fp.Protocol == fwProtoICMP {
|
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||||
fp.RemotePort = 0
|
fp.RemotePort = 0
|
||||||
fp.LocalPort = 0
|
fp.LocalPort = 0
|
||||||
} else {
|
} else {
|
||||||
|
@ -248,15 +252,15 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, header *Header, nb []byte) ([]byte, error) {
|
func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) {
|
||||||
var err error
|
var err error
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], mc, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
||||||
hostinfo.logger(f.l).WithField("header", header).
|
hostinfo.logger(f.l).WithField("header", h).
|
||||||
Debugln("dropping out of window packet")
|
Debugln("dropping out of window packet")
|
||||||
return nil, errors.New("out of window packet")
|
return nil, errors.New("out of window packet")
|
||||||
}
|
}
|
||||||
|
@ -264,10 +268,10 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) {
|
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
||||||
|
@ -298,18 +302,18 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo.hostId)
|
f.connectionManager.In(hostinfo.vpnIp)
|
||||||
_, err = f.readers[q].Write(out)
|
_, err = f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
|
func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
|
||||||
f.messageMetrics.Tx(recvError, 0, 1)
|
f.messageMetrics.Tx(header.RecvError, 0, 1)
|
||||||
|
|
||||||
//TODO: this should be a signed message so we can trust that we should drop the index
|
//TODO: this should be a signed message so we can trust that we should drop the index
|
||||||
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
|
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
||||||
f.outside.WriteTo(b, endpoint)
|
f.outside.WriteTo(b, endpoint)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("index", index).
|
f.l.WithField("index", index).
|
||||||
|
@ -318,7 +322,7 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("index", h.RemoteIndex).
|
f.l.WithField("index", h.RemoteIndex).
|
||||||
WithField("udpAddr", addr).
|
WithField("udpAddr", addr).
|
||||||
|
|
|
@ -4,12 +4,14 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_newPacket(t *testing.T) {
|
func Test_newPacket(t *testing.T) {
|
||||||
p := &FirewallPacket{}
|
p := &firewall.Packet{}
|
||||||
|
|
||||||
// length fail
|
// length fail
|
||||||
err := newPacket([]byte{0, 1}, true, p)
|
err := newPacket([]byte{0, 1}, true, p)
|
||||||
|
@ -44,7 +46,7 @@ func Test_newPacket(t *testing.T) {
|
||||||
Src: net.IPv4(10, 0, 0, 1),
|
Src: net.IPv4(10, 0, 0, 1),
|
||||||
Dst: net.IPv4(10, 0, 0, 2),
|
Dst: net.IPv4(10, 0, 0, 2),
|
||||||
Options: []byte{0, 1, 0, 2},
|
Options: []byte{0, 1, 0, 2},
|
||||||
Protocol: fwProtoTCP,
|
Protocol: firewall.ProtoTCP,
|
||||||
}
|
}
|
||||||
|
|
||||||
b, _ = h.Marshal()
|
b, _ = h.Marshal()
|
||||||
|
@ -52,9 +54,9 @@ func Test_newPacket(t *testing.T) {
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.Protocol, uint8(fwProtoTCP))
|
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
|
||||||
assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 2)))
|
assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
|
||||||
assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 1)))
|
assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
|
||||||
assert.Equal(t, p.RemotePort, uint16(3))
|
assert.Equal(t, p.RemotePort, uint16(3))
|
||||||
assert.Equal(t, p.LocalPort, uint16(4))
|
assert.Equal(t, p.LocalPort, uint16(4))
|
||||||
|
|
||||||
|
@ -74,8 +76,8 @@ func Test_newPacket(t *testing.T) {
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.Protocol, uint8(2))
|
assert.Equal(t, p.Protocol, uint8(2))
|
||||||
assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 1)))
|
assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
|
||||||
assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 2)))
|
assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
|
||||||
assert.Equal(t, p.RemotePort, uint16(6))
|
assert.Equal(t, p.RemotePort, uint16(6))
|
||||||
assert.Equal(t, p.LocalPort, uint16(5))
|
assert.Equal(t, p.LocalPort, uint16(5))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
)
|
||||||
|
|
||||||
type Punchy struct {
|
type Punchy struct {
|
||||||
Punch bool
|
Punch bool
|
||||||
|
@ -8,7 +12,7 @@ type Punchy struct {
|
||||||
Delay time.Duration
|
Delay time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPunchyFromConfig(c *Config) *Punchy {
|
func NewPunchyFromConfig(c *config.C) *Punchy {
|
||||||
p := &Punchy{}
|
p := &Punchy{}
|
||||||
|
|
||||||
if c.IsSet("punchy.punch") {
|
if c.IsSet("punchy.punch") {
|
||||||
|
|
|
@ -4,12 +4,14 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewPunchyFromConfig(t *testing.T) {
|
func TestNewPunchyFromConfig(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
c := NewConfig(l)
|
c := config.NewC(l)
|
||||||
|
|
||||||
// Test defaults
|
// Test defaults
|
||||||
p := NewPunchyFromConfig(c)
|
p := NewPunchyFromConfig(c)
|
||||||
|
|
|
@ -5,14 +5,17 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// forEachFunc is used to benefit folks that want to do work inside the lock
|
// forEachFunc is used to benefit folks that want to do work inside the lock
|
||||||
type forEachFunc func(addr *udpAddr, preferred bool)
|
type forEachFunc func(addr *udp.Addr, preferred bool)
|
||||||
|
|
||||||
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
|
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
|
||||||
type checkFuncV4 func(vpnIp uint32, to *Ip4AndPort) bool
|
type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
|
||||||
type checkFuncV6 func(vpnIp uint32, to *Ip6AndPort) bool
|
type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
|
||||||
|
|
||||||
// CacheMap is a struct that better represents the lighthouse cache for humans
|
// CacheMap is a struct that better represents the lighthouse cache for humans
|
||||||
// The string key is the owners vpnIp
|
// The string key is the owners vpnIp
|
||||||
|
@ -21,8 +24,8 @@ type CacheMap map[string]*Cache
|
||||||
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
|
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
|
||||||
// We don't reason about ipv4 vs ipv6 here
|
// We don't reason about ipv4 vs ipv6 here
|
||||||
type Cache struct {
|
type Cache struct {
|
||||||
Learned []*udpAddr `json:"learned,omitempty"`
|
Learned []*udp.Addr `json:"learned,omitempty"`
|
||||||
Reported []*udpAddr `json:"reported,omitempty"`
|
Reported []*udp.Addr `json:"reported,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
|
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
|
||||||
|
@ -53,16 +56,16 @@ type RemoteList struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
||||||
addrs []*udpAddr
|
addrs []*udp.Addr
|
||||||
|
|
||||||
// These are maps to store v4 and v6 addresses per lighthouse
|
// These are maps to store v4 and v6 addresses per lighthouse
|
||||||
// Map key is the vpnIp of the person that told us about this the cached entries underneath.
|
// Map key is the vpnIp of the person that told us about this the cached entries underneath.
|
||||||
// For learned addresses, this is the vpnIp that sent the packet
|
// For learned addresses, this is the vpnIp that sent the packet
|
||||||
cache map[uint32]*cache
|
cache map[iputil.VpnIp]*cache
|
||||||
|
|
||||||
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
||||||
// They should not be tried again during a handshake
|
// They should not be tried again during a handshake
|
||||||
badRemotes []*udpAddr
|
badRemotes []*udp.Addr
|
||||||
|
|
||||||
// A flag that the cache may have changed and addrs needs to be rebuilt
|
// A flag that the cache may have changed and addrs needs to be rebuilt
|
||||||
shouldRebuild bool
|
shouldRebuild bool
|
||||||
|
@ -71,8 +74,8 @@ type RemoteList struct {
|
||||||
// NewRemoteList creates a new empty RemoteList
|
// NewRemoteList creates a new empty RemoteList
|
||||||
func NewRemoteList() *RemoteList {
|
func NewRemoteList() *RemoteList {
|
||||||
return &RemoteList{
|
return &RemoteList{
|
||||||
addrs: make([]*udpAddr, 0),
|
addrs: make([]*udp.Addr, 0),
|
||||||
cache: make(map[uint32]*cache),
|
cache: make(map[iputil.VpnIp]*cache),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +101,7 @@ func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc)
|
||||||
|
|
||||||
// CopyAddrs locks and makes a deep copy of the deduplicated address list
|
// CopyAddrs locks and makes a deep copy of the deduplicated address list
|
||||||
// The deduplication work may need to occur here, so you must pass preferredRanges
|
// The deduplication work may need to occur here, so you must pass preferredRanges
|
||||||
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
|
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -107,7 +110,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
|
||||||
|
|
||||||
r.RLock()
|
r.RLock()
|
||||||
defer r.RUnlock()
|
defer r.RUnlock()
|
||||||
c := make([]*udpAddr, len(r.addrs))
|
c := make([]*udp.Addr, len(r.addrs))
|
||||||
for i, v := range r.addrs {
|
for i, v := range r.addrs {
|
||||||
c[i] = v.Copy()
|
c[i] = v.Copy()
|
||||||
}
|
}
|
||||||
|
@ -118,7 +121,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
|
||||||
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
|
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
|
||||||
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
|
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
|
||||||
//TODO: this needs to support the allow list list
|
//TODO: this needs to support the allow list list
|
||||||
func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) {
|
func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
|
||||||
r.Lock()
|
r.Lock()
|
||||||
defer r.Unlock()
|
defer r.Unlock()
|
||||||
if v4 := addr.IP.To4(); v4 != nil {
|
if v4 := addr.IP.To4(); v4 != nil {
|
||||||
|
@ -139,8 +142,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
|
||||||
c := cm[vpnIp]
|
c := cm[vpnIp]
|
||||||
if c == nil {
|
if c == nil {
|
||||||
c = &Cache{
|
c = &Cache{
|
||||||
Learned: make([]*udpAddr, 0),
|
Learned: make([]*udp.Addr, 0),
|
||||||
Reported: make([]*udpAddr, 0),
|
Reported: make([]*udp.Addr, 0),
|
||||||
}
|
}
|
||||||
cm[vpnIp] = c
|
cm[vpnIp] = c
|
||||||
}
|
}
|
||||||
|
@ -148,7 +151,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
|
||||||
}
|
}
|
||||||
|
|
||||||
for owner, mc := range r.cache {
|
for owner, mc := range r.cache {
|
||||||
c := getOrMake(IntIp(owner).String())
|
c := getOrMake(owner.String())
|
||||||
|
|
||||||
if mc.v4 != nil {
|
if mc.v4 != nil {
|
||||||
if mc.v4.learned != nil {
|
if mc.v4.learned != nil {
|
||||||
|
@ -175,7 +178,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
|
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
|
||||||
func (r *RemoteList) BlockRemote(bad *udpAddr) {
|
func (r *RemoteList) BlockRemote(bad *udp.Addr) {
|
||||||
r.Lock()
|
r.Lock()
|
||||||
defer r.Unlock()
|
defer r.Unlock()
|
||||||
|
|
||||||
|
@ -192,11 +195,11 @@ func (r *RemoteList) BlockRemote(bad *udpAddr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
|
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
|
||||||
func (r *RemoteList) CopyBlockedRemotes() []*udpAddr {
|
func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
|
||||||
r.RLock()
|
r.RLock()
|
||||||
defer r.RUnlock()
|
defer r.RUnlock()
|
||||||
|
|
||||||
c := make([]*udpAddr, len(r.badRemotes))
|
c := make([]*udp.Addr, len(r.badRemotes))
|
||||||
for i, v := range r.badRemotes {
|
for i, v := range r.badRemotes {
|
||||||
c[i] = v.Copy()
|
c[i] = v.Copy()
|
||||||
}
|
}
|
||||||
|
@ -228,7 +231,7 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
|
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
|
||||||
func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
|
func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
|
||||||
for _, v := range r.badRemotes {
|
for _, v := range r.badRemotes {
|
||||||
if v.Equals(remote) {
|
if v.Equals(remote) {
|
||||||
return true
|
return true
|
||||||
|
@ -239,14 +242,14 @@ func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
|
||||||
|
|
||||||
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
|
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
|
||||||
// deduplicated address list as dirty
|
// deduplicated address list as dirty
|
||||||
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) {
|
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
|
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
||||||
// and marks the deduplicated address list as dirty
|
// and marks the deduplicated address list as dirty
|
||||||
func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4AndPort, check checkFuncV4) {
|
func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
||||||
|
|
||||||
|
@ -263,7 +266,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4And
|
||||||
|
|
||||||
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
|
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
|
||||||
// This is only useful for establishing static hosts
|
// This is only useful for establishing static hosts
|
||||||
func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
|
func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
||||||
|
|
||||||
|
@ -276,14 +279,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
|
||||||
|
|
||||||
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
|
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
|
||||||
// deduplicated address list as dirty
|
// deduplicated address list as dirty
|
||||||
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) {
|
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
|
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
||||||
// and marks the deduplicated address list as dirty
|
// and marks the deduplicated address list as dirty
|
||||||
func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6AndPort, check checkFuncV6) {
|
func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
||||||
|
|
||||||
|
@ -300,7 +303,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6And
|
||||||
|
|
||||||
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
|
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
|
||||||
// This is only useful for establishing static hosts
|
// This is only useful for establishing static hosts
|
||||||
func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
|
func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
||||||
|
|
||||||
|
@ -313,7 +316,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
|
||||||
|
|
||||||
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
|
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
|
||||||
// The caller must dirty the learned address cache if required
|
// The caller must dirty the learned address cache if required
|
||||||
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
|
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
|
||||||
am := r.cache[ownerVpnIp]
|
am := r.cache[ownerVpnIp]
|
||||||
if am == nil {
|
if am == nil {
|
||||||
am = &cache{}
|
am = &cache{}
|
||||||
|
@ -328,7 +331,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
|
||||||
|
|
||||||
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
|
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
|
||||||
// The caller must dirty the learned address cache if required
|
// The caller must dirty the learned address cache if required
|
||||||
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 {
|
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
|
||||||
am := r.cache[ownerVpnIp]
|
am := r.cache[ownerVpnIp]
|
||||||
if am == nil {
|
if am == nil {
|
||||||
am = &cache{}
|
am = &cache{}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,18 +14,18 @@ func TestRemoteList_Rebuild(t *testing.T) {
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
[]*Ip4AndPort{
|
[]*Ip4AndPort{
|
||||||
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
|
||||||
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
|
||||||
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
|
||||||
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
|
||||||
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
|
||||||
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
|
||||||
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
|
||||||
},
|
},
|
||||||
func(uint32, *Ip4AndPort) bool { return true },
|
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rl.unlockedSetV6(
|
rl.unlockedSetV6(
|
||||||
|
@ -37,7 +38,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
|
||||||
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
|
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
|
||||||
NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
|
NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
|
||||||
},
|
},
|
||||||
func(uint32, *Ip6AndPort) bool { return true },
|
func(iputil.VpnIp, *Ip6AndPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rl.Rebuild([]*net.IPNet{})
|
rl.Rebuild([]*net.IPNet{})
|
||||||
|
@ -106,16 +107,16 @@ func BenchmarkFullRebuild(b *testing.B) {
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
[]*Ip4AndPort{
|
[]*Ip4AndPort{
|
||||||
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
|
||||||
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
|
||||||
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
|
||||||
},
|
},
|
||||||
func(uint32, *Ip4AndPort) bool { return true },
|
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rl.unlockedSetV6(
|
rl.unlockedSetV6(
|
||||||
|
@ -127,7 +128,7 @@ func BenchmarkFullRebuild(b *testing.B) {
|
||||||
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
|
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
|
||||||
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
|
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
|
||||||
},
|
},
|
||||||
func(uint32, *Ip6AndPort) bool { return true },
|
func(iputil.VpnIp, *Ip6AndPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
b.Run("no preferred", func(b *testing.B) {
|
b.Run("no preferred", func(b *testing.B) {
|
||||||
|
@ -171,16 +172,16 @@ func BenchmarkSortRebuild(b *testing.B) {
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
[]*Ip4AndPort{
|
[]*Ip4AndPort{
|
||||||
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
|
||||||
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
|
||||||
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
|
||||||
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
|
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
|
||||||
},
|
},
|
||||||
func(uint32, *Ip4AndPort) bool { return true },
|
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rl.unlockedSetV6(
|
rl.unlockedSetV6(
|
||||||
|
@ -192,7 +193,7 @@ func BenchmarkSortRebuild(b *testing.B) {
|
||||||
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
|
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
|
||||||
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
|
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
|
||||||
},
|
},
|
||||||
func(uint32, *Ip6AndPort) bool { return true },
|
func(iputil.VpnIp, *Ip6AndPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
b.Run("no preferred", func(b *testing.B) {
|
b.Run("no preferred", func(b *testing.B) {
|
||||||
|
|
56
ssh.go
56
ssh.go
|
@ -15,7 +15,11 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sshListHostMapFlags struct {
|
type sshListHostMapFlags struct {
|
||||||
|
@ -45,8 +49,8 @@ type sshCreateTunnelFlags struct {
|
||||||
Address string
|
Address string
|
||||||
}
|
}
|
||||||
|
|
||||||
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
|
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
|
||||||
c.RegisterReloadCallback(func(c *Config) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshRun, err := configSSH(l, ssh, c)
|
sshRun, err := configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -66,7 +70,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
|
||||||
// updates the passed-in SSHServer. On success, it returns a function
|
// updates the passed-in SSHServer. On success, it returns a function
|
||||||
// that callers may invoke to run the configured ssh server. On
|
// that callers may invoke to run the configured ssh server. On
|
||||||
// failure, it returns nil, error.
|
// failure, it returns nil, error.
|
||||||
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) {
|
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
|
||||||
//TODO conntrack list
|
//TODO conntrack list
|
||||||
//TODO print firewall rules or hash?
|
//TODO print firewall rules or hash?
|
||||||
|
|
||||||
|
@ -351,7 +355,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
|
||||||
|
|
||||||
hm := listHostMap(hostMap)
|
hm := listHostMap(hostMap)
|
||||||
sort.Slice(hm, func(i, j int) bool {
|
sort.Slice(hm, func(i, j int) bool {
|
||||||
return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0
|
return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
|
||||||
})
|
})
|
||||||
|
|
||||||
if fs.Json || fs.Pretty {
|
if fs.Json || fs.Pretty {
|
||||||
|
@ -368,7 +372,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
for _, v := range hm {
|
for _, v := range hm {
|
||||||
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs))
|
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -386,7 +390,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
|
||||||
}
|
}
|
||||||
|
|
||||||
type lighthouseInfo struct {
|
type lighthouseInfo struct {
|
||||||
VpnIP net.IP `json:"vpnIp"`
|
VpnIp string `json:"vpnIp"`
|
||||||
Addrs *CacheMap `json:"addrs"`
|
Addrs *CacheMap `json:"addrs"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -395,7 +399,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
|
||||||
x := 0
|
x := 0
|
||||||
for k, v := range lightHouse.addrMap {
|
for k, v := range lightHouse.addrMap {
|
||||||
addrMap[x] = lighthouseInfo{
|
addrMap[x] = lighthouseInfo{
|
||||||
VpnIP: int2ip(k),
|
VpnIp: k.String(),
|
||||||
Addrs: v.CopyCache(),
|
Addrs: v.CopyCache(),
|
||||||
}
|
}
|
||||||
x++
|
x++
|
||||||
|
@ -403,7 +407,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
|
||||||
lightHouse.RUnlock()
|
lightHouse.RUnlock()
|
||||||
|
|
||||||
sort.Slice(addrMap, func(i, j int) bool {
|
sort.Slice(addrMap, func(i, j int) bool {
|
||||||
return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0
|
return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0
|
||||||
})
|
})
|
||||||
|
|
||||||
if fs.Json || fs.Pretty {
|
if fs.Json || fs.Pretty {
|
||||||
|
@ -424,7 +428,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b)))
|
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -470,7 +474,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(parsedIp)
|
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
@ -499,19 +503,19 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(parsedIp)
|
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
|
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !flags.LocalOnly {
|
if !flags.LocalOnly {
|
||||||
ifce.send(
|
ifce.send(
|
||||||
closeTunnel,
|
header.CloseTunnel,
|
||||||
0,
|
0,
|
||||||
hostInfo.ConnectionState,
|
hostInfo.ConnectionState,
|
||||||
hostInfo,
|
hostInfo,
|
||||||
|
@ -542,30 +546,30 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(parsedIp)
|
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo, _ := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
|
hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp)
|
||||||
if hostInfo != nil {
|
if hostInfo != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
|
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIP(uint32(vpnIp))
|
hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
|
||||||
if hostInfo != nil {
|
if hostInfo != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
|
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
|
||||||
}
|
}
|
||||||
|
|
||||||
var addr *udpAddr
|
var addr *udp.Addr
|
||||||
if flags.Address != "" {
|
if flags.Address != "" {
|
||||||
addr = NewUDPAddrFromString(flags.Address)
|
addr = udp.NewAddrFromString(flags.Address)
|
||||||
if addr == nil {
|
if addr == nil {
|
||||||
return w.WriteLine("Address could not be parsed")
|
return w.WriteLine("Address could not be parsed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo = ifce.handshakeManager.AddVpnIP(vpnIp)
|
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
|
||||||
if addr != nil {
|
if addr != nil {
|
||||||
hostInfo.SetRemote(addr)
|
hostInfo.SetRemote(addr)
|
||||||
}
|
}
|
||||||
|
@ -589,7 +593,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
||||||
return w.WriteLine("No address was provided")
|
return w.WriteLine("No address was provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := NewUDPAddrFromString(flags.Address)
|
addr := udp.NewAddrFromString(flags.Address)
|
||||||
if addr == nil {
|
if addr == nil {
|
||||||
return w.WriteLine("Address could not be parsed")
|
return w.WriteLine("Address could not be parsed")
|
||||||
}
|
}
|
||||||
|
@ -599,12 +603,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(parsedIp)
|
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
|
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
||||||
}
|
}
|
||||||
|
@ -680,12 +684,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(parsedIp)
|
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
|
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
||||||
}
|
}
|
||||||
|
@ -742,12 +746,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(parsedIp)
|
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp)
|
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
||||||
}
|
}
|
||||||
|
|
7
stats.go
7
stats.go
|
@ -15,12 +15,13 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// startStats initializes stats from config. On success, if any futher work
|
// startStats initializes stats from config. On success, if any futher work
|
||||||
// is needed to serve stats, it returns a func to handle that work. If no
|
// is needed to serve stats, it returns a func to handle that work. If no
|
||||||
// work is needed, it'll return nil. On failure, it returns nil, error.
|
// work is needed, it'll return nil. On failure, it returns nil, error.
|
||||||
func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) {
|
func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
||||||
mType := c.GetString("stats.type", "")
|
mType := c.GetString("stats.type", "")
|
||||||
if mType == "" || mType == "none" {
|
if mType == "" || mType == "none" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -57,7 +58,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo
|
||||||
return startFn, nil
|
return startFn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
|
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error {
|
||||||
proto := c.GetString("stats.protocol", "tcp")
|
proto := c.GetString("stats.protocol", "tcp")
|
||||||
host := c.GetString("stats.host", "")
|
host := c.GetString("stats.host", "")
|
||||||
if host == "" {
|
if host == "" {
|
||||||
|
@ -77,7 +78,7 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) {
|
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
||||||
namespace := c.GetString("stats.namespace", "")
|
namespace := c.GetString("stats.namespace", "")
|
||||||
subsystem := c.GetString("stats.subsystem", "")
|
subsystem := c.GetString("stats.subsystem", "")
|
||||||
|
|
||||||
|
|
12
timeout.go
12
timeout.go
|
@ -2,12 +2,14 @@ package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
)
|
)
|
||||||
|
|
||||||
// How many timer objects should be cached
|
// How many timer objects should be cached
|
||||||
const timerCacheMax = 50000
|
const timerCacheMax = 50000
|
||||||
|
|
||||||
var emptyFWPacket = FirewallPacket{}
|
var emptyFWPacket = firewall.Packet{}
|
||||||
|
|
||||||
type TimerWheel struct {
|
type TimerWheel struct {
|
||||||
// Current tick
|
// Current tick
|
||||||
|
@ -42,7 +44,7 @@ type TimeoutList struct {
|
||||||
|
|
||||||
// Represents an item within a tick
|
// Represents an item within a tick
|
||||||
type TimeoutItem struct {
|
type TimeoutItem struct {
|
||||||
Packet FirewallPacket
|
Packet firewall.Packet
|
||||||
Next *TimeoutItem
|
Next *TimeoutItem
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,8 +75,8 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel {
|
||||||
return &tw
|
return &tw
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add will add a FirewallPacket to the wheel in it's proper timeout
|
// Add will add a firewall.Packet to the wheel in it's proper timeout
|
||||||
func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem {
|
func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem {
|
||||||
// Check and see if we should progress the tick
|
// Check and see if we should progress the tick
|
||||||
tw.advance(time.Now())
|
tw.advance(time.Now())
|
||||||
|
|
||||||
|
@ -103,7 +105,7 @@ func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem
|
||||||
return ti
|
return ti
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tw *TimerWheel) Purge() (FirewallPacket, bool) {
|
func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
|
||||||
if tw.expired.Head == nil {
|
if tw.expired.Head == nil {
|
||||||
return emptyFWPacket, false
|
return emptyFWPacket, false
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,8 @@ package nebula
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// How many timer objects should be cached
|
// How many timer objects should be cached
|
||||||
|
@ -43,7 +45,7 @@ type SystemTimeoutList struct {
|
||||||
|
|
||||||
// Represents an item within a tick
|
// Represents an item within a tick
|
||||||
type SystemTimeoutItem struct {
|
type SystemTimeoutItem struct {
|
||||||
Item uint32
|
Item iputil.VpnIp
|
||||||
Next *SystemTimeoutItem
|
Next *SystemTimeoutItem
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +76,7 @@ func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
|
||||||
return &tw
|
return &tw
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tw *SystemTimerWheel) Add(v uint32, timeout time.Duration) *SystemTimeoutItem {
|
func (tw *SystemTimerWheel) Add(v iputil.VpnIp, timeout time.Duration) *SystemTimeoutItem {
|
||||||
tw.lock.Lock()
|
tw.lock.Lock()
|
||||||
defer tw.lock.Unlock()
|
defer tw.lock.Unlock()
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,7 +52,7 @@ func TestSystemTimerWheel_findWheel(t *testing.T) {
|
||||||
func TestSystemTimerWheel_Add(t *testing.T) {
|
func TestSystemTimerWheel_Add(t *testing.T) {
|
||||||
tw := NewSystemTimerWheel(time.Second, time.Second*10)
|
tw := NewSystemTimerWheel(time.Second, time.Second*10)
|
||||||
|
|
||||||
fp1 := ip2int(net.ParseIP("1.2.3.4"))
|
fp1 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
|
||||||
tw.Add(fp1, time.Second*1)
|
tw.Add(fp1, time.Second*1)
|
||||||
|
|
||||||
// Make sure we set head and tail properly
|
// Make sure we set head and tail properly
|
||||||
|
@ -62,7 +63,7 @@ func TestSystemTimerWheel_Add(t *testing.T) {
|
||||||
assert.Nil(t, tw.wheel[2].Tail.Next)
|
assert.Nil(t, tw.wheel[2].Tail.Next)
|
||||||
|
|
||||||
// Make sure we only modify head
|
// Make sure we only modify head
|
||||||
fp2 := ip2int(net.ParseIP("1.2.3.4"))
|
fp2 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
|
||||||
tw.Add(fp2, time.Second*1)
|
tw.Add(fp2, time.Second*1)
|
||||||
assert.Equal(t, fp2, tw.wheel[2].Head.Item)
|
assert.Equal(t, fp2, tw.wheel[2].Head.Item)
|
||||||
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
|
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
|
||||||
|
@ -85,7 +86,7 @@ func TestSystemTimerWheel_Purge(t *testing.T) {
|
||||||
assert.NotNil(t, tw.lastTick)
|
assert.NotNil(t, tw.lastTick)
|
||||||
assert.Equal(t, 0, tw.current)
|
assert.Equal(t, 0, tw.current)
|
||||||
|
|
||||||
fps := []uint32{9, 10, 11, 12}
|
fps := []iputil.VpnIp{9, 10, 11, 12}
|
||||||
|
|
||||||
//fp1 := ip2int(net.ParseIP("1.2.3.4"))
|
//fp1 := ip2int(net.ParseIP("1.2.3.4"))
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -50,7 +51,7 @@ func TestTimerWheel_findWheel(t *testing.T) {
|
||||||
func TestTimerWheel_Add(t *testing.T) {
|
func TestTimerWheel_Add(t *testing.T) {
|
||||||
tw := NewTimerWheel(time.Second, time.Second*10)
|
tw := NewTimerWheel(time.Second, time.Second*10)
|
||||||
|
|
||||||
fp1 := FirewallPacket{}
|
fp1 := firewall.Packet{}
|
||||||
tw.Add(fp1, time.Second*1)
|
tw.Add(fp1, time.Second*1)
|
||||||
|
|
||||||
// Make sure we set head and tail properly
|
// Make sure we set head and tail properly
|
||||||
|
@ -61,7 +62,7 @@ func TestTimerWheel_Add(t *testing.T) {
|
||||||
assert.Nil(t, tw.wheel[2].Tail.Next)
|
assert.Nil(t, tw.wheel[2].Tail.Next)
|
||||||
|
|
||||||
// Make sure we only modify head
|
// Make sure we only modify head
|
||||||
fp2 := FirewallPacket{}
|
fp2 := firewall.Packet{}
|
||||||
tw.Add(fp2, time.Second*1)
|
tw.Add(fp2, time.Second*1)
|
||||||
assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
|
assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
|
||||||
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
|
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
|
||||||
|
@ -84,7 +85,7 @@ func TestTimerWheel_Purge(t *testing.T) {
|
||||||
assert.NotNil(t, tw.lastTick)
|
assert.NotNil(t, tw.lastTick)
|
||||||
assert.Equal(t, 0, tw.current)
|
assert.Equal(t, 0, tw.current)
|
||||||
|
|
||||||
fps := []FirewallPacket{
|
fps := []firewall.Packet{
|
||||||
{LocalIP: 1},
|
{LocalIP: 1},
|
||||||
{LocalIP: 2},
|
{LocalIP: 2},
|
||||||
{LocalIP: 3},
|
{LocalIP: 3},
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DEFAULT_MTU = 1300
|
const DEFAULT_MTU = 1300
|
||||||
|
@ -14,10 +16,10 @@ type route struct {
|
||||||
via *net.IP
|
via *net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
r := config.Get("tun.routes")
|
r := c.Get("tun.routes")
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return []route{}, nil
|
return []route{}, nil
|
||||||
}
|
}
|
||||||
|
@ -84,10 +86,10 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
r := config.Get("tun.unsafe_routes")
|
r := c.Get("tun.unsafe_routes")
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return []route{}, nil
|
return []route{}, nil
|
||||||
}
|
}
|
||||||
|
@ -110,7 +112,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
||||||
|
|
||||||
rMtu, ok := m["mtu"]
|
rMtu, ok := m["mtu"]
|
||||||
if !ok {
|
if !ok {
|
||||||
rMtu = config.GetInt("tun.mtu", DEFAULT_MTU)
|
rMtu = c.GetInt("tun.mtu", DEFAULT_MTU)
|
||||||
}
|
}
|
||||||
|
|
||||||
mtu, ok := rMtu.(int)
|
mtu, ok := rMtu.(int)
|
||||||
|
|
10
tun_test.go
10
tun_test.go
|
@ -5,12 +5,14 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_parseRoutes(t *testing.T) {
|
func Test_parseRoutes(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
c := NewConfig(l)
|
c := config.NewC(l)
|
||||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
|
@ -105,8 +107,8 @@ func Test_parseRoutes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_parseUnsafeRoutes(t *testing.T) {
|
func Test_parseUnsafeRoutes(t *testing.T) {
|
||||||
l := NewTestLogger()
|
l := util.NewTestLogger()
|
||||||
c := NewConfig(l)
|
c := config.NewC(l)
|
||||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MTU = 9001
|
||||||
|
|
||||||
|
type EncReader func(
|
||||||
|
addr *Addr,
|
||||||
|
out []byte,
|
||||||
|
packet []byte,
|
||||||
|
header *header.H,
|
||||||
|
fwPacket *firewall.Packet,
|
||||||
|
lhh LightHouseHandlerFunc,
|
||||||
|
nb []byte,
|
||||||
|
q int,
|
||||||
|
localCache firewall.ConntrackCache,
|
||||||
|
)
|
|
@ -0,0 +1,14 @@
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
)
|
||||||
|
|
||||||
|
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
|
||||||
|
|
||||||
|
type EncWriter interface {
|
||||||
|
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)
|
|
@ -1,4 +1,4 @@
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -7,32 +7,34 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type udpAddr struct {
|
type m map[string]interface{}
|
||||||
|
|
||||||
|
type Addr struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
Port uint16
|
Port uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
|
func NewAddr(ip net.IP, port uint16) *Addr {
|
||||||
addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port}
|
addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
|
||||||
copy(addr.IP, ip.To16())
|
copy(addr.IP, ip.To16())
|
||||||
return &addr
|
return &addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPAddrFromString(s string) *udpAddr {
|
func NewAddrFromString(s string) *Addr {
|
||||||
ip, port, err := parseIPAndPort(s)
|
ip, port, err := ParseIPAndPort(s)
|
||||||
//TODO: handle err
|
//TODO: handle err
|
||||||
_ = err
|
_ = err
|
||||||
return &udpAddr{IP: ip.To16(), Port: port}
|
return &Addr{IP: ip.To16(), Port: port}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ua *udpAddr) Equals(t *udpAddr) bool {
|
func (ua *Addr) Equals(t *Addr) bool {
|
||||||
if t == nil || ua == nil {
|
if t == nil || ua == nil {
|
||||||
return t == nil && ua == nil
|
return t == nil && ua == nil
|
||||||
}
|
}
|
||||||
return ua.IP.Equal(t.IP) && ua.Port == t.Port
|
return ua.IP.Equal(t.IP) && ua.Port == t.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ua *udpAddr) String() string {
|
func (ua *Addr) String() string {
|
||||||
if ua == nil {
|
if ua == nil {
|
||||||
return "<nil>"
|
return "<nil>"
|
||||||
}
|
}
|
||||||
|
@ -40,7 +42,7 @@ func (ua *udpAddr) String() string {
|
||||||
return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
|
return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ua *udpAddr) MarshalJSON() ([]byte, error) {
|
func (ua *Addr) MarshalJSON() ([]byte, error) {
|
||||||
if ua == nil {
|
if ua == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -48,12 +50,12 @@ func (ua *udpAddr) MarshalJSON() ([]byte, error) {
|
||||||
return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
|
return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ua *udpAddr) Copy() *udpAddr {
|
func (ua *Addr) Copy() *Addr {
|
||||||
if ua == nil {
|
if ua == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
nu := udpAddr{
|
nu := Addr{
|
||||||
Port: ua.Port,
|
Port: ua.Port,
|
||||||
IP: make(net.IP, len(ua.IP)),
|
IP: make(net.IP, len(ua.IP)),
|
||||||
}
|
}
|
||||||
|
@ -62,7 +64,7 @@ func (ua *udpAddr) Copy() *udpAddr {
|
||||||
return &nu
|
return &nu
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseIPAndPort(s string) (net.IP, uint16, error) {
|
func ParseIPAndPort(s string) (net.IP, uint16, error) {
|
||||||
rIp, sPort, err := net.SplitHostPort(s)
|
rIp, sPort, err := net.SplitHostPort(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
|
@ -1,7 +1,7 @@
|
||||||
//go:build !e2e_testing
|
//go:build !e2e_testing
|
||||||
// +build !e2e_testing
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -34,6 +34,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() error {
|
func (u *Conn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
|
@ -1,7 +1,7 @@
|
||||||
//go:build !e2e_testing
|
//go:build !e2e_testing
|
||||||
// +build !e2e_testing
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
|
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() error {
|
func (u *Conn) Rebind() error {
|
||||||
file, err := u.File()
|
file, err := u.File()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
|
@ -1,7 +1,7 @@
|
||||||
//go:build !e2e_testing
|
//go:build !e2e_testing
|
||||||
// +build !e2e_testing
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
|
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
|
||||||
|
|
||||||
|
@ -36,6 +36,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() error {
|
func (u *Conn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
|
@ -5,7 +5,7 @@
|
||||||
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
||||||
// means it can be used on platforms like Darwin and Windows.
|
// means it can be used on platforms like Darwin and Windows.
|
||||||
|
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -13,36 +13,39 @@ import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
type udpConn struct {
|
type Conn struct {
|
||||||
*net.UDPConn
|
*net.UDPConn
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
|
func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
|
||||||
lc := NewListenConfig(multi)
|
lc := NewListenConfig(multi)
|
||||||
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
|
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if uc, ok := pc.(*net.UDPConn); ok {
|
if uc, ok := pc.(*net.UDPConn); ok {
|
||||||
return &udpConn{UDPConn: uc, l: l}, nil
|
return &Conn{UDPConn: uc, l: l}, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
|
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
func (uc *Conn) WriteTo(b []byte, addr *Addr) error {
|
||||||
_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
|
_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *udpConn) LocalAddr() (*udpAddr, error) {
|
func (uc *Conn) LocalAddr() (*Addr, error) {
|
||||||
a := uc.UDPConn.LocalAddr()
|
a := uc.UDPConn.LocalAddr()
|
||||||
|
|
||||||
switch v := a.(type) {
|
switch v := a.(type) {
|
||||||
case *net.UDPAddr:
|
case *net.UDPAddr:
|
||||||
addr := &udpAddr{IP: make([]byte, len(v.IP))}
|
addr := &Addr{IP: make([]byte, len(v.IP))}
|
||||||
copy(addr.IP, v.IP)
|
copy(addr.IP, v.IP)
|
||||||
addr.Port = uint16(v.Port)
|
addr.Port = uint16(v.Port)
|
||||||
return addr, nil
|
return addr, nil
|
||||||
|
@ -52,11 +55,11 @@ func (uc *udpConn) LocalAddr() (*udpAddr, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) reloadConfig(c *Config) {
|
func (u *Conn) ReloadConfig(c *config.C) {
|
||||||
// TODO
|
// TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(udpConns []*udpConn) func() {
|
func NewUDPStatsEmitter(udpConns []*Conn) func() {
|
||||||
// No UDP stats for non-linux
|
// No UDP stats for non-linux
|
||||||
return func() {}
|
return func() {}
|
||||||
}
|
}
|
||||||
|
@ -65,32 +68,24 @@ type rawMessage struct {
|
||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) ListenOut(f *Interface, q int) {
|
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
||||||
plaintext := make([]byte, mtu)
|
plaintext := make([]byte, MTU)
|
||||||
buffer := make([]byte, mtu)
|
buffer := make([]byte, MTU)
|
||||||
header := &Header{}
|
h := &header.H{}
|
||||||
fwPacket := &FirewallPacket{}
|
fwPacket := &firewall.Packet{}
|
||||||
udpAddr := &udpAddr{IP: make([]byte, 16)}
|
udpAddr := &Addr{IP: make([]byte, 16)}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
|
||||||
|
|
||||||
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
n, rua, err := u.ReadFromUDP(buffer)
|
n, rua, err := u.ReadFromUDP(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to read packets")
|
u.l.WithError(err).Error("Failed to read packets")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
udpAddr.IP = rua.IP
|
udpAddr.IP = rua.IP
|
||||||
udpAddr.Port = uint16(rua.Port)
|
udpAddr.Port = uint16(rua.Port)
|
||||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
|
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
|
|
||||||
return !addr.Equals(newaddr)
|
|
||||||
}
|
|
|
@ -1,7 +1,7 @@
|
||||||
//go:build !android && !e2e_testing
|
//go:build !android && !e2e_testing
|
||||||
// +build !android,!e2e_testing
|
// +build !android,!e2e_testing
|
||||||
|
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
@ -12,14 +12,18 @@ import (
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
//TODO: make it support reload as best you can!
|
//TODO: make it support reload as best you can!
|
||||||
|
|
||||||
type udpConn struct {
|
type Conn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
batch int
|
||||||
}
|
}
|
||||||
|
|
||||||
var x int
|
var x int
|
||||||
|
@ -41,7 +45,7 @@ const (
|
||||||
|
|
||||||
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
|
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
|
func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
|
||||||
syscall.ForkLock.RLock()
|
syscall.ForkLock.RLock()
|
||||||
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -73,36 +77,36 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, e
|
||||||
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
|
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
|
||||||
//l.Println(v, err)
|
//l.Println(v, err)
|
||||||
|
|
||||||
return &udpConn{sysFd: fd, l: l}, err
|
return &Conn{sysFd: fd, l: l, batch: batch}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() error {
|
func (u *Conn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) SetRecvBuffer(n int) error {
|
func (u *Conn) SetRecvBuffer(n int) error {
|
||||||
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
|
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) SetSendBuffer(n int) error {
|
func (u *Conn) SetSendBuffer(n int) error {
|
||||||
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
|
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) GetRecvBuffer() (int, error) {
|
func (u *Conn) GetRecvBuffer() (int, error) {
|
||||||
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
|
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) GetSendBuffer() (int, error) {
|
func (u *Conn) GetSendBuffer() (int, error) {
|
||||||
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
|
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) LocalAddr() (*udpAddr, error) {
|
func (u *Conn) LocalAddr() (*Addr, error) {
|
||||||
sa, err := unix.Getsockname(u.sysFd)
|
sa, err := unix.Getsockname(u.sysFd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := &udpAddr{}
|
addr := &Addr{}
|
||||||
switch sa := sa.(type) {
|
switch sa := sa.(type) {
|
||||||
case *unix.SockaddrInet4:
|
case *unix.SockaddrInet4:
|
||||||
addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
|
addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
|
||||||
|
@ -115,25 +119,21 @@ func (u *udpConn) LocalAddr() (*udpAddr, error) {
|
||||||
return addr, nil
|
return addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) ListenOut(f *Interface, q int) {
|
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
||||||
plaintext := make([]byte, mtu)
|
plaintext := make([]byte, MTU)
|
||||||
header := &Header{}
|
h := &header.H{}
|
||||||
fwPacket := &FirewallPacket{}
|
fwPacket := &firewall.Packet{}
|
||||||
udpAddr := &udpAddr{}
|
udpAddr := &Addr{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
|
||||||
|
|
||||||
//TODO: should we track this?
|
//TODO: should we track this?
|
||||||
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
|
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
|
||||||
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if f.udpBatchSize == 1 {
|
if u.batch == 1 {
|
||||||
read = u.ReadSingle
|
read = u.ReadSingle
|
||||||
}
|
}
|
||||||
|
|
||||||
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -145,12 +145,12 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
udpAddr.IP = names[i][8:24]
|
udpAddr.IP = names[i][8:24]
|
||||||
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
|
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
|
||||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
|
r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
|
func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||||
for {
|
for {
|
||||||
n, _, err := unix.Syscall6(
|
n, _, err := unix.Syscall6(
|
||||||
unix.SYS_RECVMSG,
|
unix.SYS_RECVMSG,
|
||||||
|
@ -171,7 +171,7 @@ func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
|
func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
|
||||||
for {
|
for {
|
||||||
n, _, err := unix.Syscall6(
|
n, _, err := unix.Syscall6(
|
||||||
unix.SYS_RECVMMSG,
|
unix.SYS_RECVMMSG,
|
||||||
|
@ -191,7 +191,7 @@ func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
func (u *Conn) WriteTo(b []byte, addr *Addr) error {
|
||||||
|
|
||||||
var rsa unix.RawSockaddrInet6
|
var rsa unix.RawSockaddrInet6
|
||||||
rsa.Family = unix.AF_INET6
|
rsa.Family = unix.AF_INET6
|
||||||
|
@ -221,7 +221,7 @@ func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) reloadConfig(c *Config) {
|
func (u *Conn) ReloadConfig(c *config.C) {
|
||||||
b := c.GetInt("listen.read_buffer", 0)
|
b := c.GetInt("listen.read_buffer", 0)
|
||||||
if b > 0 {
|
if b > 0 {
|
||||||
err := u.SetRecvBuffer(b)
|
err := u.SetRecvBuffer(b)
|
||||||
|
@ -253,7 +253,7 @@ func (u *udpConn) reloadConfig(c *Config) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
|
func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
|
||||||
var vallen uint32 = 4 * _SK_MEMINFO_VARS
|
var vallen uint32 = 4 * _SK_MEMINFO_VARS
|
||||||
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
|
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
|
@ -262,7 +262,7 @@ func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(udpConns []*udpConn) func() {
|
func NewUDPStatsEmitter(udpConns []*Conn) func() {
|
||||||
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
||||||
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
|
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
|
||||||
var meminfo _SK_MEMINFO
|
var meminfo _SK_MEMINFO
|
||||||
|
@ -293,7 +293,3 @@ func NewUDPStatsEmitter(udpConns []*udpConn) func() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
|
|
||||||
return !addr.Equals(newaddr)
|
|
||||||
}
|
|
|
@ -4,7 +4,7 @@
|
||||||
// +build !android
|
// +build !android
|
||||||
// +build !e2e_testing
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
@ -30,13 +30,13 @@ type rawMessage struct {
|
||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, mtu)
|
buffers[i] = make([]byte, MTU)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
//TODO: this is still silly, no need for an array
|
//TODO: this is still silly, no need for an array
|
|
@ -4,7 +4,7 @@
|
||||||
// +build !android
|
// +build !android
|
||||||
// +build !e2e_testing
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
@ -33,13 +33,13 @@ type rawMessage struct {
|
||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, mtu)
|
buffers[i] = make([]byte, MTU)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
//TODO: this is still silly, no need for an array
|
//TODO: this is still silly, no need for an array
|
|
@ -1,16 +1,19 @@
|
||||||
//go:build e2e_testing
|
//go:build e2e_testing
|
||||||
// +build e2e_testing
|
// +build e2e_testing
|
||||||
|
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UdpPacket struct {
|
type Packet struct {
|
||||||
ToIp net.IP
|
ToIp net.IP
|
||||||
ToPort uint16
|
ToPort uint16
|
||||||
FromIp net.IP
|
FromIp net.IP
|
||||||
|
@ -18,8 +21,8 @@ type UdpPacket struct {
|
||||||
Data []byte
|
Data []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UdpPacket) Copy() *UdpPacket {
|
func (u *Packet) Copy() *Packet {
|
||||||
n := &UdpPacket{
|
n := &Packet{
|
||||||
ToIp: make(net.IP, len(u.ToIp)),
|
ToIp: make(net.IP, len(u.ToIp)),
|
||||||
ToPort: u.ToPort,
|
ToPort: u.ToPort,
|
||||||
FromIp: make(net.IP, len(u.FromIp)),
|
FromIp: make(net.IP, len(u.FromIp)),
|
||||||
|
@ -33,20 +36,20 @@ func (u *UdpPacket) Copy() *UdpPacket {
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpConn struct {
|
type Conn struct {
|
||||||
addr *udpAddr
|
Addr *Addr
|
||||||
|
|
||||||
rxPackets chan *UdpPacket // Packets to receive into nebula
|
RxPackets chan *Packet // Packets to receive into nebula
|
||||||
txPackets chan *UdpPacket // Packets transmitted outside by nebula
|
TxPackets chan *Packet // Packets transmitted outside by nebula
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error) {
|
func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) {
|
||||||
return &udpConn{
|
return &Conn{
|
||||||
addr: &udpAddr{net.ParseIP(ip), uint16(port)},
|
Addr: &Addr{net.ParseIP(ip), uint16(port)},
|
||||||
rxPackets: make(chan *UdpPacket, 1),
|
RxPackets: make(chan *Packet, 1),
|
||||||
txPackets: make(chan *UdpPacket, 1),
|
TxPackets: make(chan *Packet, 1),
|
||||||
l: l,
|
l: l,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -54,8 +57,8 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error
|
||||||
// Send will place a UdpPacket onto the receive queue for nebula to consume
|
// Send will place a UdpPacket onto the receive queue for nebula to consume
|
||||||
// this is an encrypted packet or a handshake message in most cases
|
// this is an encrypted packet or a handshake message in most cases
|
||||||
// packets were transmitted from another nebula node, you can send them with Tun.Send
|
// packets were transmitted from another nebula node, you can send them with Tun.Send
|
||||||
func (u *udpConn) Send(packet *UdpPacket) {
|
func (u *Conn) Send(packet *Packet) {
|
||||||
h := &Header{}
|
h := &header.H{}
|
||||||
if err := h.Parse(packet.Data); err != nil {
|
if err := h.Parse(packet.Data); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -63,19 +66,19 @@ func (u *udpConn) Send(packet *UdpPacket) {
|
||||||
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
|
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
|
||||||
WithField("dataLen", len(packet.Data)).
|
WithField("dataLen", len(packet.Data)).
|
||||||
Info("UDP receiving injected packet")
|
Info("UDP receiving injected packet")
|
||||||
u.rxPackets <- packet
|
u.RxPackets <- packet
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get will pull a UdpPacket from the transmit queue
|
// Get will pull a UdpPacket from the transmit queue
|
||||||
// nebula meant to send this message on the network, it will be encrypted
|
// nebula meant to send this message on the network, it will be encrypted
|
||||||
// packets were ingested from the tun side (in most cases), you can send them with Tun.Send
|
// packets were ingested from the tun side (in most cases), you can send them with Tun.Send
|
||||||
func (u *udpConn) Get(block bool) *UdpPacket {
|
func (u *Conn) Get(block bool) *Packet {
|
||||||
if block {
|
if block {
|
||||||
return <-u.txPackets
|
return <-u.TxPackets
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case p := <-u.txPackets:
|
case p := <-u.TxPackets:
|
||||||
return p
|
return p
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
@ -86,56 +89,49 @@ func (u *udpConn) Get(block bool) *UdpPacket {
|
||||||
// Below this is boilerplate implementation to make nebula actually work
|
// Below this is boilerplate implementation to make nebula actually work
|
||||||
//********************************************************************************************************************//
|
//********************************************************************************************************************//
|
||||||
|
|
||||||
func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
func (u *Conn) WriteTo(b []byte, addr *Addr) error {
|
||||||
p := &UdpPacket{
|
p := &Packet{
|
||||||
Data: make([]byte, len(b), len(b)),
|
Data: make([]byte, len(b), len(b)),
|
||||||
FromIp: make([]byte, 16),
|
FromIp: make([]byte, 16),
|
||||||
FromPort: u.addr.Port,
|
FromPort: u.Addr.Port,
|
||||||
ToIp: make([]byte, 16),
|
ToIp: make([]byte, 16),
|
||||||
ToPort: addr.Port,
|
ToPort: addr.Port,
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(p.Data, b)
|
copy(p.Data, b)
|
||||||
copy(p.ToIp, addr.IP.To16())
|
copy(p.ToIp, addr.IP.To16())
|
||||||
copy(p.FromIp, u.addr.IP.To16())
|
copy(p.FromIp, u.Addr.IP.To16())
|
||||||
|
|
||||||
u.txPackets <- p
|
u.TxPackets <- p
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) ListenOut(f *Interface, q int) {
|
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
||||||
plaintext := make([]byte, mtu)
|
plaintext := make([]byte, MTU)
|
||||||
header := &Header{}
|
h := &header.H{}
|
||||||
fwPacket := &FirewallPacket{}
|
fwPacket := &firewall.Packet{}
|
||||||
ua := &udpAddr{IP: make([]byte, 16)}
|
ua := &Addr{IP: make([]byte, 16)}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
|
||||||
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
p := <-u.rxPackets
|
p := <-u.RxPackets
|
||||||
ua.Port = p.FromPort
|
ua.Port = p.FromPort
|
||||||
copy(ua.IP, p.FromIp.To16())
|
copy(ua.IP, p.FromIp.To16())
|
||||||
f.readOutsidePackets(ua, plaintext[:0], p.Data, header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
|
r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) reloadConfig(*Config) {}
|
func (u *Conn) ReloadConfig(*config.C) {}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(_ []*udpConn) func() {
|
func NewUDPStatsEmitter(_ []*Conn) func() {
|
||||||
// No UDP stats for non-linux
|
// No UDP stats for non-linux
|
||||||
return func() {}
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) LocalAddr() (*udpAddr, error) {
|
func (u *Conn) LocalAddr() (*Addr, error) {
|
||||||
return u.addr, nil
|
return u.Addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() error {
|
func (u *Conn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
|
|
||||||
return !addr.Equals(newaddr)
|
|
||||||
}
|
|
|
@ -1,7 +1,7 @@
|
||||||
//go:build !e2e_testing
|
//go:build !e2e_testing
|
||||||
// +build !e2e_testing
|
// +build !e2e_testing
|
||||||
|
|
||||||
package nebula
|
package udp
|
||||||
|
|
||||||
// Windows support is primarily implemented in udp_generic, besides NewListenConfig
|
// Windows support is primarily implemented in udp_generic, besides NewListenConfig
|
||||||
|
|
||||||
|
@ -24,6 +24,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Rebind() error {
|
func (u *Conn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package nebula
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -17,13 +17,12 @@ func NewTestLogger() *logrus.Logger {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v {
|
switch v {
|
||||||
case "1":
|
|
||||||
// This is the default level but we are being explicit
|
|
||||||
l.SetLevel(logrus.InfoLevel)
|
|
||||||
case "2":
|
case "2":
|
||||||
l.SetLevel(logrus.DebugLevel)
|
l.SetLevel(logrus.DebugLevel)
|
||||||
case "3":
|
case "3":
|
||||||
l.SetLevel(logrus.TraceLevel)
|
l.SetLevel(logrus.TraceLevel)
|
||||||
|
default:
|
||||||
|
l.SetLevel(logrus.InfoLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
return l
|
return l
|
Loading…
Reference in New Issue