Add locking around ssh conns to avoid concurrent map access on reload (#447)
This commit is contained in:
parent
1deb5d98e8
commit
a0735dd7d5
|
@ -1,8 +1,10 @@
|
||||||
package sshd
|
package sshd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/armon/go-radix"
|
"github.com/armon/go-radix"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
@ -20,6 +22,9 @@ type SSHServer struct {
|
||||||
helpCommand *Command
|
helpCommand *Command
|
||||||
commands *radix.Tree
|
commands *radix.Tree
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
|
|
||||||
|
// Locks the conns/counter to avoid concurrent map access
|
||||||
|
connsLock sync.Mutex
|
||||||
conns map[int]*session
|
conns map[int]*session
|
||||||
counter int
|
counter int
|
||||||
}
|
}
|
||||||
|
@ -97,11 +102,24 @@ func (s *SSHServer) Run(addr string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
s.l.WithField("sshListener", addr).Info("SSH server is listening")
|
s.l.WithField("sshListener", addr).Info("SSH server is listening")
|
||||||
|
|
||||||
|
// Run loops until there is an error
|
||||||
|
s.run()
|
||||||
|
s.closeSessions()
|
||||||
|
|
||||||
|
s.l.Info("SSH server stopped listening")
|
||||||
|
// We don't return an error because run logs for us
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SSHServer) run() {
|
||||||
for {
|
for {
|
||||||
c, err := s.listener.Accept()
|
c, err := s.listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
s.l.WithError(err).Warn("Error in listener, shutting down")
|
s.l.WithError(err).Warn("Error in listener, shutting down")
|
||||||
return nil
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
|
conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
|
||||||
|
@ -127,37 +145,38 @@ func (s *SSHServer) Run(addr string) error {
|
||||||
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
|
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
|
||||||
|
|
||||||
session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
|
session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
|
||||||
|
s.connsLock.Lock()
|
||||||
s.counter++
|
s.counter++
|
||||||
counter := s.counter
|
counter := s.counter
|
||||||
s.conns[counter] = session
|
s.conns[counter] = session
|
||||||
|
s.connsLock.Unlock()
|
||||||
|
|
||||||
go ssh.DiscardRequests(reqs)
|
go ssh.DiscardRequests(reqs)
|
||||||
go func() {
|
go func() {
|
||||||
<-session.exitChan
|
<-session.exitChan
|
||||||
s.l.WithField("id", counter).Debug("closing conn")
|
s.l.WithField("id", counter).Debug("closing conn")
|
||||||
|
s.connsLock.Lock()
|
||||||
delete(s.conns, counter)
|
delete(s.conns, counter)
|
||||||
|
s.connsLock.Unlock()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHServer) Stop() {
|
func (s *SSHServer) Stop() {
|
||||||
// Close the listener first, to prevent any new connections being accepted.
|
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
|
||||||
if s.listener != nil {
|
if s.listener != nil {
|
||||||
if err := s.listener.Close(); err != nil {
|
if err := s.listener.Close(); err != nil {
|
||||||
s.l.WithError(err).Warn("Failed to close the sshd listener")
|
s.l.WithError(err).Warn("Failed to close the sshd listener")
|
||||||
} else {
|
|
||||||
s.l.Info("SSH server stopped listening")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Force close all existing connections.
|
func (s *SSHServer) closeSessions() {
|
||||||
// TODO I believe this has a slight race if the listener has just accepted
|
s.connsLock.Lock()
|
||||||
// a connection. Can fix by moving this to the goroutine that's accepting.
|
|
||||||
for _, c := range s.conns {
|
for _, c := range s.conns {
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
s.connsLock.Unlock()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
|
|
Loading…
Reference in New Issue