diff --git a/udp_linux.go b/udp_linux.go index e79dac1..8166838 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -71,8 +71,10 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) { var lip [4]byte copy(lip[:], net.ParseIP(ip).To4()) - if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { - return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) + if multi { + if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) + } } if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil { @@ -143,9 +145,13 @@ func (u *udpConn) ListenOut(f *Interface) { //TODO: should we track this? //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize) + read := u.ReadMulti + if f.udpBatchSize == 1 { + read = u.ReadSingle + } for { - n, err := u.ReadMulti(msgs) + n, err := read(msgs) if err != nil { l.WithError(err).Error("Failed to read packets") continue @@ -161,34 +167,24 @@ func (u *udpConn) ListenOut(f *Interface) { } } -func (u *udpConn) Read(addr *udpAddr, b []byte) ([]byte, error) { - var rsa rawSockaddrAny - var rLen = unix.SizeofSockaddrAny - +func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( - unix.SYS_RECVFROM, + unix.SYS_RECVMSG, uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unsafe.Pointer(&rLen)), + uintptr(unsafe.Pointer(&(msgs[0].Hdr))), + 0, + 0, + 0, + 0, ) if err != 0 { - return nil, &net.OpError{Op: "read", Err: err} + return 0, &net.OpError{Op: "recvmsg", Err: err} } - if rsa.Addr.Family == unix.AF_INET { - addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1]) - addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5]) - } else { - addr.Port = 0 - addr.IP = 0 - } - - return b[:n], nil + msgs[0].Len = uint32(n) + return 1, nil } }