More like a library (#279)
This commit is contained in:
parent
6238f1550b
commit
68e3e84fdc
57
cert/cert.go
57
cert/cert.go
|
@ -468,6 +468,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
|
|||
return json.Marshal(jc)
|
||||
}
|
||||
|
||||
//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
|
||||
// r, err := nc.Marshal()
|
||||
// if err != nil {
|
||||
// //TODO
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// c, err := UnmarshalNebulaCertificate(r)
|
||||
// return c
|
||||
//}
|
||||
|
||||
func (nc *NebulaCertificate) Copy() *NebulaCertificate {
|
||||
c := &NebulaCertificate{
|
||||
Details: NebulaCertificateDetails{
|
||||
Name: nc.Details.Name,
|
||||
Groups: make([]string, len(nc.Details.Groups)),
|
||||
Ips: make([]*net.IPNet, len(nc.Details.Ips)),
|
||||
Subnets: make([]*net.IPNet, len(nc.Details.Subnets)),
|
||||
NotBefore: nc.Details.NotBefore,
|
||||
NotAfter: nc.Details.NotAfter,
|
||||
PublicKey: make([]byte, len(nc.Details.PublicKey)),
|
||||
IsCA: nc.Details.IsCA,
|
||||
Issuer: nc.Details.Issuer,
|
||||
InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
|
||||
},
|
||||
Signature: make([]byte, len(nc.Signature)),
|
||||
}
|
||||
|
||||
copy(c.Signature, nc.Signature)
|
||||
copy(c.Details.Groups, nc.Details.Groups)
|
||||
copy(c.Details.PublicKey, nc.Details.PublicKey)
|
||||
|
||||
for i, p := range nc.Details.Ips {
|
||||
c.Details.Ips[i] = &net.IPNet{
|
||||
IP: make(net.IP, len(p.IP)),
|
||||
Mask: make(net.IPMask, len(p.Mask)),
|
||||
}
|
||||
copy(c.Details.Ips[i].IP, p.IP)
|
||||
copy(c.Details.Ips[i].Mask, p.Mask)
|
||||
}
|
||||
|
||||
for i, p := range nc.Details.Subnets {
|
||||
c.Details.Subnets[i] = &net.IPNet{
|
||||
IP: make(net.IP, len(p.IP)),
|
||||
Mask: make(net.IPMask, len(p.Mask)),
|
||||
}
|
||||
copy(c.Details.Subnets[i].IP, p.IP)
|
||||
copy(c.Details.Subnets[i].Mask, p.Mask)
|
||||
}
|
||||
|
||||
for g := range nc.Details.InvertedGroups {
|
||||
c.Details.InvertedGroups[g] = struct{}{}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
|
||||
for _, net := range rootIps {
|
||||
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
@ -487,6 +488,17 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
|
|||
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
||||
}
|
||||
|
||||
func TestNebulaCertificate_Copy(t *testing.T) {
|
||||
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
assert.Nil(t, err)
|
||||
|
||||
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
assert.Nil(t, err)
|
||||
cc := c.Copy()
|
||||
|
||||
util.AssertDeepCopyEqual(t, c, cc)
|
||||
}
|
||||
|
||||
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if before.IsZero() {
|
||||
|
@ -499,10 +511,11 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
|
|||
nc := &NebulaCertificate{
|
||||
Details: NebulaCertificateDetails{
|
||||
Name: "test ca",
|
||||
NotBefore: before,
|
||||
NotAfter: after,
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
IsCA: true,
|
||||
InvertedGroups: make(map[string]struct{}),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -544,17 +557,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
|
|||
|
||||
if len(ips) == 0 {
|
||||
ips = []*net.IPNet{
|
||||
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
|
||||
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
|
||||
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
|
||||
{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
|
||||
{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
|
||||
{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
|
||||
}
|
||||
}
|
||||
|
||||
if len(subnets) == 0 {
|
||||
subnets = []*net.IPNet{
|
||||
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
|
||||
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
|
||||
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
|
||||
{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
|
||||
{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
|
||||
{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -566,11 +579,12 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
|
|||
Ips: ips,
|
||||
Subnets: subnets,
|
||||
Groups: groups,
|
||||
NotBefore: before,
|
||||
NotAfter: after,
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
Issuer: issuer,
|
||||
InvertedGroups: make(map[string]struct{}),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ func main() {
|
|||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
|
||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case nebula.ContextualError:
|
||||
|
@ -66,5 +66,10 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !*configTest {
|
||||
c.Start()
|
||||
c.ShutdownBlock()
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
|
|
@ -14,21 +14,16 @@ import (
|
|||
var logger service.Logger
|
||||
|
||||
type program struct {
|
||||
exit chan struct{}
|
||||
configPath *string
|
||||
configTest *bool
|
||||
build string
|
||||
control *nebula.Control
|
||||
}
|
||||
|
||||
func (p *program) Start(s service.Service) error {
|
||||
logger.Info("Nebula service starting.")
|
||||
p.exit = make(chan struct{})
|
||||
// Start should not block.
|
||||
go p.run()
|
||||
return nil
|
||||
}
|
||||
logger.Info("Nebula service starting.")
|
||||
|
||||
func (p *program) run() error {
|
||||
config := nebula.NewConfig()
|
||||
err := config.Load(*p.configPath)
|
||||
if err != nil {
|
||||
|
@ -37,17 +32,22 @@ func (p *program) run() error {
|
|||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
return nebula.Main(config, *p.configTest, true, Build, l, nil, nil)
|
||||
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.control.Start()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *program) Stop(s service.Service) error {
|
||||
logger.Info("Nebula service stopping.")
|
||||
close(p.exit)
|
||||
p.control.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
|
||||
|
||||
if *configPath == "" {
|
||||
ex, err := os.Executable()
|
||||
if err != nil {
|
||||
|
|
|
@ -49,7 +49,7 @@ func main() {
|
|||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
|
||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case nebula.ContextualError:
|
||||
|
@ -60,5 +60,10 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !*configTest {
|
||||
c.Start()
|
||||
c.ShutdownBlock()
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||
|
||||
type Control struct {
|
||||
f *Interface
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type ControlHostInfo struct {
|
||||
VpnIP net.IP `json:"vpnIp"`
|
||||
LocalIndex uint32 `json:"localIndex"`
|
||||
RemoteIndex uint32 `json:"remoteIndex"`
|
||||
RemoteAddrs []udpAddr `json:"remoteAddrs"`
|
||||
CachedPackets int `json:"cachedPackets"`
|
||||
Cert *cert.NebulaCertificate `json:"cert"`
|
||||
MessageCounter uint64 `json:"messageCounter"`
|
||||
CurrentRemote udpAddr `json:"currentRemote"`
|
||||
}
|
||||
|
||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||
func (c *Control) Start() {
|
||||
c.f.run()
|
||||
}
|
||||
|
||||
// Stop signals nebula to shutdown, returns after the shutdown is complete
|
||||
func (c *Control) Stop() {
|
||||
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
|
||||
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
||||
c.f.hostMap.Lock()
|
||||
for _, h := range c.f.hostMap.Hosts {
|
||||
if h.ConnectionState.ready {
|
||||
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
|
||||
Debug("Sending close tunnel message")
|
||||
}
|
||||
}
|
||||
c.f.hostMap.Unlock()
|
||||
c.l.Info("Goodbye")
|
||||
}
|
||||
|
||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||
func (c *Control) ShutdownBlock() {
|
||||
sigChan := make(chan os.Signal)
|
||||
signal.Notify(sigChan, syscall.SIGTERM)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
rawSig := <-sigChan
|
||||
sig := rawSig.String()
|
||||
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
|
||||
c.Stop()
|
||||
}
|
||||
|
||||
// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
|
||||
func (c *Control) RebindUDPServer() {
|
||||
_ = c.f.outside.Rebind()
|
||||
}
|
||||
|
||||
// ListHostmap returns details about the actual or pending (handshaking) hostmap
|
||||
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
|
||||
var hm *HostMap
|
||||
if pendingMap {
|
||||
hm = c.f.handshakeManager.pendingHostMap
|
||||
} else {
|
||||
hm = c.f.hostMap
|
||||
}
|
||||
|
||||
hm.RLock()
|
||||
hosts := make([]ControlHostInfo, len(hm.Hosts))
|
||||
i := 0
|
||||
for _, v := range hm.Hosts {
|
||||
hosts[i] = copyHostInfo(v)
|
||||
i++
|
||||
}
|
||||
hm.RUnlock()
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
|
||||
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
|
||||
var hm *HostMap
|
||||
if pending {
|
||||
hm = c.f.handshakeManager.pendingHostMap
|
||||
} else {
|
||||
hm = c.f.hostMap
|
||||
}
|
||||
|
||||
h, err := hm.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch := copyHostInfo(h)
|
||||
return &ch
|
||||
}
|
||||
|
||||
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
||||
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
|
||||
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
hostInfo.SetRemote(addr.Copy())
|
||||
ch := copyHostInfo(hostInfo)
|
||||
return &ch
|
||||
}
|
||||
|
||||
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
||||
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
|
||||
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !localOnly {
|
||||
c.f.send(
|
||||
closeTunnel,
|
||||
0,
|
||||
hostInfo.ConnectionState,
|
||||
hostInfo,
|
||||
hostInfo.remote,
|
||||
[]byte{},
|
||||
make([]byte, 12, 12),
|
||||
make([]byte, mtu),
|
||||
)
|
||||
}
|
||||
|
||||
c.f.closeTunnel(hostInfo)
|
||||
return true
|
||||
}
|
||||
|
||||
func copyHostInfo(h *HostInfo) ControlHostInfo {
|
||||
addrs := h.RemoteUDPAddrs()
|
||||
chi := ControlHostInfo{
|
||||
VpnIP: int2ip(h.hostId),
|
||||
LocalIndex: h.localIndexId,
|
||||
RemoteIndex: h.remoteIndexId,
|
||||
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
|
||||
CachedPackets: len(h.packetStore),
|
||||
MessageCounter: *h.ConnectionState.messageCounter,
|
||||
}
|
||||
|
||||
if c := h.GetCert(); c != nil {
|
||||
chi.Cert = c.Copy()
|
||||
}
|
||||
|
||||
if h.remote != nil {
|
||||
chi.CurrentRemote = *h.remote
|
||||
}
|
||||
|
||||
for i, addr := range addrs {
|
||||
chi.RemoteAddrs[i] = addr.Copy()
|
||||
}
|
||||
|
||||
return chi
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||
// To properly ensure we are not exposing core memory to the caller
|
||||
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||
remote1 := NewUDPAddr(100, 4444)
|
||||
remote2 := NewUDPAddr(101, 4444)
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
ipNet2 := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 5),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
crt := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "test",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Subnets: []*net.IPNet{},
|
||||
Groups: []string{"default-group"},
|
||||
NotBefore: time.Unix(1, 0),
|
||||
NotAfter: time.Unix(2, 0),
|
||||
PublicKey: []byte{5, 6, 7, 8},
|
||||
IsCA: false,
|
||||
Issuer: "the-issuer",
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
},
|
||||
Signature: []byte{1, 2, 1, 2, 1, 3},
|
||||
}
|
||||
counter := uint64(0)
|
||||
|
||||
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
|
||||
hm.Add(ip2int(ipNet.IP), &HostInfo{
|
||||
remote: remote1,
|
||||
Remotes: remotes,
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: crt,
|
||||
messageCounter: &counter,
|
||||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
hostId: ip2int(ipNet.IP),
|
||||
})
|
||||
|
||||
hm.Add(ip2int(ipNet2.IP), &HostInfo{
|
||||
remote: remote1,
|
||||
Remotes: remotes,
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: nil,
|
||||
messageCounter: &counter,
|
||||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
hostId: ip2int(ipNet2.IP),
|
||||
})
|
||||
|
||||
c := Control{
|
||||
f: &Interface{
|
||||
hostMap: hm,
|
||||
},
|
||||
l: logrus.New(),
|
||||
}
|
||||
|
||||
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
|
||||
|
||||
expectedInfo := ControlHostInfo{
|
||||
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
|
||||
LocalIndex: 201,
|
||||
RemoteIndex: 200,
|
||||
RemoteAddrs: []udpAddr{*remote1, *remote2},
|
||||
CachedPackets: 0,
|
||||
Cert: crt.Copy(),
|
||||
MessageCounter: 0,
|
||||
CurrentRemote: *NewUDPAddr(100, 4444),
|
||||
}
|
||||
|
||||
// Make sure we don't have any unexpected fields
|
||||
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
|
||||
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||
|
||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||
assert.NotPanics(t, func() {
|
||||
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
|
||||
})
|
||||
}
|
||||
|
||||
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
|
||||
val := reflect.ValueOf(actualStruct).Elem()
|
||||
fields := make([]string, val.NumField())
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
fields[i] = val.Type().Field(i).Name
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, fields)
|
||||
}
|
10
firewall.go
10
firewall.go
|
@ -221,11 +221,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
|
|||
|
||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
||||
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
|
||||
// https://github.com/golang/go/issues/14131
|
||||
sIp := ""
|
||||
if ip != nil {
|
||||
sIp = ip.String()
|
||||
}
|
||||
|
||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||
ruleString := fmt.Sprintf(
|
||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
|
||||
incoming, proto, startPort, endPort, groups, host, ip, caName, caSha,
|
||||
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
|
||||
)
|
||||
f.rules += ruleString + "\n"
|
||||
|
||||
|
@ -233,7 +239,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||
if !incoming {
|
||||
direction = "outgoing"
|
||||
}
|
||||
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}).
|
||||
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||
Info("Firewall rule added")
|
||||
|
||||
var (
|
||||
|
|
2
go.mod
2
go.mod
|
@ -22,7 +22,7 @@ require (
|
|||
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
|
||||
github.com/sirupsen/logrus v1.4.2
|
||||
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
|
||||
github.com/stretchr/testify v1.4.0
|
||||
github.com/stretchr/testify v1.6.1
|
||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
|
||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
|
||||
|
|
8
go.sum
8
go.sum
|
@ -103,8 +103,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
|||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
|
||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
||||
|
@ -112,8 +112,6 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
|
|||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
|
||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g=
|
||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
|
||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
|
@ -154,3 +152,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
|||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
|
||||
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
17
interface.go
17
interface.go
|
@ -35,7 +35,10 @@ type InterfaceConfig struct {
|
|||
DropLocalBroadcast bool
|
||||
DropMulticast bool
|
||||
UDPBatchSize int
|
||||
udpQueues int
|
||||
tunQueues int
|
||||
MessageMetrics *MessageMetrics
|
||||
version string
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
|
@ -54,6 +57,8 @@ type Interface struct {
|
|||
dropLocalBroadcast bool
|
||||
dropMulticast bool
|
||||
udpBatchSize int
|
||||
udpQueues int
|
||||
tunQueues int
|
||||
version string
|
||||
|
||||
metricHandshakes metrics.Histogram
|
||||
|
@ -89,6 +94,9 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
|||
dropLocalBroadcast: c.DropLocalBroadcast,
|
||||
dropMulticast: c.DropMulticast,
|
||||
udpBatchSize: c.UDPBatchSize,
|
||||
udpQueues: c.udpQueues,
|
||||
tunQueues: c.tunQueues,
|
||||
version: c.version,
|
||||
|
||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
messageMetrics: c.MessageMetrics,
|
||||
|
@ -99,29 +107,28 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
|||
return ifce, nil
|
||||
}
|
||||
|
||||
func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) {
|
||||
func (f *Interface) run() {
|
||||
// actually turn on tun dev
|
||||
if err := f.inside.Activate(); err != nil {
|
||||
l.Fatal(err)
|
||||
}
|
||||
|
||||
f.version = buildVersion
|
||||
addr, err := f.outside.LocalAddr()
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to get udp listen address")
|
||||
}
|
||||
|
||||
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
|
||||
WithField("build", buildVersion).WithField("udpAddr", addr).
|
||||
WithField("build", f.version).WithField("udpAddr", addr).
|
||||
Info("Nebula interface is active")
|
||||
|
||||
// Launch n queues to read packets from udp
|
||||
for i := 0; i < udpRoutines; i++ {
|
||||
for i := 0; i < f.udpQueues; i++ {
|
||||
go f.listenOut(i)
|
||||
}
|
||||
|
||||
// Launch n queues to read packets from tun dev
|
||||
for i := 0; i < tunRoutines; i++ {
|
||||
for i := 0; i < f.tunQueues; i++ {
|
||||
go f.listenIn(i)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -15,10 +17,16 @@ func NewContextualError(msg string, fields map[string]interface{}, realError err
|
|||
}
|
||||
|
||||
func (ce ContextualError) Error() string {
|
||||
if ce.RealError == nil {
|
||||
return ce.Context
|
||||
}
|
||||
return ce.RealError.Error()
|
||||
}
|
||||
|
||||
func (ce ContextualError) Unwrap() error {
|
||||
if ce.RealError == nil {
|
||||
return errors.New(ce.Context)
|
||||
}
|
||||
return ce.RealError
|
||||
}
|
||||
|
||||
|
|
123
main.go
123
main.go
|
@ -4,11 +4,8 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -21,12 +18,7 @@ var l = logrus.New()
|
|||
|
||||
type m map[string]interface{}
|
||||
|
||||
type CommandRequest struct {
|
||||
Command string
|
||||
Callback chan error
|
||||
}
|
||||
|
||||
func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error {
|
||||
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
|
||||
l = logger
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
|
@ -36,7 +28,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
if configTest {
|
||||
b, err := yaml.Marshal(config.Settings)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Print the final config
|
||||
|
@ -45,7 +37,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
|
||||
err := configLogger(config)
|
||||
if err != nil {
|
||||
return NewContextualError("Failed to configure the logger", nil, err)
|
||||
return nil, NewContextualError("Failed to configure the logger", nil, err)
|
||||
}
|
||||
|
||||
config.RegisterReloadCallback(func(c *Config) {
|
||||
|
@ -59,20 +51,20 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
trustedCAs, err = loadCAFromConfig(config)
|
||||
if err != nil {
|
||||
//The errors coming out of loadCA are already nicely formatted
|
||||
return NewContextualError("Failed to load ca from config", nil, err)
|
||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||
}
|
||||
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
|
||||
cs, err := NewCertStateFromConfig(config)
|
||||
if err != nil {
|
||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||
return NewContextualError("Failed to load certificate from config", nil, err)
|
||||
return nil, NewContextualError("Failed to load certificate from config", nil, err)
|
||||
}
|
||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||
|
||||
fw, err := NewFirewallFromConfig(cs.certificate, config)
|
||||
if err != nil {
|
||||
return NewContextualError("Error while loading firewall rules", nil, err)
|
||||
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
||||
}
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
||||
|
||||
|
@ -80,11 +72,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
tunCidr := cs.certificate.Details.Ips[0]
|
||||
routes, err := parseRoutes(config, tunCidr)
|
||||
if err != nil {
|
||||
return NewContextualError("Could not parse tun.routes", nil, err)
|
||||
return nil, NewContextualError("Could not parse tun.routes", nil, err)
|
||||
}
|
||||
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
||||
if err != nil {
|
||||
return NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
}
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
|
@ -92,7 +84,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
if config.GetBool("sshd.enabled", false) {
|
||||
err = configSSH(ssh, config)
|
||||
if err != nil {
|
||||
return NewContextualError("Error while configuring the sshd", nil, err)
|
||||
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,7 +121,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
return NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||
return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,28 +132,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
if !configTest {
|
||||
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
|
||||
if err != nil {
|
||||
return NewContextualError("Failed to open udp listener", nil, err)
|
||||
return nil, NewContextualError("Failed to open udp listener", nil, err)
|
||||
}
|
||||
udpServer.reloadConfig(config)
|
||||
}
|
||||
|
||||
sigChan := make(chan os.Signal)
|
||||
killChan := make(chan CommandRequest)
|
||||
if commandChan != nil {
|
||||
go func() {
|
||||
cmd := CommandRequest{}
|
||||
for {
|
||||
cmd = <-commandChan
|
||||
switch cmd.Command {
|
||||
case "rebind":
|
||||
udpServer.Rebind()
|
||||
case "exit":
|
||||
killChan <- cmd
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Set up my internal host map
|
||||
var preferredRanges []*net.IPNet
|
||||
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
|
||||
|
@ -170,7 +145,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
for _, rawPreferredRange := range rawPreferredRanges {
|
||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||
if err != nil {
|
||||
return NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||
return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||
}
|
||||
preferredRanges = append(preferredRanges, preferredRange)
|
||||
}
|
||||
|
@ -183,7 +158,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
if rawLocalRange != "" {
|
||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||
if err != nil {
|
||||
return NewContextualError("Failed to parse local_range", nil, err)
|
||||
return nil, NewContextualError("Failed to parse local_range", nil, err)
|
||||
}
|
||||
|
||||
// Check if the entry for local_range was already specified in
|
||||
|
@ -223,7 +198,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
if port == 0 && !configTest {
|
||||
uPort, err := udpServer.LocalAddr()
|
||||
if err != nil {
|
||||
return NewContextualError("Failed to get listening port", nil, err)
|
||||
return nil, NewContextualError("Failed to get listening port", nil, err)
|
||||
}
|
||||
port = int(uPort.Port)
|
||||
}
|
||||
|
@ -240,10 +215,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
for i, host := range rawLighthouseHosts {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||
}
|
||||
if !tunCidr.Contains(ip) {
|
||||
return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
lighthouseHosts[i] = ip2int(ip)
|
||||
}
|
||||
|
@ -263,13 +238,13 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
|
||||
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
|
||||
if err != nil {
|
||||
return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetRemoteAllowList(remoteAllowList)
|
||||
|
||||
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
|
||||
if err != nil {
|
||||
return NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetLocalAllowList(localAllowList)
|
||||
|
||||
|
@ -277,7 +252,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
||||
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
|
||||
if !tunCidr.Contains(vpnIp) {
|
||||
return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
vals, ok := v.([]interface{})
|
||||
if ok {
|
||||
|
@ -288,7 +263,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
ip := addr.IP
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||
}
|
||||
|
@ -301,7 +276,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
ip := addr.IP
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||
}
|
||||
|
@ -354,7 +329,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
|
||||
DropMulticast: config.GetBool("tun.drop_multicast", false),
|
||||
UDPBatchSize: config.GetInt("listen.batch", 64),
|
||||
udpQueues: udpQueues,
|
||||
tunQueues: config.GetInt("tun.routines", 1),
|
||||
MessageMetrics: messageMetrics,
|
||||
version: buildVersion,
|
||||
}
|
||||
|
||||
switch ifConfig.Cipher {
|
||||
|
@ -363,14 +341,14 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
case "chachapoly":
|
||||
noiseEndianness = binary.LittleEndian
|
||||
default:
|
||||
return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
|
||||
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
|
||||
}
|
||||
|
||||
var ifce *Interface
|
||||
if !configTest {
|
||||
ifce, err = NewInterface(ifConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize interface: %s", err)
|
||||
return nil, fmt.Errorf("failed to initialize interface: %s", err)
|
||||
}
|
||||
|
||||
ifce.RegisterConfigChangeCallbacks(config)
|
||||
|
@ -381,18 +359,17 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
|
||||
err = startStats(config, configTest)
|
||||
if err != nil {
|
||||
return NewContextualError("Failed to start stats emitter", nil, err)
|
||||
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||
}
|
||||
|
||||
if configTest {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
//TODO: check if we _should_ be emitting stats
|
||||
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
||||
|
||||
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||
ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
|
||||
|
||||
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
||||
if amLighthouse && serveDns {
|
||||
|
@ -400,47 +377,5 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
|
|||
go dnsMain(hostMap, config)
|
||||
}
|
||||
|
||||
if block {
|
||||
// Just sit here and be friendly, main thread.
|
||||
shutdownBlock(ifce, sigChan, killChan)
|
||||
} else {
|
||||
// Even though we aren't blocking we still want to shutdown gracefully
|
||||
go shutdownBlock(ifce, sigChan, killChan)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) {
|
||||
var cmd CommandRequest
|
||||
var sig string
|
||||
|
||||
signal.Notify(sigChan, syscall.SIGTERM)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
select {
|
||||
case rawSig := <-sigChan:
|
||||
sig = rawSig.String()
|
||||
case cmd = <-killChan:
|
||||
sig = "controlling app"
|
||||
}
|
||||
|
||||
l.WithField("signal", sig).Info("Caught signal, shutting down")
|
||||
|
||||
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
|
||||
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
||||
ifce.hostMap.Lock()
|
||||
for _, h := range ifce.hostMap.Hosts {
|
||||
if h.ConnectionState.ready {
|
||||
ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||
l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
|
||||
Debug("Sending close tunnel message")
|
||||
}
|
||||
}
|
||||
ifce.hostMap.Unlock()
|
||||
|
||||
l.WithField("signal", sig).Info("Goodbye")
|
||||
select {
|
||||
case cmd.Callback <- nil:
|
||||
default:
|
||||
}
|
||||
return &Control{ifce, l}, nil
|
||||
}
|
||||
|
|
|
@ -31,6 +31,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() {
|
||||
return
|
||||
func (u *udpConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -33,6 +33,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() {
|
||||
return
|
||||
func (u *udpConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
|
|||
return ua.IP.Equal(t.IP) && ua.Port == t.Port
|
||||
}
|
||||
|
||||
func (ua *udpAddr) Copy() udpAddr {
|
||||
nu := udpAddr{net.UDPAddr{
|
||||
Port: ua.Port,
|
||||
Zone: ua.Zone,
|
||||
IP: make(net.IP, len(ua.IP)),
|
||||
}}
|
||||
|
||||
copy(nu.IP, ua.IP)
|
||||
return nu
|
||||
}
|
||||
|
||||
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
||||
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
|
||||
return err
|
||||
|
|
15
udp_linux.go
15
udp_linux.go
|
@ -89,8 +89,12 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
|||
return &udpConn{sysFd: fd}, err
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() {
|
||||
return
|
||||
func (u *udpConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ua *udpAddr) Copy() udpAddr {
|
||||
return *ua
|
||||
}
|
||||
|
||||
func (u *udpConn) SetRecvBuffer(n int) error {
|
||||
|
@ -282,13 +286,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
|
|||
return ua.IP == t.IP && ua.Port == t.Port
|
||||
}
|
||||
|
||||
func (ua *udpAddr) Copy() *udpAddr {
|
||||
return &udpAddr{
|
||||
Port: ua.Port,
|
||||
IP: ua.IP,
|
||||
}
|
||||
}
|
||||
|
||||
func (ua *udpAddr) String() string {
|
||||
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
|
||||
}
|
||||
|
|
|
@ -21,6 +21,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() {
|
||||
return
|
||||
func (u *udpConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
|
||||
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
|
||||
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
|
||||
v1 := reflect.ValueOf(a)
|
||||
v2 := reflect.ValueOf(b)
|
||||
|
||||
if !assert.Equal(t, v1.Type(), v2.Type()) {
|
||||
return
|
||||
}
|
||||
|
||||
traverseDeepCopy(t, v1, v2, v1.Type().String())
|
||||
}
|
||||
|
||||
func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
|
||||
switch v1.Kind() {
|
||||
case reflect.Array:
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case reflect.Slice:
|
||||
if v1.IsNil() || v2.IsNil() {
|
||||
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
|
||||
}
|
||||
|
||||
if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
|
||||
return false
|
||||
}
|
||||
|
||||
// A slice with cap 0
|
||||
if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
|
||||
return false
|
||||
}
|
||||
|
||||
v1c := v1.Cap()
|
||||
v2c := v2.Cap()
|
||||
if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
|
||||
return assert.Fail(t, "", "%s share some underlying memory", name)
|
||||
}
|
||||
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case reflect.Interface:
|
||||
if v1.IsNil() || v2.IsNil() {
|
||||
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
|
||||
}
|
||||
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
|
||||
|
||||
case reflect.Ptr:
|
||||
local := reflect.ValueOf(time.Local).Pointer()
|
||||
if local == v1.Pointer() && local == v2.Pointer() {
|
||||
return true
|
||||
}
|
||||
|
||||
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
|
||||
return false
|
||||
}
|
||||
|
||||
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
|
||||
|
||||
case reflect.Struct:
|
||||
for i, n := 0, v1.NumField(); i < n; i++ {
|
||||
if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case reflect.Map:
|
||||
if v1.IsNil() || v2.IsNil() {
|
||||
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
|
||||
}
|
||||
|
||||
if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, k := range v1.MapKeys() {
|
||||
val1 := v1.MapIndex(k)
|
||||
val2 := v2.MapIndex(k)
|
||||
if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
default:
|
||||
if v1.CanInterface() && v2.CanInterface() {
|
||||
return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
|
||||
}
|
||||
|
||||
e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
|
||||
e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
|
||||
|
||||
return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue