diff --git a/lighthouse.go b/lighthouse.go index 3251ef1..c9583a2 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1,6 +1,7 @@ package nebula import ( + "errors" "fmt" "net" "sync" @@ -11,6 +12,8 @@ import ( "github.com/slackhq/nebula/cert" ) +var ErrHostNotKnown = errors.New("host not known") + type LightHouse struct { sync.RWMutex //Because we concurrently read and write to our maps amLighthouse bool @@ -113,7 +116,7 @@ func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]udpAddr, error) { return v, nil } lh.RUnlock() - return nil, fmt.Errorf("host %s not known, queries sent to lighthouses", IntIp(ip)) + return nil, ErrHostNotKnown } // This is asynchronous so no reply should be expected @@ -229,17 +232,8 @@ func NewLhWhoami() *NebulaMeta { // End Quick generators for protobuf -func NewIpAndPortFromUDPAddr(addr udpAddr) *IpAndPort { - return &IpAndPort{Ip: udp2ipInt(&addr), Port: uint32(addr.Port)} -} - -func NewIpAndPortsFromNetIps(ips []udpAddr) *[]*IpAndPort { - var iap []*IpAndPort - for _, e := range ips { - // Only add IPs that aren't my VPN/tun IP - iap = append(iap, NewIpAndPortFromUDPAddr(e)) - } - return &iap +func NewIpAndPortFromUDPAddr(addr udpAddr) IpAndPort { + return IpAndPort{Ip: udp2ipInt(&addr), Port: uint32(addr.Port)} } func (lh *LightHouse) LhUpdateWorker(f EncWriter) { @@ -281,9 +275,68 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) { } } -func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *cert.NebulaCertificate, f EncWriter) { - n := &NebulaMeta{} - err := proto.Unmarshal(p, n) +type LightHouseHandler struct { + lh *LightHouse + nb []byte + out []byte + meta *NebulaMeta + iap []IpAndPort + iapp []*IpAndPort +} + +func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { + lhh := &LightHouseHandler{ + lh: lh, + nb: make([]byte, 12, 12), + out: make([]byte, mtu), + + meta: &NebulaMeta{ + Details: &NebulaMetaDetails{}, + }, + } + + lhh.resizeIpAndPorts(10) + + return lhh +} + +// This method is similar to Reset(), but it re-uses the pointer structs +// so that we don't have to re-allocate them +func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { + details := lhh.meta.Details + + details.Reset() + lhh.meta.Reset() + lhh.meta.Details = details + + return lhh.meta +} + +func (lhh *LightHouseHandler) resizeIpAndPorts(n int) { + if cap(lhh.iap) < n { + lhh.iap = make([]IpAndPort, n) + lhh.iapp = make([]*IpAndPort, n) + + for i := range lhh.iap { + lhh.iapp[i] = &lhh.iap[i] + } + } + lhh.iap = lhh.iap[:n] + lhh.iapp = lhh.iapp[:n] +} + +func (lhh *LightHouseHandler) setIpAndPortsFromNetIps(ips []udpAddr) []*IpAndPort { + lhh.resizeIpAndPorts(len(ips)) + for i, e := range ips { + lhh.iap[i] = NewIpAndPortFromUDPAddr(e) + } + return lhh.iapp +} + +func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *cert.NebulaCertificate, f EncWriter) { + lh := lhh.lh + n := lhh.resetMeta() + err := proto.UnmarshalMerge(p, n) if err != nil { l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") @@ -314,21 +367,18 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c //l.Debugf("Can't answer query %s from %s because error: %s", IntIp(n.Details.VpnIp), rAddr, err) return } else { - iap := NewIpAndPortsFromNetIps(ips) - answer := &NebulaMeta{ - Type: NebulaMeta_HostQueryReply, - Details: &NebulaMetaDetails{ - VpnIp: n.Details.VpnIp, - IpAndPorts: *iap, - }, - } - reply, err := proto.Marshal(answer) + reqVpnIP := n.Details.VpnIp + n = lhh.resetMeta() + n.Type = NebulaMeta_HostQueryReply + n.Details.VpnIp = reqVpnIP + n.Details.IpAndPorts = lhh.setIpAndPortsFromNetIps(ips) + reply, err := proto.Marshal(n) if err != nil { l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply") return } lh.metricTx(NebulaMeta_HostQueryReply, 1) - f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, lhh.nb, lhh.out[:0]) // This signals the other side to punch some zero byte udp packets ips, err = lh.Query(vpnIp, f) @@ -337,17 +387,13 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c return } else { //l.Debugln("Notify host to punch", iap) - iap = NewIpAndPortsFromNetIps(ips) - answer = &NebulaMeta{ - Type: NebulaMeta_HostPunchNotification, - Details: &NebulaMetaDetails{ - VpnIp: vpnIp, - IpAndPorts: *iap, - }, - } - reply, _ := proto.Marshal(answer) + n = lhh.resetMeta() + n.Type = NebulaMeta_HostPunchNotification + n.Details.VpnIp = vpnIp + n.Details.IpAndPorts = lhh.setIpAndPortsFromNetIps(ips) + reply, _ := proto.Marshal(n) lh.metricTx(NebulaMeta_HostPunchNotification, 1) - f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, reply, lhh.nb, lhh.out[:0]) } //fmt.Println(reply, remoteaddr) } @@ -401,7 +447,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c go func() { time.Sleep(time.Second * 5) l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp)) - f.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), lhh.nb, lhh.out[:0]) }() } } diff --git a/lighthouse_test.go b/lighthouse_test.go index 19b2891..93ac415 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -36,12 +36,19 @@ func TestNewipandportfromudpaddr(t *testing.T) { assert.Equal(t, uint32(12345), meh.Port) } -func TestNewipandportsfromudpaddrs(t *testing.T) { +func TestSetipandportsfromudpaddrs(t *testing.T) { blah := NewUDPAddrFromString("1.2.2.3:12345") blah2 := NewUDPAddrFromString("9.9.9.9:47828") group := []udpAddr{*blah, *blah2} - hah := NewIpAndPortsFromNetIps(group) - assert.IsType(t, &[]*IpAndPort{}, hah) + var lh *LightHouse + lhh := lh.NewRequestHandler() + result := lhh.setIpAndPortsFromNetIps(group) + assert.IsType(t, []*IpAndPort{}, result) + assert.Len(t, result, 2) + assert.Equal(t, uint32(0x01020203), result[0].Ip) + assert.Equal(t, uint32(12345), result[0].Port) + assert.Equal(t, uint32(0x09090909), result[1].Ip) + assert.Equal(t, uint32(47828), result[1].Port) //t.Error(reflect.TypeOf(hah)) } @@ -66,6 +73,57 @@ func Test_lhStaticMapping(t *testing.T) { assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") } +func BenchmarkLighthouseHandleRequest(b *testing.B) { + lh1 := "10.128.0.2" + lh1IP := net.ParseIP(lh1) + + udpServer, _ := NewListener("0.0.0.0", 0, true) + + lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) + + hAddr := NewUDPAddrFromString("4.5.6.7:12345") + hAddr2 := NewUDPAddrFromString("4.5.6.7:12346") + lh.addrMap[3] = []udpAddr{*hAddr, *hAddr2} + + rAddr := NewUDPAddrFromString("1.2.2.3:12345") + rAddr2 := NewUDPAddrFromString("1.2.2.3:12346") + lh.addrMap[2] = []udpAddr{*rAddr, *rAddr2} + + mw := &mockEncWriter{} + + b.Run("notfound", func(b *testing.B) { + lhh := lh.NewRequestHandler() + req := &NebulaMeta{ + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{ + VpnIp: 4, + IpAndPorts: nil, + }, + } + p, err := proto.Marshal(req) + assert.NoError(b, err) + for n := 0; n < b.N; n++ { + lhh.HandleRequest(rAddr, 2, p, nil, mw) + } + }) + b.Run("found", func(b *testing.B) { + lhh := lh.NewRequestHandler() + req := &NebulaMeta{ + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{ + VpnIp: 3, + IpAndPorts: nil, + }, + } + p, err := proto.Marshal(req) + assert.NoError(b, err) + + for n := 0; n < b.N; n++ { + lhh.HandleRequest(rAddr, 2, p, nil, mw) + } + }) +} + //func NewLightHouse(amLighthouse bool, myIp uint32, ips []string, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse { /* diff --git a/outside.go b/outside.go index 166bd5d..064b0a1 100644 --- a/outside.go +++ b/outside.go @@ -17,7 +17,7 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, nb []byte) { +func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte) { err := header.Parse(packet) if err != nil { // TODO: best if we return this and let caller log @@ -66,7 +66,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, return } - f.lightHouse.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f) + lhh.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f) // Fallthrough to the bottom to record incoming traffic diff --git a/udp_generic.go b/udp_generic.go index 94d8cdf..0bafbb6 100644 --- a/udp_generic.go +++ b/udp_generic.go @@ -108,6 +108,8 @@ func (u *udpConn) ListenOut(f *Interface) { udpAddr := &udpAddr{} nb := make([]byte, 12, 12) + lhh := f.lightHouse.NewRequestHandler() + for { // Just read one packet at a time n, rua, err := u.ReadFromUDP(buffer) @@ -117,7 +119,7 @@ func (u *udpConn) ListenOut(f *Interface) { } udpAddr.UDPAddr = *rua - f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, nb) + f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb) } } diff --git a/udp_linux.go b/udp_linux.go index 2cde08d..92866e5 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -146,6 +146,8 @@ func (u *udpConn) ListenOut(f *Interface) { udpAddr := &udpAddr{} nb := make([]byte, 12, 12) + lhh := f.lightHouse.NewRequestHandler() + //TODO: should we track this? //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize) @@ -166,7 +168,7 @@ func (u *udpConn) ListenOut(f *Interface) { udpAddr.IP = binary.BigEndian.Uint32(names[i][4:8]) udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, nb) + f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb) } } }