Add locking around ssh conns to avoid concurrent map access on reload (#447)

This commit is contained in:
Nathan Brown 2021-04-23 14:43:16 -05:00 committed by GitHub
parent 1deb5d98e8
commit a0735dd7d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 12 deletions

View File

@ -1,8 +1,10 @@
package sshd package sshd
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"sync"
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -20,8 +22,11 @@ type SSHServer struct {
helpCommand *Command helpCommand *Command
commands *radix.Tree commands *radix.Tree
listener net.Listener listener net.Listener
conns map[int]*session
counter int // Locks the conns/counter to avoid concurrent map access
connsLock sync.Mutex
conns map[int]*session
counter int
} }
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
@ -97,11 +102,24 @@ func (s *SSHServer) Run(addr string) error {
} }
s.l.WithField("sshListener", addr).Info("SSH server is listening") s.l.WithField("sshListener", addr).Info("SSH server is listening")
// Run loops until there is an error
s.run()
s.closeSessions()
s.l.Info("SSH server stopped listening")
// We don't return an error because run logs for us
return nil
}
func (s *SSHServer) run() {
for { for {
c, err := s.listener.Accept() c, err := s.listener.Accept()
if err != nil { if err != nil {
s.l.WithError(err).Warn("Error in listener, shutting down") if !errors.Is(err, net.ErrClosed) {
return nil s.l.WithError(err).Warn("Error in listener, shutting down")
}
return
} }
conn, chans, reqs, err := ssh.NewServerConn(c, s.config) conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
@ -127,37 +145,38 @@ func (s *SSHServer) Run(addr string) error {
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
s.connsLock.Lock()
s.counter++ s.counter++
counter := s.counter counter := s.counter
s.conns[counter] = session s.conns[counter] = session
s.connsLock.Unlock()
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
go func() { go func() {
<-session.exitChan <-session.exitChan
s.l.WithField("id", counter).Debug("closing conn") s.l.WithField("id", counter).Debug("closing conn")
s.connsLock.Lock()
delete(s.conns, counter) delete(s.conns, counter)
s.connsLock.Unlock()
}() }()
} }
} }
func (s *SSHServer) Stop() { func (s *SSHServer) Stop() {
// Close the listener first, to prevent any new connections being accepted. // Close the listener, this will cause all session to terminate as well, see SSHServer.Run
if s.listener != nil { if s.listener != nil {
if err := s.listener.Close(); err != nil { if err := s.listener.Close(); err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener") s.l.WithError(err).Warn("Failed to close the sshd listener")
} else {
s.l.Info("SSH server stopped listening")
} }
} }
}
// Force close all existing connections. func (s *SSHServer) closeSessions() {
// TODO I believe this has a slight race if the listener has just accepted s.connsLock.Lock()
// a connection. Can fix by moving this to the goroutine that's accepting.
for _, c := range s.conns { for _, c := range s.conns {
c.Close() c.Close()
} }
s.connsLock.Unlock()
return
} }
func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {