Public Release
This commit is contained in:
161
sshd/command.go
Normal file
161
sshd/command.go
Normal file
@ -0,0 +1,161 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/armon/go-radix"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CommandFlags is a function called before help or command execution to parse command line flags
|
||||
// It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags
|
||||
type CommandFlags func() (*flag.FlagSet, interface{})
|
||||
|
||||
// CommandCallback is the function called when your command should execute.
|
||||
// fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved
|
||||
// and handled automatically for you.
|
||||
// a will be any unconsumed arguments, if no Command.Flags was available this will be all the flags passed in.
|
||||
// w is the writer to use when sending messages back to the client.
|
||||
// If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user
|
||||
// where appropriate
|
||||
type CommandCallback func(fs interface{}, a []string, w StringWriter) error
|
||||
|
||||
type Command struct {
|
||||
Name string
|
||||
ShortDescription string
|
||||
Help string
|
||||
Flags CommandFlags
|
||||
Callback CommandCallback
|
||||
}
|
||||
|
||||
func execCommand(c *Command, args []string, w StringWriter) error {
|
||||
var (
|
||||
fl *flag.FlagSet
|
||||
fs interface{}
|
||||
)
|
||||
|
||||
if c.Flags != nil {
|
||||
fl, fs = c.Flags()
|
||||
if fl != nil {
|
||||
//TODO: handle the error
|
||||
fl.Parse(args)
|
||||
args = fl.Args()
|
||||
}
|
||||
}
|
||||
|
||||
return c.Callback(fs, args, w)
|
||||
}
|
||||
|
||||
func dumpCommands(c *radix.Tree, w StringWriter) {
|
||||
err := w.WriteLine("Available commands:")
|
||||
if err != nil {
|
||||
//TODO: log
|
||||
return
|
||||
}
|
||||
|
||||
cmds := make([]string, 0)
|
||||
for _, l := range allCommands(c) {
|
||||
cmds = append(cmds, fmt.Sprintf("%s - %s", l.Name, l.ShortDescription))
|
||||
}
|
||||
|
||||
sort.Strings(cmds)
|
||||
err = w.Write(strings.Join(cmds, "\n") + "\n\n")
|
||||
if err != nil {
|
||||
//TODO: log
|
||||
}
|
||||
}
|
||||
|
||||
func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
|
||||
cmd, ok := c.Get(sCmd)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
command, ok := cmd.(*Command)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to cast command")
|
||||
}
|
||||
|
||||
return command, nil
|
||||
}
|
||||
|
||||
func matchCommand(c *radix.Tree, cmd string) []string {
|
||||
cmds := make([]string, 0)
|
||||
c.WalkPrefix(cmd, func(found string, v interface{}) bool {
|
||||
cmds = append(cmds, found)
|
||||
return false
|
||||
})
|
||||
sort.Strings(cmds)
|
||||
return cmds
|
||||
}
|
||||
|
||||
func allCommands(c *radix.Tree) []*Command {
|
||||
cmds := make([]*Command, 0)
|
||||
c.WalkPrefix("", func(found string, v interface{}) bool {
|
||||
cmd, ok := v.(*Command)
|
||||
if ok {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
return false
|
||||
})
|
||||
return cmds
|
||||
}
|
||||
|
||||
func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error) {
|
||||
// Just typed help
|
||||
if len(a) == 0 {
|
||||
dumpCommands(commands, w)
|
||||
return nil
|
||||
}
|
||||
|
||||
// We are printing a specific commands help text
|
||||
cmd, err := lookupCommand(commands, a[0])
|
||||
if err != nil {
|
||||
//TODO: handle error
|
||||
//TODO: message the user
|
||||
return
|
||||
}
|
||||
|
||||
if cmd != nil {
|
||||
err = w.WriteLine(fmt.Sprintf("%s - %s", cmd.Name, cmd.ShortDescription))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cmd.Help != "" {
|
||||
err = w.WriteLine(fmt.Sprintf(" %s", cmd.Help))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flags != nil {
|
||||
fs, _ := cmd.Flags()
|
||||
if fs != nil {
|
||||
fs.SetOutput(w.GetWriter())
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err = w.WriteLine("Command not available " + a[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkHelpArgs(args []string) bool {
|
||||
for _, a := range args {
|
||||
if a == "-h" || a == "-help" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
182
sshd/server.go
Normal file
182
sshd/server.go
Normal file
@ -0,0 +1,182 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/armon/go-radix"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"net"
|
||||
)
|
||||
|
||||
type SSHServer struct {
|
||||
config *ssh.ServerConfig
|
||||
l *logrus.Entry
|
||||
|
||||
// Map of user -> authorized keys
|
||||
trustedKeys map[string]map[string]bool
|
||||
|
||||
// List of available commands
|
||||
helpCommand *Command
|
||||
commands *radix.Tree
|
||||
listener net.Listener
|
||||
conns map[int]*session
|
||||
counter int
|
||||
}
|
||||
|
||||
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
|
||||
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
|
||||
s := &SSHServer{
|
||||
trustedKeys: make(map[string]map[string]bool),
|
||||
l: l,
|
||||
commands: radix.New(),
|
||||
conns: make(map[int]*session),
|
||||
}
|
||||
|
||||
s.config = &ssh.ServerConfig{
|
||||
PublicKeyCallback: s.matchPubKey,
|
||||
//TODO: AuthLogCallback: s.authAttempt,
|
||||
//TODO: version string
|
||||
ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
|
||||
}
|
||||
|
||||
s.RegisterCommand(&Command{
|
||||
Name: "help",
|
||||
ShortDescription: "prints available commands or help <command> for specific usage info",
|
||||
Callback: func(a interface{}, args []string, w StringWriter) error {
|
||||
return helpCallback(s.commands, args, w)
|
||||
},
|
||||
})
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error {
|
||||
private, err := ssh.ParsePrivateKey(hostPrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private key: %s", err)
|
||||
}
|
||||
|
||||
s.config.AddHostKey(private)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSHServer) ClearAuthorizedKeys() {
|
||||
s.trustedKeys = make(map[string]map[string]bool)
|
||||
}
|
||||
|
||||
// AddAuthorizedKey adds an ssh public key for a user
|
||||
func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
|
||||
pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tk, ok := s.trustedKeys[user]
|
||||
if !ok {
|
||||
tk = make(map[string]bool)
|
||||
s.trustedKeys[user] = tk
|
||||
}
|
||||
|
||||
tk[string(pk.Marshal())] = true
|
||||
s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterCommand adds a command that can be run by a user, by default only `help` is available
|
||||
func (s *SSHServer) RegisterCommand(c *Command) {
|
||||
s.commands.Insert(c.Name, c)
|
||||
}
|
||||
|
||||
// Run begins listening and accepting connections
|
||||
func (s *SSHServer) Run(addr string) error {
|
||||
var err error
|
||||
s.listener, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.l.WithField("sshListener", addr).Info("SSH server is listening")
|
||||
for {
|
||||
c, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
s.l.WithError(err).Warn("Error in listener, shutting down")
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
|
||||
fp := ""
|
||||
if conn != nil {
|
||||
fp = conn.Permissions.Extensions["fp"]
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
|
||||
if conn != nil {
|
||||
l = l.WithField("sshUser", conn.User())
|
||||
conn.Close()
|
||||
}
|
||||
if fp != "" {
|
||||
l = l.WithField("sshFingerprint", fp)
|
||||
}
|
||||
l.Warn("failed to handshake")
|
||||
continue
|
||||
}
|
||||
|
||||
l := s.l.WithField("sshUser", conn.User())
|
||||
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
|
||||
|
||||
session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
|
||||
s.counter++
|
||||
counter := s.counter
|
||||
s.conns[counter] = session
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
go func() {
|
||||
<-session.exitChan
|
||||
s.l.WithField("id", counter).Debug("closing conn")
|
||||
delete(s.conns, counter)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHServer) Stop() {
|
||||
for _, c := range s.conns {
|
||||
c.Close()
|
||||
}
|
||||
|
||||
if s.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.listener.Close()
|
||||
if err != nil {
|
||||
s.l.WithError(err).Warn("Failed to close the sshd listener")
|
||||
return
|
||||
}
|
||||
|
||||
s.l.Info("SSH server stopped listening")
|
||||
return
|
||||
}
|
||||
|
||||
func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
pk := string(pubKey.Marshal())
|
||||
fp := ssh.FingerprintSHA256(pubKey)
|
||||
|
||||
tk, ok := s.trustedKeys[c.User()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown user %s", c.User())
|
||||
}
|
||||
|
||||
_, ok = tk[pk]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp)
|
||||
}
|
||||
|
||||
return &ssh.Permissions{
|
||||
// Record the public key used for authentication.
|
||||
Extensions: map[string]string{
|
||||
"fp": fp,
|
||||
"user": c.User(),
|
||||
},
|
||||
}, nil
|
||||
}
|
182
sshd/session.go
Normal file
182
sshd/session.go
Normal file
@ -0,0 +1,182 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/anmitsu/go-shlex"
|
||||
"github.com/armon/go-radix"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type session struct {
|
||||
l *logrus.Entry
|
||||
c *ssh.ServerConn
|
||||
term *terminal.Terminal
|
||||
commands *radix.Tree
|
||||
exitChan chan bool
|
||||
}
|
||||
|
||||
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session {
|
||||
s := &session{
|
||||
commands: radix.NewFromMap(commands.ToMap()),
|
||||
l: l,
|
||||
c: conn,
|
||||
exitChan: make(chan bool),
|
||||
}
|
||||
|
||||
s.commands.Insert("logout", &Command{
|
||||
Name: "logout",
|
||||
ShortDescription: "Ends the current session",
|
||||
Callback: func(a interface{}, args []string, w StringWriter) error {
|
||||
s.Close()
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
go s.handleChannels(chans)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
|
||||
for newChannel := range chans {
|
||||
if newChannel.ChannelType() != "session" {
|
||||
s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
|
||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
continue
|
||||
}
|
||||
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
s.l.WithError(err).Warn("could not accept channel")
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleRequests(requests, channel)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
|
||||
for req := range in {
|
||||
var err error
|
||||
//TODO: maybe support window sizing?
|
||||
switch req.Type {
|
||||
case "shell":
|
||||
if s.term == nil {
|
||||
s.term = s.createTerm(channel)
|
||||
err = req.Reply(true, nil)
|
||||
} else {
|
||||
err = req.Reply(false, nil)
|
||||
}
|
||||
|
||||
case "pty-req":
|
||||
err = req.Reply(true, nil)
|
||||
|
||||
case "window-change":
|
||||
err = req.Reply(true, nil)
|
||||
|
||||
case "exec":
|
||||
var payload = struct{ Value string }{}
|
||||
cErr := ssh.Unmarshal(req.Payload, &payload)
|
||||
if cErr == nil {
|
||||
s.dispatchCommand(payload.Value, &stringWriter{channel})
|
||||
} else {
|
||||
//TODO: log it
|
||||
}
|
||||
channel.Close()
|
||||
return
|
||||
|
||||
default:
|
||||
s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request")
|
||||
err = req.Reply(false, nil)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.l.WithError(err).Info("Error handling ssh session requests")
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
|
||||
//TODO: PS1 with nebula cert name
|
||||
term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
|
||||
term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
|
||||
// key 9 is tab
|
||||
if key == 9 {
|
||||
cmds := matchCommand(s.commands, line)
|
||||
if len(cmds) == 1 {
|
||||
return cmds[0] + " ", len(cmds[0]) + 1, true
|
||||
}
|
||||
|
||||
sort.Strings(cmds)
|
||||
term.Write([]byte(strings.Join(cmds, "\n") + "\n\n"))
|
||||
}
|
||||
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
go s.handleInput(channel)
|
||||
return term
|
||||
}
|
||||
|
||||
func (s *session) handleInput(channel ssh.Channel) {
|
||||
defer s.Close()
|
||||
w := &stringWriter{w: s.term}
|
||||
for {
|
||||
line, err := s.term.ReadLine()
|
||||
if err != nil {
|
||||
//TODO: log
|
||||
break
|
||||
}
|
||||
|
||||
s.dispatchCommand(line, w)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) dispatchCommand(line string, w StringWriter) {
|
||||
args, err := shlex.Split(line, true)
|
||||
if err != nil {
|
||||
//todo: LOG IT
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
dumpCommands(s.commands, w)
|
||||
return
|
||||
}
|
||||
|
||||
c, err := lookupCommand(s.commands, args[0])
|
||||
if err != nil {
|
||||
//TODO: handle the error
|
||||
return
|
||||
}
|
||||
|
||||
if c == nil {
|
||||
err := w.WriteLine(fmt.Sprintf("did not understand: %s", line))
|
||||
//TODO: log error
|
||||
_ = err
|
||||
|
||||
dumpCommands(s.commands, w)
|
||||
return
|
||||
}
|
||||
|
||||
if checkHelpArgs(args) {
|
||||
s.dispatchCommand(fmt.Sprintf("%s %s", "help", c.Name), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = execCommand(c, args[1:], w)
|
||||
if err != nil {
|
||||
//TODO: log the error
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *session) Close() {
|
||||
s.c.Close()
|
||||
s.exitChan <- true
|
||||
}
|
32
sshd/writer.go
Normal file
32
sshd/writer.go
Normal file
@ -0,0 +1,32 @@
|
||||
package sshd
|
||||
|
||||
import "io"
|
||||
|
||||
type StringWriter interface {
|
||||
WriteLine(string) error
|
||||
Write(string) error
|
||||
WriteBytes([]byte) error
|
||||
GetWriter() io.Writer
|
||||
}
|
||||
|
||||
type stringWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (w *stringWriter) WriteLine(s string) error {
|
||||
return w.Write(s + "\n")
|
||||
}
|
||||
|
||||
func (w *stringWriter) Write(s string) error {
|
||||
_, err := w.w.Write([]byte(s))
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *stringWriter) WriteBytes(b []byte) error {
|
||||
_, err := w.w.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *stringWriter) GetWriter() io.Writer {
|
||||
return w.w
|
||||
}
|
Reference in New Issue
Block a user