Merge pull request #16972 from hashicorp/jbardin/ssh-agent-identity

ssh connection `agent_identity`
This commit is contained in:
James Bardin 2018-01-05 16:57:30 -05:00 committed by GitHub
commit bf5944a92c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1256 additions and 359 deletions

View File

@ -1,11 +1,15 @@
package ssh package ssh
import ( import (
"bytes"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net" "net"
"os" "os"
"path/filepath"
"strings"
"time" "time"
"github.com/hashicorp/terraform/communicator/shared" "github.com/hashicorp/terraform/communicator/shared"
@ -50,6 +54,8 @@ type connectionInfo struct {
BastionPrivateKey string `mapstructure:"bastion_private_key"` BastionPrivateKey string `mapstructure:"bastion_private_key"`
BastionHost string `mapstructure:"bastion_host"` BastionHost string `mapstructure:"bastion_host"`
BastionPort int `mapstructure:"bastion_port"` BastionPort int `mapstructure:"bastion_port"`
AgentIdentity string `mapstructure:"agent_identity"`
} }
// parseConnectionInfo is used to convert the ConnInfo of the InstanceState into // parseConnectionInfo is used to convert the ConnInfo of the InstanceState into
@ -186,7 +192,8 @@ type sshClientConfigOpts struct {
func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) { func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
conf := &ssh.ClientConfig{ conf := &ssh.ClientConfig{
User: opts.user, HostKeyCallback: ssh.InsecureIgnoreHostKey(),
User: opts.user,
} }
if opts.privateKey != "" { if opts.privateKey != "" {
@ -246,6 +253,7 @@ func connectToAgent(connInfo *connectionInfo) (*sshAgent, error) {
return &sshAgent{ return &sshAgent{
agent: agent, agent: agent,
conn: conn, conn: conn,
id: connInfo.AgentIdentity,
}, nil }, nil
} }
@ -255,6 +263,7 @@ func connectToAgent(connInfo *connectionInfo) (*sshAgent, error) {
type sshAgent struct { type sshAgent struct {
agent agent.Agent agent agent.Agent
conn net.Conn conn net.Conn
id string
} }
func (a *sshAgent) Close() error { func (a *sshAgent) Close() error {
@ -265,8 +274,127 @@ func (a *sshAgent) Close() error {
return a.conn.Close() return a.conn.Close()
} }
// make an attempt to either read the identity file or find a corresponding
// public key file using the typical openssh naming convention.
// This returns the public key in wire format, or nil when a key is not found.
func findIDPublicKey(id string) []byte {
for _, d := range idKeyData(id) {
signer, err := ssh.ParsePrivateKey(d)
if err == nil {
log.Println("[DEBUG] parsed id private key")
pk := signer.PublicKey()
return pk.Marshal()
}
// try it as a publicKey
pk, err := ssh.ParsePublicKey(d)
if err == nil {
log.Println("[DEBUG] parsed id public key")
return pk.Marshal()
}
// finally try it as an authorized key
pk, _, _, _, err = ssh.ParseAuthorizedKey(d)
if err == nil {
log.Println("[DEBUG] parsed id authorized key")
return pk.Marshal()
}
}
return nil
}
// Try to read an id file using the id as the file path. Also read the .pub
// file if it exists, as the id file may be encrypted. Return only the file
// data read. We don't need to know what data came from which path, as we will
// try parsing each as a private key, a public key and an authorized key
// regardless.
func idKeyData(id string) [][]byte {
idPath, err := filepath.Abs(id)
if err != nil {
return nil
}
var fileData [][]byte
paths := []string{idPath}
if !strings.HasSuffix(idPath, ".pub") {
paths = append(paths, idPath+".pub")
}
for _, p := range paths {
d, err := ioutil.ReadFile(p)
if err != nil {
log.Printf("[DEBUG] error reading %q: %s", p, err)
continue
}
log.Printf("[DEBUG] found identity data at %q", p)
fileData = append(fileData, d)
}
return fileData
}
// sortSigners moves a signer with an agent comment field matching the
// agent_identity to the head of the list when attempting authentication. This
// helps when there are more keys loaded in an agent than the host will allow
// attempts.
func (s *sshAgent) sortSigners(signers []ssh.Signer) {
if s.id == "" || len(signers) < 2 {
return
}
// if we can locate the public key, either by extracting it from the id or
// locating the .pub file, then we can more easily determine an exact match
idPk := findIDPublicKey(s.id)
// if we have a signer with a connect field that matches the id, send that
// first, otherwise put close matches at the front of the list.
head := 0
for i := range signers {
pk := signers[i].PublicKey()
k, ok := pk.(*agent.Key)
if !ok {
continue
}
// check for an exact match first
if bytes.Equal(pk.Marshal(), idPk) || s.id == k.Comment {
signers[0], signers[i] = signers[i], signers[0]
break
}
// no exact match yet, move it to the front if it's close. The agent
// may have loaded as a full filepath, while the config refers to it by
// filename only.
if strings.HasSuffix(k.Comment, s.id) {
signers[head], signers[i] = signers[i], signers[head]
head++
continue
}
}
ss := []string{}
for _, signer := range signers {
pk := signer.PublicKey()
k := pk.(*agent.Key)
ss = append(ss, k.Comment)
}
}
func (s *sshAgent) Signers() ([]ssh.Signer, error) {
signers, err := s.agent.Signers()
if err != nil {
return nil, err
}
s.sortSigners(signers)
return signers, nil
}
func (a *sshAgent) Auth() ssh.AuthMethod { func (a *sshAgent) Auth() ssh.AuthMethod {
return ssh.PublicKeysCallback(a.agent.Signers) return ssh.PublicKeysCallback(a.Signers)
} }
func (a *sshAgent) ForwardToAgent(client *ssh.Client) error { func (a *sshAgent) ForwardToAgent(client *ssh.Client) error {

View File

@ -0,0 +1,105 @@
package ssh
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"io/ioutil"
"os"
"path/filepath"
"testing"
"golang.org/x/crypto/ssh"
)
// verify that we can locate public key data
func TestFindKeyData(t *testing.T) {
// setup a test directory
td, err := ioutil.TempDir("", "ssh")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(td)
cwd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
if err := os.Chdir(td); err != nil {
t.Fatal(err)
}
defer os.Chdir(cwd)
id := "provisioner_id"
pub := generateSSHKey(t, id)
pubData := pub.Marshal()
// backup the pub file, and replace it with a broken file to ensure we
// extract the public key from the private key.
if err := os.Rename(id+".pub", "saved.pub"); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(id+".pub", []byte("not a public key"), 0600); err != nil {
t.Fatal(err)
}
foundData := findIDPublicKey(id)
if !bytes.Equal(foundData, pubData) {
t.Fatalf("public key %q does not match", foundData)
}
// move the pub file back, and break the private key file to simulate an
// encrypted private key
if err := os.Rename("saved.pub", id+".pub"); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(id, []byte("encrypted private key"), 0600); err != nil {
t.Fatal(err)
}
foundData = findIDPublicKey(id)
if !bytes.Equal(foundData, pubData) {
t.Fatalf("public key %q does not match", foundData)
}
// check the file by path too
foundData = findIDPublicKey(filepath.Join(".", id))
if !bytes.Equal(foundData, pubData) {
t.Fatalf("public key %q does not match", foundData)
}
}
func generateSSHKey(t *testing.T, idFile string) ssh.PublicKey {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
privFile, err := os.OpenFile(idFile, os.O_RDWR|os.O_CREATE, 0600)
defer privFile.Close()
if err != nil {
t.Fatal(err)
}
privPEM := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}
if err := pem.Encode(privFile, privPEM); err != nil {
t.Fatal(err)
}
// generate and write public key
pub, err := ssh.NewPublicKey(&priv.PublicKey)
if err != nil {
t.Fatal(err)
}
err = ioutil.WriteFile(idFile+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
if err != nil {
t.Fatal(err)
}
return pub
}

View File

@ -150,6 +150,7 @@ func (n *EvalValidateProvisioner) validateConnConfig(connConfig *ResourceConfig)
BastionUser interface{} `mapstructure:"bastion_user"` BastionUser interface{} `mapstructure:"bastion_user"`
BastionPassword interface{} `mapstructure:"bastion_password"` BastionPassword interface{} `mapstructure:"bastion_password"`
BastionPrivateKey interface{} `mapstructure:"bastion_private_key"` BastionPrivateKey interface{} `mapstructure:"bastion_private_key"`
AgentIdentity interface{} `mapstructure:"agent_identity"`
// For type=winrm only (enforced in winrm communicator) // For type=winrm only (enforced in winrm communicator)
HTTPS interface{} `mapstructure:"https"` HTTPS interface{} `mapstructure:"https"`

View File

@ -57,6 +57,17 @@ type Agent interface {
Signers() ([]ssh.Signer, error) Signers() ([]ssh.Signer, error)
} }
// ConstraintExtension describes an optional constraint defined by users.
type ConstraintExtension struct {
// ExtensionName consist of a UTF-8 string suffixed by the
// implementation domain following the naming scheme defined
// in Section 4.2 of [RFC4251], e.g. "foo@example.com".
ExtensionName string
// ExtensionDetails contains the actual content of the extended
// constraint.
ExtensionDetails []byte
}
// AddedKey describes an SSH key to be added to an Agent. // AddedKey describes an SSH key to be added to an Agent.
type AddedKey struct { type AddedKey struct {
// PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or // PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or
@ -73,6 +84,9 @@ type AddedKey struct {
// ConfirmBeforeUse, if true, requests that the agent confirm with the // ConfirmBeforeUse, if true, requests that the agent confirm with the
// user before each use of this key. // user before each use of this key.
ConfirmBeforeUse bool ConfirmBeforeUse bool
// ConstraintExtensions are the experimental or private-use constraints
// defined by users.
ConstraintExtensions []ConstraintExtension
} }
// See [PROTOCOL.agent], section 3. // See [PROTOCOL.agent], section 3.
@ -84,7 +98,7 @@ const (
agentAddIdentity = 17 agentAddIdentity = 17
agentRemoveIdentity = 18 agentRemoveIdentity = 18
agentRemoveAllIdentities = 19 agentRemoveAllIdentities = 19
agentAddIdConstrained = 25 agentAddIDConstrained = 25
// 3.3 Key-type independent requests from client to agent // 3.3 Key-type independent requests from client to agent
agentAddSmartcardKey = 20 agentAddSmartcardKey = 20
@ -94,8 +108,9 @@ const (
agentAddSmartcardKeyConstrained = 26 agentAddSmartcardKeyConstrained = 26
// 3.7 Key constraint identifiers // 3.7 Key constraint identifiers
agentConstrainLifetime = 1 agentConstrainLifetime = 1
agentConstrainConfirm = 2 agentConstrainConfirm = 2
agentConstrainExtension = 3
) )
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This // maxAgentResponseBytes is the maximum agent reply size that is accepted. This
@ -151,6 +166,19 @@ type publicKey struct {
Rest []byte `ssh:"rest"` Rest []byte `ssh:"rest"`
} }
// 3.7 Key constraint identifiers
type constrainLifetimeAgentMsg struct {
LifetimeSecs uint32 `sshtype:"1"`
}
type constrainExtensionAgentMsg struct {
ExtensionName string `sshtype:"3"`
ExtensionDetails []byte
// Rest is a field used for parsing, not part of message
Rest []byte `ssh:"rest"`
}
// Key represents a protocol 2 public key as defined in // Key represents a protocol 2 public key as defined in
// [PROTOCOL.agent], section 2.5.2. // [PROTOCOL.agent], section 2.5.2.
type Key struct { type Key struct {
@ -487,7 +515,7 @@ func (c *client) insertKey(s interface{}, comment string, constraints []byte) er
// if constraints are present then the message type needs to be changed. // if constraints are present then the message type needs to be changed.
if len(constraints) != 0 { if len(constraints) != 0 {
req[0] = agentAddIdConstrained req[0] = agentAddIDConstrained
} }
resp, err := c.call(req) resp, err := c.call(req)
@ -542,22 +570,18 @@ func (c *client) Add(key AddedKey) error {
var constraints []byte var constraints []byte
if secs := key.LifetimeSecs; secs != 0 { if secs := key.LifetimeSecs; secs != 0 {
constraints = append(constraints, agentConstrainLifetime) constraints = append(constraints, ssh.Marshal(constrainLifetimeAgentMsg{secs})...)
var secsBytes [4]byte
binary.BigEndian.PutUint32(secsBytes[:], secs)
constraints = append(constraints, secsBytes[:]...)
} }
if key.ConfirmBeforeUse { if key.ConfirmBeforeUse {
constraints = append(constraints, agentConstrainConfirm) constraints = append(constraints, agentConstrainConfirm)
} }
if cert := key.Certificate; cert == nil { cert := key.Certificate
if cert == nil {
return c.insertKey(key.PrivateKey, key.Comment, constraints) return c.insertKey(key.PrivateKey, key.Comment, constraints)
} else {
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
} }
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
} }
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error { func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error {
@ -609,7 +633,7 @@ func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string
// if constraints are present then the message type needs to be changed. // if constraints are present then the message type needs to be changed.
if len(constraints) != 0 { if len(constraints) != 0 {
req[0] = agentAddIdConstrained req[0] = agentAddIDConstrained
} }
signer, err := ssh.NewSignerFromKey(s) signer, err := ssh.NewSignerFromKey(s)

View File

@ -106,7 +106,7 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
return nil, s.agent.Lock(req.Passphrase) return nil, s.agent.Lock(req.Passphrase)
case agentUnlock: case agentUnlock:
var req agentLockMsg var req agentUnlockMsg
if err := ssh.Unmarshal(data, &req); err != nil { if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err return nil, err
} }
@ -148,13 +148,51 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
} }
return rep, nil return rep, nil
case agentAddIdConstrained, agentAddIdentity: case agentAddIDConstrained, agentAddIdentity:
return nil, s.insertIdentity(data) return nil, s.insertIdentity(data)
} }
return nil, fmt.Errorf("unknown opcode %d", data[0]) return nil, fmt.Errorf("unknown opcode %d", data[0])
} }
func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
for len(constraints) != 0 {
switch constraints[0] {
case agentConstrainLifetime:
lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
constraints = constraints[5:]
case agentConstrainConfirm:
confirmBeforeUse = true
constraints = constraints[1:]
case agentConstrainExtension:
var msg constrainExtensionAgentMsg
if err = ssh.Unmarshal(constraints, &msg); err != nil {
return 0, false, nil, err
}
extensions = append(extensions, ConstraintExtension{
ExtensionName: msg.ExtensionName,
ExtensionDetails: msg.ExtensionDetails,
})
constraints = msg.Rest
default:
return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
}
}
return
}
func setConstraints(key *AddedKey, constraintBytes []byte) error {
lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
if err != nil {
return err
}
key.LifetimeSecs = lifetimeSecs
key.ConfirmBeforeUse = confirmBeforeUse
key.ConstraintExtensions = constraintExtensions
return nil
}
func parseRSAKey(req []byte) (*AddedKey, error) { func parseRSAKey(req []byte) (*AddedKey, error) {
var k rsaKeyMsg var k rsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil { if err := ssh.Unmarshal(req, &k); err != nil {
@ -173,7 +211,11 @@ func parseRSAKey(req []byte) (*AddedKey, error) {
} }
priv.Precompute() priv.Precompute()
return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
} }
func parseEd25519Key(req []byte) (*AddedKey, error) { func parseEd25519Key(req []byte) (*AddedKey, error) {
@ -182,7 +224,12 @@ func parseEd25519Key(req []byte) (*AddedKey, error) {
return nil, err return nil, err
} }
priv := ed25519.PrivateKey(k.Priv) priv := ed25519.PrivateKey(k.Priv)
return &AddedKey{PrivateKey: &priv, Comment: k.Comments}, nil
addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
} }
func parseDSAKey(req []byte) (*AddedKey, error) { func parseDSAKey(req []byte) (*AddedKey, error) {
@ -202,7 +249,11 @@ func parseDSAKey(req []byte) (*AddedKey, error) {
X: k.X, X: k.X,
} }
return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
} }
func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) { func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
@ -243,7 +294,12 @@ func parseEd25519Cert(req []byte) (*AddedKey, error) {
if !ok { if !ok {
return nil, errors.New("agent: bad ED25519 certificate") return nil, errors.New("agent: bad ED25519 certificate")
} }
return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil
addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
} }
func parseECDSAKey(req []byte) (*AddedKey, error) { func parseECDSAKey(req []byte) (*AddedKey, error) {
@ -257,7 +313,11 @@ func parseECDSAKey(req []byte) (*AddedKey, error) {
return nil, err return nil, err
} }
return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
} }
func parseRSACert(req []byte) (*AddedKey, error) { func parseRSACert(req []byte) (*AddedKey, error) {
@ -300,7 +360,11 @@ func parseRSACert(req []byte) (*AddedKey, error) {
} }
priv.Precompute() priv.Precompute()
return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
} }
func parseDSACert(req []byte) (*AddedKey, error) { func parseDSACert(req []byte) (*AddedKey, error) {
@ -338,7 +402,11 @@ func parseDSACert(req []byte) (*AddedKey, error) {
X: k.X, X: k.X,
} }
return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
} }
func parseECDSACert(req []byte) (*AddedKey, error) { func parseECDSACert(req []byte) (*AddedKey, error) {
@ -371,7 +439,11 @@ func parseECDSACert(req []byte) (*AddedKey, error) {
return nil, err return nil, err
} }
return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
} }
func (s *server) insertIdentity(req []byte) error { func (s *server) insertIdentity(req []byte) error {

View File

@ -51,13 +51,12 @@ func (b *buffer) write(buf []byte) {
} }
// eof closes the buffer. Reads from the buffer once all // eof closes the buffer. Reads from the buffer once all
// the data has been consumed will receive os.EOF. // the data has been consumed will receive io.EOF.
func (b *buffer) eof() error { func (b *buffer) eof() {
b.Cond.L.Lock() b.Cond.L.Lock()
b.closed = true b.closed = true
b.Cond.Signal() b.Cond.Signal()
b.Cond.L.Unlock() b.Cond.L.Unlock()
return nil
} }
// Read reads data from the internal buffer in buf. Reads will block // Read reads data from the internal buffer in buf. Reads will block

View File

@ -251,10 +251,18 @@ type CertChecker struct {
// for user certificates. // for user certificates.
SupportedCriticalOptions []string SupportedCriticalOptions []string
// IsAuthority should return true if the key is recognized as // IsUserAuthority should return true if the key is recognized as an
// an authority. This allows for certificates to be signed by other // authority for the given user certificate. This allows for
// certificates. // certificates to be signed by other certificates. This must be set
IsAuthority func(auth PublicKey) bool // if this CertChecker will be checking user certificates.
IsUserAuthority func(auth PublicKey) bool
// IsHostAuthority should report whether the key is recognized as
// an authority for this host. This allows for certificates to be
// signed by other keys, and for those other keys to only be valid
// signers for particular hostnames. This must be set if this
// CertChecker will be checking host certificates.
IsHostAuthority func(auth PublicKey, address string) bool
// Clock is used for verifying time stamps. If nil, time.Now // Clock is used for verifying time stamps. If nil, time.Now
// is used. // is used.
@ -268,7 +276,7 @@ type CertChecker struct {
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a // HostKeyFallback is called when CertChecker.CheckHostKey encounters a
// public key that is not a certificate. It must implement host key // public key that is not a certificate. It must implement host key
// validation or else, if nil, all such keys are rejected. // validation or else, if nil, all such keys are rejected.
HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error HostKeyFallback HostKeyCallback
// IsRevoked is called for each certificate so that revocation checking // IsRevoked is called for each certificate so that revocation checking
// can be implemented. It should return true if the given certificate // can be implemented. It should return true if the given certificate
@ -290,8 +298,17 @@ func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey)
if cert.CertType != HostCert { if cert.CertType != HostCert {
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
} }
if !c.IsHostAuthority(cert.SignatureKey, addr) {
return fmt.Errorf("ssh: no authorities for hostname: %v", addr)
}
return c.CheckCert(addr, cert) hostname, _, err := net.SplitHostPort(addr)
if err != nil {
return err
}
// Pass hostname only as principal for host certificates (consistent with OpenSSH)
return c.CheckCert(hostname, cert)
} }
// Authenticate checks a user certificate. Authenticate can be used as // Authenticate checks a user certificate. Authenticate can be used as
@ -308,6 +325,9 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
if cert.CertType != UserCert { if cert.CertType != UserCert {
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
} }
if !c.IsUserAuthority(cert.SignatureKey) {
return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
}
if err := c.CheckCert(conn.User(), cert); err != nil { if err := c.CheckCert(conn.User(), cert); err != nil {
return nil, err return nil, err
@ -323,7 +343,7 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial)
} }
for opt, _ := range cert.CriticalOptions { for opt := range cert.CriticalOptions {
// sourceAddressCriticalOption will be enforced by // sourceAddressCriticalOption will be enforced by
// serverAuthenticate // serverAuthenticate
if opt == sourceAddressCriticalOption { if opt == sourceAddressCriticalOption {
@ -356,10 +376,6 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
} }
} }
if !c.IsAuthority(cert.SignatureKey) {
return fmt.Errorf("ssh: certificate signed by unrecognized authority")
}
clock := c.Clock clock := c.Clock
if clock == nil { if clock == nil {
clock = time.Now clock = time.Now

View File

@ -205,32 +205,32 @@ type channel struct {
// writePacket sends a packet. If the packet is a channel close, it updates // writePacket sends a packet. If the packet is a channel close, it updates
// sentClose. This method takes the lock c.writeMu. // sentClose. This method takes the lock c.writeMu.
func (c *channel) writePacket(packet []byte) error { func (ch *channel) writePacket(packet []byte) error {
c.writeMu.Lock() ch.writeMu.Lock()
if c.sentClose { if ch.sentClose {
c.writeMu.Unlock() ch.writeMu.Unlock()
return io.EOF return io.EOF
} }
c.sentClose = (packet[0] == msgChannelClose) ch.sentClose = (packet[0] == msgChannelClose)
err := c.mux.conn.writePacket(packet) err := ch.mux.conn.writePacket(packet)
c.writeMu.Unlock() ch.writeMu.Unlock()
return err return err
} }
func (c *channel) sendMessage(msg interface{}) error { func (ch *channel) sendMessage(msg interface{}) error {
if debugMux { if debugMux {
log.Printf("send(%d): %#v", c.mux.chanList.offset, msg) log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
} }
p := Marshal(msg) p := Marshal(msg)
binary.BigEndian.PutUint32(p[1:], c.remoteId) binary.BigEndian.PutUint32(p[1:], ch.remoteId)
return c.writePacket(p) return ch.writePacket(p)
} }
// WriteExtended writes data to a specific extended stream. These streams are // WriteExtended writes data to a specific extended stream. These streams are
// used, for example, for stderr. // used, for example, for stderr.
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if c.sentEOF { if ch.sentEOF {
return 0, io.EOF return 0, io.EOF
} }
// 1 byte message type, 4 bytes remoteId, 4 bytes data length // 1 byte message type, 4 bytes remoteId, 4 bytes data length
@ -241,16 +241,16 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
opCode = msgChannelExtendedData opCode = msgChannelExtendedData
} }
c.writeMu.Lock() ch.writeMu.Lock()
packet := c.packetPool[extendedCode] packet := ch.packetPool[extendedCode]
// We don't remove the buffer from packetPool, so // We don't remove the buffer from packetPool, so
// WriteExtended calls from different goroutines will be // WriteExtended calls from different goroutines will be
// flagged as errors by the race detector. // flagged as errors by the race detector.
c.writeMu.Unlock() ch.writeMu.Unlock()
for len(data) > 0 { for len(data) > 0 {
space := min(c.maxRemotePayload, len(data)) space := min(ch.maxRemotePayload, len(data))
if space, err = c.remoteWin.reserve(space); err != nil { if space, err = ch.remoteWin.reserve(space); err != nil {
return n, err return n, err
} }
if want := headerLength + space; uint32(cap(packet)) < want { if want := headerLength + space; uint32(cap(packet)) < want {
@ -262,13 +262,13 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
todo := data[:space] todo := data[:space]
packet[0] = opCode packet[0] = opCode
binary.BigEndian.PutUint32(packet[1:], c.remoteId) binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
if extendedCode > 0 { if extendedCode > 0 {
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
} }
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
copy(packet[headerLength:], todo) copy(packet[headerLength:], todo)
if err = c.writePacket(packet); err != nil { if err = ch.writePacket(packet); err != nil {
return n, err return n, err
} }
@ -276,14 +276,14 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
data = data[len(todo):] data = data[len(todo):]
} }
c.writeMu.Lock() ch.writeMu.Lock()
c.packetPool[extendedCode] = packet ch.packetPool[extendedCode] = packet
c.writeMu.Unlock() ch.writeMu.Unlock()
return n, err return n, err
} }
func (c *channel) handleData(packet []byte) error { func (ch *channel) handleData(packet []byte) error {
headerLen := 9 headerLen := 9
isExtendedData := packet[0] == msgChannelExtendedData isExtendedData := packet[0] == msgChannelExtendedData
if isExtendedData { if isExtendedData {
@ -303,7 +303,7 @@ func (c *channel) handleData(packet []byte) error {
if length == 0 { if length == 0 {
return nil return nil
} }
if length > c.maxIncomingPayload { if length > ch.maxIncomingPayload {
// TODO(hanwen): should send Disconnect? // TODO(hanwen): should send Disconnect?
return errors.New("ssh: incoming packet exceeds maximum payload size") return errors.New("ssh: incoming packet exceeds maximum payload size")
} }
@ -313,21 +313,21 @@ func (c *channel) handleData(packet []byte) error {
return errors.New("ssh: wrong packet length") return errors.New("ssh: wrong packet length")
} }
c.windowMu.Lock() ch.windowMu.Lock()
if c.myWindow < length { if ch.myWindow < length {
c.windowMu.Unlock() ch.windowMu.Unlock()
// TODO(hanwen): should send Disconnect with reason? // TODO(hanwen): should send Disconnect with reason?
return errors.New("ssh: remote side wrote too much") return errors.New("ssh: remote side wrote too much")
} }
c.myWindow -= length ch.myWindow -= length
c.windowMu.Unlock() ch.windowMu.Unlock()
if extended == 1 { if extended == 1 {
c.extPending.write(data) ch.extPending.write(data)
} else if extended > 0 { } else if extended > 0 {
// discard other extended data. // discard other extended data.
} else { } else {
c.pending.write(data) ch.pending.write(data)
} }
return nil return nil
} }
@ -384,31 +384,31 @@ func (c *channel) close() {
// responseMessageReceived is called when a success or failure message is // responseMessageReceived is called when a success or failure message is
// received on a channel to check that such a message is reasonable for the // received on a channel to check that such a message is reasonable for the
// given channel. // given channel.
func (c *channel) responseMessageReceived() error { func (ch *channel) responseMessageReceived() error {
if c.direction == channelInbound { if ch.direction == channelInbound {
return errors.New("ssh: channel response message received on inbound channel") return errors.New("ssh: channel response message received on inbound channel")
} }
if c.decided { if ch.decided {
return errors.New("ssh: duplicate response received for channel") return errors.New("ssh: duplicate response received for channel")
} }
c.decided = true ch.decided = true
return nil return nil
} }
func (c *channel) handlePacket(packet []byte) error { func (ch *channel) handlePacket(packet []byte) error {
switch packet[0] { switch packet[0] {
case msgChannelData, msgChannelExtendedData: case msgChannelData, msgChannelExtendedData:
return c.handleData(packet) return ch.handleData(packet)
case msgChannelClose: case msgChannelClose:
c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
c.mux.chanList.remove(c.localId) ch.mux.chanList.remove(ch.localId)
c.close() ch.close()
return nil return nil
case msgChannelEOF: case msgChannelEOF:
// RFC 4254 is mute on how EOF affects dataExt messages but // RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time. // it is logical to signal EOF at the same time.
c.extPending.eof() ch.extPending.eof()
c.pending.eof() ch.pending.eof()
return nil return nil
} }
@ -419,24 +419,24 @@ func (c *channel) handlePacket(packet []byte) error {
switch msg := decoded.(type) { switch msg := decoded.(type) {
case *channelOpenFailureMsg: case *channelOpenFailureMsg:
if err := c.responseMessageReceived(); err != nil { if err := ch.responseMessageReceived(); err != nil {
return err return err
} }
c.mux.chanList.remove(msg.PeersId) ch.mux.chanList.remove(msg.PeersID)
c.msg <- msg ch.msg <- msg
case *channelOpenConfirmMsg: case *channelOpenConfirmMsg:
if err := c.responseMessageReceived(); err != nil { if err := ch.responseMessageReceived(); err != nil {
return err return err
} }
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
} }
c.remoteId = msg.MyId ch.remoteId = msg.MyID
c.maxRemotePayload = msg.MaxPacketSize ch.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow) ch.remoteWin.add(msg.MyWindow)
c.msg <- msg ch.msg <- msg
case *windowAdjustMsg: case *windowAdjustMsg:
if !c.remoteWin.add(msg.AdditionalBytes) { if !ch.remoteWin.add(msg.AdditionalBytes) {
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
} }
case *channelRequestMsg: case *channelRequestMsg:
@ -444,12 +444,12 @@ func (c *channel) handlePacket(packet []byte) error {
Type: msg.Request, Type: msg.Request,
WantReply: msg.WantReply, WantReply: msg.WantReply,
Payload: msg.RequestSpecificData, Payload: msg.RequestSpecificData,
ch: c, ch: ch,
} }
c.incomingRequests <- &req ch.incomingRequests <- &req
default: default:
c.msg <- msg ch.msg <- msg
} }
return nil return nil
} }
@ -488,23 +488,23 @@ func (e *extChannel) Read(data []byte) (n int, err error) {
return e.ch.ReadExtended(data, e.code) return e.ch.ReadExtended(data, e.code)
} }
func (c *channel) Accept() (Channel, <-chan *Request, error) { func (ch *channel) Accept() (Channel, <-chan *Request, error) {
if c.decided { if ch.decided {
return nil, nil, errDecidedAlready return nil, nil, errDecidedAlready
} }
c.maxIncomingPayload = channelMaxPacket ch.maxIncomingPayload = channelMaxPacket
confirm := channelOpenConfirmMsg{ confirm := channelOpenConfirmMsg{
PeersId: c.remoteId, PeersID: ch.remoteId,
MyId: c.localId, MyID: ch.localId,
MyWindow: c.myWindow, MyWindow: ch.myWindow,
MaxPacketSize: c.maxIncomingPayload, MaxPacketSize: ch.maxIncomingPayload,
} }
c.decided = true ch.decided = true
if err := c.sendMessage(confirm); err != nil { if err := ch.sendMessage(confirm); err != nil {
return nil, nil, err return nil, nil, err
} }
return c, c.incomingRequests, nil return ch, ch.incomingRequests, nil
} }
func (ch *channel) Reject(reason RejectionReason, message string) error { func (ch *channel) Reject(reason RejectionReason, message string) error {
@ -512,7 +512,7 @@ func (ch *channel) Reject(reason RejectionReason, message string) error {
return errDecidedAlready return errDecidedAlready
} }
reject := channelOpenFailureMsg{ reject := channelOpenFailureMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
Reason: reason, Reason: reason,
Message: message, Message: message,
Language: "en", Language: "en",
@ -541,7 +541,7 @@ func (ch *channel) CloseWrite() error {
} }
ch.sentEOF = true ch.sentEOF = true
return ch.sendMessage(channelEOFMsg{ return ch.sendMessage(channelEOFMsg{
PeersId: ch.remoteId}) PeersID: ch.remoteId})
} }
func (ch *channel) Close() error { func (ch *channel) Close() error {
@ -550,7 +550,7 @@ func (ch *channel) Close() error {
} }
return ch.sendMessage(channelCloseMsg{ return ch.sendMessage(channelCloseMsg{
PeersId: ch.remoteId}) PeersID: ch.remoteId})
} }
// Extended returns an io.ReadWriter that sends and receives data on the given, // Extended returns an io.ReadWriter that sends and receives data on the given,
@ -577,7 +577,7 @@ func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (boo
} }
msg := channelRequestMsg{ msg := channelRequestMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
Request: name, Request: name,
WantReply: wantReply, WantReply: wantReply,
RequestSpecificData: payload, RequestSpecificData: payload,
@ -614,11 +614,11 @@ func (ch *channel) ackRequest(ok bool) error {
var msg interface{} var msg interface{}
if !ok { if !ok {
msg = channelRequestFailureMsg{ msg = channelRequestFailureMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
} }
} else { } else {
msg = channelRequestSuccessMsg{ msg = channelRequestSuccessMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
} }
} }
return ch.sendMessage(msg) return ch.sendMessage(msg)

View File

@ -304,7 +304,7 @@ type gcmCipher struct {
buf []byte buf []byte
} }
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { func newGCMCipher(iv, key []byte) (packetCipher, error) {
c, err := aes.NewCipher(key) c, err := aes.NewCipher(key)
if err != nil { if err != nil {
return nil, err return nil, err
@ -372,7 +372,7 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
} }
length := binary.BigEndian.Uint32(c.prefix[:]) length := binary.BigEndian.Uint32(c.prefix[:])
if length > maxPacket { if length > maxPacket {
return nil, errors.New("ssh: max packet length exceeded.") return nil, errors.New("ssh: max packet length exceeded")
} }
if cap(c.buf) < int(length+gcmTagSize) { if cap(c.buf) < int(length+gcmTagSize) {
@ -392,7 +392,9 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
c.incIV() c.incIV()
padding := plain[0] padding := plain[0]
if padding < 4 || padding >= 20 { if padding < 4 {
// padding is a byte, so it automatically satisfies
// the maximum size, which is 255.
return nil, fmt.Errorf("ssh: illegal padding %d", padding) return nil, fmt.Errorf("ssh: illegal padding %d", padding)
} }
@ -546,11 +548,11 @@ func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error)
c.packetData = c.packetData[:entirePacketSize] c.packetData = c.packetData[:entirePacketSize]
} }
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
if err != nil {
return nil, err return nil, err
} else {
c.oracleCamouflage -= uint32(n)
} }
c.oracleCamouflage -= uint32(n)
remainingCrypted := c.packetData[firstBlockLength:macStart] remainingCrypted := c.packetData[firstBlockLength:macStart]
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)

View File

@ -5,15 +5,17 @@
package ssh package ssh
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"time" "time"
) )
// Client implements a traditional SSH client that supports shells, // Client implements a traditional SSH client that supports shells,
// subprocesses, port forwarding and tunneled dialing. // subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
type Client struct { type Client struct {
Conn Conn
@ -59,6 +61,7 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn.forwards.closeAll() conn.forwards.closeAll()
}() }()
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
return conn return conn
} }
@ -68,6 +71,11 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config fullConf := *config
fullConf.SetDefaults() fullConf.SetDefaults()
if fullConf.HostKeyCallback == nil {
c.Close()
return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback")
}
conn := &connection{ conn := &connection{
sshConn: sshConn{conn: c}, sshConn: sshConn{conn: c},
} }
@ -173,6 +181,17 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
return NewClient(c, chans, reqs), nil return NewClient(c, chans, reqs), nil
} }
// HostKeyCallback is the function type used for verifying server
// keys. A HostKeyCallback must return nil if the host key is OK, or
// an error to reject it. It receives the hostname as passed to Dial
// or NewClientConn. The remote address is the RemoteAddr of the
// net.Conn underlying the the SSH connection.
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// BannerCallback is the function type used for treat the banner sent by
// the server. A BannerCallback receives the message sent by the remote server.
type BannerCallback func(message string) error
// A ClientConfig structure is used to configure a Client. It must not be // A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function. // modified after having been passed to an SSH function.
type ClientConfig struct { type ClientConfig struct {
@ -188,10 +207,18 @@ type ClientConfig struct {
// be used during authentication. // be used during authentication.
Auth []AuthMethod Auth []AuthMethod
// HostKeyCallback, if not nil, is called during the cryptographic // HostKeyCallback is called during the cryptographic
// handshake to validate the server's host key. A nil HostKeyCallback // handshake to validate the server's host key. The client
// implies that all host keys are accepted. // configuration must supply this callback for the connection
HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error // to succeed. The functions InsecureIgnoreHostKey or
// FixedHostKey can be used for simplistic host key checks.
HostKeyCallback HostKeyCallback
// BannerCallback is called during the SSH dance to display a custom
// server's message. The client configuration can supply this callback to
// handle it as wished. The function BannerDisplayStderr can be used for
// simplistic display on Stderr.
BannerCallback BannerCallback
// ClientVersion contains the version identification string that will // ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used. // be used for the connection. If empty, a reasonable default is used.
@ -209,3 +236,43 @@ type ClientConfig struct {
// A Timeout of zero means no timeout. // A Timeout of zero means no timeout.
Timeout time.Duration Timeout time.Duration
} }
// InsecureIgnoreHostKey returns a function that can be used for
// ClientConfig.HostKeyCallback to accept any host key. It should
// not be used for production code.
func InsecureIgnoreHostKey() HostKeyCallback {
return func(hostname string, remote net.Addr, key PublicKey) error {
return nil
}
}
type fixedHostKey struct {
key PublicKey
}
func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error {
if f.key == nil {
return fmt.Errorf("ssh: required host key was nil")
}
if !bytes.Equal(key.Marshal(), f.key.Marshal()) {
return fmt.Errorf("ssh: host key mismatch")
}
return nil
}
// FixedHostKey returns a function for use in
// ClientConfig.HostKeyCallback to accept only a specific host key.
func FixedHostKey(key PublicKey) HostKeyCallback {
hk := &fixedHostKey{key}
return hk.check
}
// BannerDisplayStderr returns a function that can be used for
// ClientConfig.BannerCallback to display banners on os.Stderr.
func BannerDisplayStderr() BannerCallback {
return func(banner string) error {
_, err := os.Stderr.WriteString(banner)
return err
}
}

View File

@ -179,31 +179,26 @@ func (cb publicKeyCallback) method() string {
} }
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
// Authentication is performed in two stages. The first stage sends an // Authentication is performed by sending an enquiry to test if a key is
// enquiry to test if each key is acceptable to the remote. The second // acceptable to the remote. If the key is acceptable, the client will
// stage attempts to authenticate with the valid keys obtained in the // attempt to authenticate with the valid key. If not the client will repeat
// first stage. // the process with the remaining keys.
signers, err := cb() signers, err := cb()
if err != nil { if err != nil {
return false, nil, err return false, nil, err
} }
var validKeys []Signer
for _, signer := range signers {
if ok, err := validateKey(signer.PublicKey(), user, c); ok {
validKeys = append(validKeys, signer)
} else {
if err != nil {
return false, nil, err
}
}
}
// methods that may continue if this auth is not successful.
var methods []string var methods []string
for _, signer := range validKeys { for _, signer := range signers {
pub := signer.PublicKey() ok, err := validateKey(signer.PublicKey(), user, c)
if err != nil {
return false, nil, err
}
if !ok {
continue
}
pub := signer.PublicKey()
pubKey := pub.Marshal() pubKey := pub.Marshal()
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user, User: user,
@ -236,13 +231,29 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
if err != nil { if err != nil {
return false, nil, err return false, nil, err
} }
if success {
// If authentication succeeds or the list of available methods does not
// contain the "publickey" method, do not attempt to authenticate with any
// other keys. According to RFC 4252 Section 7, the latter can occur when
// additional authentication methods are required.
if success || !containsMethod(methods, cb.method()) {
return success, methods, err return success, methods, err
} }
} }
return false, methods, nil return false, methods, nil
} }
func containsMethod(methods []string, method string) bool {
for _, m := range methods {
if m == method {
return true
}
}
return false
}
// validateKey validates the key provided is acceptable to the server. // validateKey validates the key provided is acceptable to the server.
func validateKey(key PublicKey, user string, c packetConn) (bool, error) { func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
pubKey := key.Marshal() pubKey := key.Marshal()
@ -272,7 +283,9 @@ func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
} }
switch packet[0] { switch packet[0] {
case msgUserAuthBanner: case msgUserAuthBanner:
// TODO(gpaul): add callback to present the banner to the user if err := handleBannerResponse(c, packet); err != nil {
return false, err
}
case msgUserAuthPubKeyOk: case msgUserAuthPubKeyOk:
var msg userAuthPubKeyOkMsg var msg userAuthPubKeyOkMsg
if err := Unmarshal(packet, &msg); err != nil { if err := Unmarshal(packet, &msg); err != nil {
@ -314,7 +327,9 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
switch packet[0] { switch packet[0] {
case msgUserAuthBanner: case msgUserAuthBanner:
// TODO: add callback to present the banner to the user if err := handleBannerResponse(c, packet); err != nil {
return false, nil, err
}
case msgUserAuthFailure: case msgUserAuthFailure:
var msg userAuthFailureMsg var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil { if err := Unmarshal(packet, &msg); err != nil {
@ -329,6 +344,24 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
} }
} }
func handleBannerResponse(c packetConn, packet []byte) error {
var msg userAuthBannerMsg
if err := Unmarshal(packet, &msg); err != nil {
return err
}
transport, ok := c.(*handshakeTransport)
if !ok {
return nil
}
if transport.bannerCallback != nil {
return transport.bannerCallback(msg.Message)
}
return nil
}
// KeyboardInteractiveChallenge should print questions, optionally // KeyboardInteractiveChallenge should print questions, optionally
// disabling echoing (e.g. for passwords), and return all the answers. // disabling echoing (e.g. for passwords), and return all the answers.
// Challenge may be called multiple times in a single session. After // Challenge may be called multiple times in a single session. After
@ -338,7 +371,7 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
// both CLI and GUI environments. // both CLI and GUI environments.
type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error)
// KeyboardInteractive returns a AuthMethod using a prompt/response // KeyboardInteractive returns an AuthMethod using a prompt/response
// sequence controlled by the server. // sequence controlled by the server.
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
return challenge return challenge
@ -374,7 +407,9 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
// like handleAuthResponse, but with less options. // like handleAuthResponse, but with less options.
switch packet[0] { switch packet[0] {
case msgUserAuthBanner: case msgUserAuthBanner:
// TODO: Print banners during userauth. if err := handleBannerResponse(c, packet); err != nil {
return false, nil, err
}
continue continue
case msgUserAuthInfoRequest: case msgUserAuthInfoRequest:
// OK // OK

View File

@ -9,6 +9,7 @@ import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"io" "io"
"math"
"sync" "sync"
_ "crypto/sha1" _ "crypto/sha1"
@ -40,7 +41,7 @@ var supportedKexAlgos = []string{
kexAlgoDH14SHA1, kexAlgoDH1SHA1, kexAlgoDH14SHA1, kexAlgoDH1SHA1,
} }
// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods // supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
// of authenticating servers) in preference order. // of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{ var supportedHostKeyAlgos = []string{
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
@ -186,7 +187,7 @@ type Config struct {
// The maximum number of bytes sent or received after which a // The maximum number of bytes sent or received after which a
// new key is negotiated. It must be at least 256. If // new key is negotiated. It must be at least 256. If
// unspecified, 1 gigabyte is used. // unspecified, a size suitable for the chosen cipher is used.
RekeyThreshold uint64 RekeyThreshold uint64
// The allowed key exchanges algorithms. If unspecified then a // The allowed key exchanges algorithms. If unspecified then a
@ -230,17 +231,18 @@ func (c *Config) SetDefaults() {
} }
if c.RekeyThreshold == 0 { if c.RekeyThreshold == 0 {
// RFC 4253, section 9 suggests rekeying after 1G. // cipher specific default
c.RekeyThreshold = 1 << 30 } else if c.RekeyThreshold < minRekeyThreshold {
}
if c.RekeyThreshold < minRekeyThreshold {
c.RekeyThreshold = minRekeyThreshold c.RekeyThreshold = minRekeyThreshold
} else if c.RekeyThreshold >= math.MaxInt64 {
// Avoid weirdness if somebody uses -1 as a threshold.
c.RekeyThreshold = math.MaxInt64
} }
} }
// buildDataSignedForAuth returns the data that is signed in order to prove // buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7. // possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
data := struct { data := struct {
Session []byte Session []byte
Type byte Type byte
@ -251,7 +253,7 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
Algo []byte Algo []byte
PubKey []byte PubKey []byte
}{ }{
sessionId, sessionID,
msgUserAuthRequest, msgUserAuthRequest,
req.User, req.User,
req.Service, req.Service,

View File

@ -25,7 +25,7 @@ type ConnMetadata interface {
// User returns the user ID for this connection. // User returns the user ID for this connection.
User() string User() string
// SessionID returns the sesson hash, also denoted by H. // SessionID returns the session hash, also denoted by H.
SessionID() []byte SessionID() []byte
// ClientVersion returns the client's version string as hashed // ClientVersion returns the client's version string as hashed

View File

@ -14,5 +14,8 @@ others.
References: References:
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
This package does not fall under the stability promise of the Go language itself,
so its API may be changed when pressing needs arise.
*/ */
package ssh // import "golang.org/x/crypto/ssh" package ssh // import "golang.org/x/crypto/ssh"

View File

@ -74,10 +74,15 @@ type handshakeTransport struct {
startKex chan *pendingKex startKex chan *pendingKex
// data for host key checking // data for host key checking
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error hostKeyCallback HostKeyCallback
dialAddress string dialAddress string
remoteAddr net.Addr remoteAddr net.Addr
// bannerCallback is non-empty if we are the client and it has been set in
// ClientConfig. In that case it is called during the user authentication
// dance to handle a custom server's message.
bannerCallback BannerCallback
// Algorithms agreed in the last key exchange. // Algorithms agreed in the last key exchange.
algorithms *algorithms algorithms *algorithms
@ -107,6 +112,8 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
config: config, config: config,
} }
t.resetReadThresholds()
t.resetWriteThresholds()
// We always start with a mandatory key exchange. // We always start with a mandatory key exchange.
t.requestKex <- struct{}{} t.requestKex <- struct{}{}
@ -118,6 +125,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
t.dialAddress = dialAddr t.dialAddress = dialAddr
t.remoteAddr = addr t.remoteAddr = addr
t.hostKeyCallback = config.HostKeyCallback t.hostKeyCallback = config.HostKeyCallback
t.bannerCallback = config.BannerCallback
if config.HostKeyAlgorithms != nil { if config.HostKeyAlgorithms != nil {
t.hostKeyAlgorithms = config.HostKeyAlgorithms t.hostKeyAlgorithms = config.HostKeyAlgorithms
} else { } else {
@ -237,6 +245,17 @@ func (t *handshakeTransport) requestKeyExchange() {
} }
} }
func (t *handshakeTransport) resetWriteThresholds() {
t.writePacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
} else {
t.writeBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) kexLoop() { func (t *handshakeTransport) kexLoop() {
write: write:
@ -285,12 +304,8 @@ write:
t.writeError = err t.writeError = err
t.sentInitPacket = nil t.sentInitPacket = nil
t.sentInitMsg = nil t.sentInitMsg = nil
t.writePacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 { t.resetWriteThresholds()
t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
}
// we have completed the key exchange. Since the // we have completed the key exchange. Since the
// reader is still blocked, it is safe to clear out // reader is still blocked, it is safe to clear out
@ -344,6 +359,17 @@ write:
// key exchange itself. // key exchange itself.
const packetRekeyThreshold = (1 << 31) const packetRekeyThreshold = (1 << 31)
func (t *handshakeTransport) resetReadThresholds() {
t.readPacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.readBytesLeft = t.algorithms.r.rekeyBytes()
} else {
t.readBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
p, err := t.conn.readPacket() p, err := t.conn.readPacket()
if err != nil { if err != nil {
@ -391,12 +417,7 @@ func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
return nil, err return nil, err
} }
t.readPacketsLeft = packetRekeyThreshold t.resetReadThresholds()
if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold)
} else {
t.readBytesLeft = t.algorithms.r.rekeyBytes()
}
// By default, a key exchange is hidden from higher layers by // By default, a key exchange is hidden from higher layers by
// translating it into msgIgnore. // translating it into msgIgnore.
@ -574,7 +595,9 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
} }
result.SessionID = t.sessionID result.SessionID = t.sessionID
t.conn.prepareKeyChange(t.algorithms, result) if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
return err
}
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
return err return err
} }
@ -614,11 +637,9 @@ func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *
return nil, err return nil, err
} }
if t.hostKeyCallback != nil { err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) if err != nil {
if err != nil { return nil, err
return nil, err
}
} }
return result, nil return result, nil

View File

@ -119,7 +119,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
return nil, err return nil, err
} }
kInt, err := group.diffieHellman(kexDHReply.Y, x) ki, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -129,8 +129,8 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
writeString(h, kexDHReply.HostKey) writeString(h, kexDHReply.HostKey)
writeInt(h, X) writeInt(h, X)
writeInt(h, kexDHReply.Y) writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
return &kexResult{ return &kexResult{
@ -164,7 +164,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
} }
Y := new(big.Int).Exp(group.g, y, group.p) Y := new(big.Int).Exp(group.g, y, group.p)
kInt, err := group.diffieHellman(kexDHInit.X, y) ki, err := group.diffieHellman(kexDHInit.X, y)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -177,8 +177,8 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
writeInt(h, kexDHInit.X) writeInt(h, kexDHInit.X)
writeInt(h, Y) writeInt(h, Y)
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
H := h.Sum(nil) H := h.Sum(nil)
@ -383,8 +383,8 @@ func init() {
// 4253 and Oakley Group 2 in RFC 2409. // 4253 and Oakley Group 2 in RFC 2409.
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
g: new(big.Int).SetInt64(2), g: new(big.Int).SetInt64(2),
p: p, p: p,
pMinus1: new(big.Int).Sub(p, bigOne), pMinus1: new(big.Int).Sub(p, bigOne),
} }
@ -393,8 +393,8 @@ func init() {
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
g: new(big.Int).SetInt64(2), g: new(big.Int).SetInt64(2),
p: p, p: p,
pMinus1: new(big.Int).Sub(p, bigOne), pMinus1: new(big.Int).Sub(p, bigOne),
} }
@ -462,9 +462,9 @@ func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handsh
writeString(h, kp.pub[:]) writeString(h, kp.pub[:])
writeString(h, reply.EphemeralPubKey) writeString(h, reply.EphemeralPubKey)
kInt := new(big.Int).SetBytes(secret[:]) ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
return &kexResult{ return &kexResult{
@ -510,9 +510,9 @@ func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handsh
writeString(h, kexInit.ClientPubKey) writeString(h, kexInit.ClientPubKey)
writeString(h, kp.pub[:]) writeString(h, kp.pub[:])
kInt := new(big.Int).SetBytes(secret[:]) ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
H := h.Sum(nil) H := h.Sum(nil)

View File

@ -363,10 +363,21 @@ func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
type dsaPublicKey dsa.PublicKey type dsaPublicKey dsa.PublicKey
func (r *dsaPublicKey) Type() string { func (k *dsaPublicKey) Type() string {
return "ssh-dss" return "ssh-dss"
} }
func checkDSAParams(param *dsa.Parameters) error {
// SSH specifies FIPS 186-2, which only provided a single size
// (1024 bits) DSA key. FIPS 186-3 allows for larger key
// sizes, which would confuse SSH.
if l := param.P.BitLen(); l != 1024 {
return fmt.Errorf("ssh: unsupported DSA key size %d", l)
}
return nil
}
// parseDSA parses an DSA key according to RFC 4253, section 6.6. // parseDSA parses an DSA key according to RFC 4253, section 6.6.
func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
var w struct { var w struct {
@ -377,13 +388,18 @@ func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
return nil, nil, err return nil, nil, err
} }
param := dsa.Parameters{
P: w.P,
Q: w.Q,
G: w.G,
}
if err := checkDSAParams(&param); err != nil {
return nil, nil, err
}
key := &dsaPublicKey{ key := &dsaPublicKey{
Parameters: dsa.Parameters{ Parameters: param,
P: w.P, Y: w.Y,
Q: w.Q,
G: w.G,
},
Y: w.Y,
} }
return key, w.Rest, nil return key, w.Rest, nil
} }
@ -465,12 +481,12 @@ func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
type ecdsaPublicKey ecdsa.PublicKey type ecdsaPublicKey ecdsa.PublicKey
func (key *ecdsaPublicKey) Type() string { func (k *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + key.nistID() return "ecdsa-sha2-" + k.nistID()
} }
func (key *ecdsaPublicKey) nistID() string { func (k *ecdsaPublicKey) nistID() string {
switch key.Params().BitSize { switch k.Params().BitSize {
case 256: case 256:
return "nistp256" return "nistp256"
case 384: case 384:
@ -483,7 +499,7 @@ func (key *ecdsaPublicKey) nistID() string {
type ed25519PublicKey ed25519.PublicKey type ed25519PublicKey ed25519.PublicKey
func (key ed25519PublicKey) Type() string { func (k ed25519PublicKey) Type() string {
return KeyAlgoED25519 return KeyAlgoED25519
} }
@ -502,23 +518,23 @@ func parseED25519(in []byte) (out PublicKey, rest []byte, err error) {
return (ed25519PublicKey)(key), w.Rest, nil return (ed25519PublicKey)(key), w.Rest, nil
} }
func (key ed25519PublicKey) Marshal() []byte { func (k ed25519PublicKey) Marshal() []byte {
w := struct { w := struct {
Name string Name string
KeyBytes []byte KeyBytes []byte
}{ }{
KeyAlgoED25519, KeyAlgoED25519,
[]byte(key), []byte(k),
} }
return Marshal(&w) return Marshal(&w)
} }
func (key ed25519PublicKey) Verify(b []byte, sig *Signature) error { func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != key.Type() { if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
} }
edKey := (ed25519.PublicKey)(key) edKey := (ed25519.PublicKey)(k)
if ok := ed25519.Verify(edKey, b, sig.Blob); !ok { if ok := ed25519.Verify(edKey, b, sig.Blob); !ok {
return errors.New("ssh: signature did not verify") return errors.New("ssh: signature did not verify")
} }
@ -579,9 +595,9 @@ func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
return (*ecdsaPublicKey)(key), w.Rest, nil return (*ecdsaPublicKey)(key), w.Rest, nil
} }
func (key *ecdsaPublicKey) Marshal() []byte { func (k *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1. // See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y)
// ECDSA publickey struct layout should match the struct used by // ECDSA publickey struct layout should match the struct used by
// parseECDSACert in the x/crypto/ssh/agent package. // parseECDSACert in the x/crypto/ssh/agent package.
w := struct { w := struct {
@ -589,20 +605,20 @@ func (key *ecdsaPublicKey) Marshal() []byte {
ID string ID string
Key []byte Key []byte
}{ }{
key.Type(), k.Type(),
key.nistID(), k.nistID(),
keyBytes, keyBytes,
} }
return Marshal(&w) return Marshal(&w)
} }
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != key.Type() { if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
} }
h := ecHash(key.Curve).New() h := ecHash(k.Curve).New()
h.Write(data) h.Write(data)
digest := h.Sum(nil) digest := h.Sum(nil)
@ -619,7 +635,7 @@ func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
return err return err
} }
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) {
return nil return nil
} }
return errors.New("ssh: signature did not verify") return errors.New("ssh: signature did not verify")
@ -630,19 +646,28 @@ func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey {
} }
// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey, // NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey,
// *ecdsa.PrivateKey or any other crypto.Signer and returns a corresponding // *ecdsa.PrivateKey or any other crypto.Signer and returns a
// Signer instance. ECDSA keys must use P-256, P-384 or P-521. // corresponding Signer instance. ECDSA keys must use P-256, P-384 or
// P-521. DSA keys must use parameter size L1024N160.
func NewSignerFromKey(key interface{}) (Signer, error) { func NewSignerFromKey(key interface{}) (Signer, error) {
switch key := key.(type) { switch key := key.(type) {
case crypto.Signer: case crypto.Signer:
return NewSignerFromSigner(key) return NewSignerFromSigner(key)
case *dsa.PrivateKey: case *dsa.PrivateKey:
return &dsaPrivateKey{key}, nil return newDSAPrivateKey(key)
default: default:
return nil, fmt.Errorf("ssh: unsupported key type %T", key) return nil, fmt.Errorf("ssh: unsupported key type %T", key)
} }
} }
func newDSAPrivateKey(key *dsa.PrivateKey) (Signer, error) {
if err := checkDSAParams(&key.PublicKey.Parameters); err != nil {
return nil, err
}
return &dsaPrivateKey{key}, nil
}
type wrappedSigner struct { type wrappedSigner struct {
signer crypto.Signer signer crypto.Signer
pubKey PublicKey pubKey PublicKey
@ -733,7 +758,7 @@ func NewPublicKey(key interface{}) (PublicKey, error) {
return (*rsaPublicKey)(key), nil return (*rsaPublicKey)(key), nil
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
if !supportedEllipticCurve(key.Curve) { if !supportedEllipticCurve(key.Curve) {
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported.") return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported")
} }
return (*ecdsaPublicKey)(key), nil return (*ecdsaPublicKey)(key), nil
case *dsa.PublicKey: case *dsa.PublicKey:
@ -756,6 +781,18 @@ func ParsePrivateKey(pemBytes []byte) (Signer, error) {
return NewSignerFromKey(key) return NewSignerFromKey(key)
} }
// ParsePrivateKeyWithPassphrase returns a Signer from a PEM encoded private
// key and passphrase. It supports the same keys as
// ParseRawPrivateKeyWithPassphrase.
func ParsePrivateKeyWithPassphrase(pemBytes, passPhrase []byte) (Signer, error) {
key, err := ParseRawPrivateKeyWithPassphrase(pemBytes, passPhrase)
if err != nil {
return nil, err
}
return NewSignerFromKey(key)
}
// encryptedBlock tells whether a private key is // encryptedBlock tells whether a private key is
// encrypted by examining its Proc-Type header // encrypted by examining its Proc-Type header
// for a mention of ENCRYPTED // for a mention of ENCRYPTED
@ -790,6 +827,43 @@ func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
} }
} }
// ParseRawPrivateKeyWithPassphrase returns a private key decrypted with
// passphrase from a PEM encoded private key. If wrong passphrase, return
// x509.IncorrectPasswordError.
func ParseRawPrivateKeyWithPassphrase(pemBytes, passPhrase []byte) (interface{}, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("ssh: no key found")
}
buf := block.Bytes
if encryptedBlock(block) {
if x509.IsEncryptedPEMBlock(block) {
var err error
buf, err = x509.DecryptPEMBlock(block, passPhrase)
if err != nil {
if err == x509.IncorrectPasswordError {
return nil, err
}
return nil, fmt.Errorf("ssh: cannot decode encrypted private keys: %v", err)
}
}
}
switch block.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(buf)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(buf)
case "DSA PRIVATE KEY":
return ParseDSAPrivateKey(buf)
case "OPENSSH PRIVATE KEY":
return parseOpenSSHPrivateKey(buf)
default:
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
}
}
// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as // ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
// specified by the OpenSSL DSA man page. // specified by the OpenSSL DSA man page.
func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
@ -824,7 +898,7 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
// Implemented based on the documentation at // Implemented based on the documentation at
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) {
magic := append([]byte("openssh-key-v1"), 0) magic := append([]byte("openssh-key-v1"), 0)
if !bytes.Equal(magic, key[0:len(magic)]) { if !bytes.Equal(magic, key[0:len(magic)]) {
return nil, errors.New("ssh: invalid openssh private key format") return nil, errors.New("ssh: invalid openssh private key format")
@ -844,14 +918,15 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
return nil, err return nil, err
} }
if w.KdfName != "none" || w.CipherName != "none" {
return nil, errors.New("ssh: cannot decode encrypted private keys")
}
pk1 := struct { pk1 := struct {
Check1 uint32 Check1 uint32
Check2 uint32 Check2 uint32
Keytype string Keytype string
Pub []byte Rest []byte `ssh:"rest"`
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
}{} }{}
if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil { if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil {
@ -862,24 +937,75 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
return nil, errors.New("ssh: checkint mismatch") return nil, errors.New("ssh: checkint mismatch")
} }
// we only handle ed25519 keys currently // we only handle ed25519 and rsa keys currently
if pk1.Keytype != KeyAlgoED25519 { switch pk1.Keytype {
case KeyAlgoRSA:
// https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773
key := struct {
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int
P *big.Int
Q *big.Int
Comment string
Pad []byte `ssh:"rest"`
}{}
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
for i, b := range key.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
pk := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: key.N,
E: int(key.E.Int64()),
},
D: key.D,
Primes: []*big.Int{key.P, key.Q},
}
if err := pk.Validate(); err != nil {
return nil, err
}
pk.Precompute()
return pk, nil
case KeyAlgoED25519:
key := struct {
Pub []byte
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
}{}
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
if len(key.Priv) != ed25519.PrivateKeySize {
return nil, errors.New("ssh: private key unexpected length")
}
for i, b := range key.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
copy(pk, key.Priv)
return &pk, nil
default:
return nil, errors.New("ssh: unhandled key type") return nil, errors.New("ssh: unhandled key type")
} }
for i, b := range pk1.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
if len(pk1.Priv) != ed25519.PrivateKeySize {
return nil, errors.New("ssh: private key unexpected length")
}
pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
copy(pk, pk1.Priv)
return &pk, nil
} }
// FingerprintLegacyMD5 returns the user presentation of the key's // FingerprintLegacyMD5 returns the user presentation of the key's

View File

@ -23,10 +23,6 @@ const (
msgUnimplemented = 3 msgUnimplemented = 3
msgDebug = 4 msgDebug = 4
msgNewKeys = 21 msgNewKeys = 21
// Standard authentication messages
msgUserAuthSuccess = 52
msgUserAuthBanner = 53
) )
// SSH messages: // SSH messages:
@ -137,6 +133,18 @@ type userAuthFailureMsg struct {
PartialSuccess bool PartialSuccess bool
} }
// See RFC 4252, section 5.1
const msgUserAuthSuccess = 52
// See RFC 4252, section 5.4
const msgUserAuthBanner = 53
type userAuthBannerMsg struct {
Message string `sshtype:"53"`
// unused, but required to allow message parsing
Language string
}
// See RFC 4256, section 3.2 // See RFC 4256, section 3.2
const msgUserAuthInfoRequest = 60 const msgUserAuthInfoRequest = 60
const msgUserAuthInfoResponse = 61 const msgUserAuthInfoResponse = 61
@ -154,7 +162,7 @@ const msgChannelOpen = 90
type channelOpenMsg struct { type channelOpenMsg struct {
ChanType string `sshtype:"90"` ChanType string `sshtype:"90"`
PeersId uint32 PeersID uint32
PeersWindow uint32 PeersWindow uint32
MaxPacketSize uint32 MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"` TypeSpecificData []byte `ssh:"rest"`
@ -165,7 +173,7 @@ const msgChannelData = 94
// Used for debug print outs of packets. // Used for debug print outs of packets.
type channelDataMsg struct { type channelDataMsg struct {
PeersId uint32 `sshtype:"94"` PeersID uint32 `sshtype:"94"`
Length uint32 Length uint32
Rest []byte `ssh:"rest"` Rest []byte `ssh:"rest"`
} }
@ -174,8 +182,8 @@ type channelDataMsg struct {
const msgChannelOpenConfirm = 91 const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct { type channelOpenConfirmMsg struct {
PeersId uint32 `sshtype:"91"` PeersID uint32 `sshtype:"91"`
MyId uint32 MyID uint32
MyWindow uint32 MyWindow uint32
MaxPacketSize uint32 MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"` TypeSpecificData []byte `ssh:"rest"`
@ -185,7 +193,7 @@ type channelOpenConfirmMsg struct {
const msgChannelOpenFailure = 92 const msgChannelOpenFailure = 92
type channelOpenFailureMsg struct { type channelOpenFailureMsg struct {
PeersId uint32 `sshtype:"92"` PeersID uint32 `sshtype:"92"`
Reason RejectionReason Reason RejectionReason
Message string Message string
Language string Language string
@ -194,7 +202,7 @@ type channelOpenFailureMsg struct {
const msgChannelRequest = 98 const msgChannelRequest = 98
type channelRequestMsg struct { type channelRequestMsg struct {
PeersId uint32 `sshtype:"98"` PeersID uint32 `sshtype:"98"`
Request string Request string
WantReply bool WantReply bool
RequestSpecificData []byte `ssh:"rest"` RequestSpecificData []byte `ssh:"rest"`
@ -204,28 +212,28 @@ type channelRequestMsg struct {
const msgChannelSuccess = 99 const msgChannelSuccess = 99
type channelRequestSuccessMsg struct { type channelRequestSuccessMsg struct {
PeersId uint32 `sshtype:"99"` PeersID uint32 `sshtype:"99"`
} }
// See RFC 4254, section 5.4. // See RFC 4254, section 5.4.
const msgChannelFailure = 100 const msgChannelFailure = 100
type channelRequestFailureMsg struct { type channelRequestFailureMsg struct {
PeersId uint32 `sshtype:"100"` PeersID uint32 `sshtype:"100"`
} }
// See RFC 4254, section 5.3 // See RFC 4254, section 5.3
const msgChannelClose = 97 const msgChannelClose = 97
type channelCloseMsg struct { type channelCloseMsg struct {
PeersId uint32 `sshtype:"97"` PeersID uint32 `sshtype:"97"`
} }
// See RFC 4254, section 5.3 // See RFC 4254, section 5.3
const msgChannelEOF = 96 const msgChannelEOF = 96
type channelEOFMsg struct { type channelEOFMsg struct {
PeersId uint32 `sshtype:"96"` PeersID uint32 `sshtype:"96"`
} }
// See RFC 4254, section 4 // See RFC 4254, section 4
@ -255,7 +263,7 @@ type globalRequestFailureMsg struct {
const msgChannelWindowAdjust = 93 const msgChannelWindowAdjust = 93
type windowAdjustMsg struct { type windowAdjustMsg struct {
PeersId uint32 `sshtype:"93"` PeersID uint32 `sshtype:"93"`
AdditionalBytes uint32 AdditionalBytes uint32
} }

View File

@ -278,7 +278,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
failMsg := channelOpenFailureMsg{ failMsg := channelOpenFailureMsg{
PeersId: msg.PeersId, PeersID: msg.PeersID,
Reason: ConnectionFailed, Reason: ConnectionFailed,
Message: "invalid request", Message: "invalid request",
Language: "en_US.UTF-8", Language: "en_US.UTF-8",
@ -287,7 +287,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
} }
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
c.remoteId = msg.PeersId c.remoteId = msg.PeersID
c.maxRemotePayload = msg.MaxPacketSize c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.PeersWindow) c.remoteWin.add(msg.PeersWindow)
m.incomingChannels <- c m.incomingChannels <- c
@ -313,7 +313,7 @@ func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
PeersWindow: ch.myWindow, PeersWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload, MaxPacketSize: ch.maxIncomingPayload,
TypeSpecificData: extra, TypeSpecificData: extra,
PeersId: ch.localId, PeersID: ch.localId,
} }
if err := m.sendMessage(open); err != nil { if err := m.sendMessage(open); err != nil {
return nil, err return nil, err

View File

@ -14,23 +14,34 @@ import (
) )
// The Permissions type holds fine-grained permissions that are // The Permissions type holds fine-grained permissions that are
// specific to a user or a specific authentication method for a // specific to a user or a specific authentication method for a user.
// user. Permissions, except for "source-address", must be enforced in // The Permissions value for a successful authentication attempt is
// the server application layer, after successful authentication. The // available in ServerConn, so it can be used to pass information from
// Permissions are passed on in ServerConn so a server implementation // the user-authentication phase to the application layer.
// can honor them.
type Permissions struct { type Permissions struct {
// Critical options restrict default permissions. Common // CriticalOptions indicate restrictions to the default
// restrictions are "source-address" and "force-command". If // permissions, and are typically used in conjunction with
// the server cannot enforce the restriction, or does not // user certificates. The standard for SSH certificates
// recognize it, the user should not authenticate. // defines "force-command" (only allow the given command to
// execute) and "source-address" (only allow connections from
// the given address). The SSH package currently only enforces
// the "source-address" critical option. It is up to server
// implementations to enforce other critical options, such as
// "force-command", by checking them after the SSH handshake
// is successful. In general, SSH servers should reject
// connections that specify critical options that are unknown
// or not supported.
CriticalOptions map[string]string CriticalOptions map[string]string
// Extensions are extra functionality that the server may // Extensions are extra functionality that the server may
// offer on authenticated connections. Common extensions are // offer on authenticated connections. Lack of support for an
// "permit-agent-forwarding", "permit-X11-forwarding". Lack of // extension does not preclude authenticating a user. Common
// support for an extension does not preclude authenticating a // extensions are "permit-agent-forwarding",
// user. // "permit-X11-forwarding". The Go SSH library currently does
// not act on any extension, and it is up to server
// implementations to honor them. Extensions can be used to
// pass data from the authentication callbacks to the server
// application layer.
Extensions map[string]string Extensions map[string]string
} }
@ -45,13 +56,24 @@ type ServerConfig struct {
// authenticating. // authenticating.
NoClientAuth bool NoClientAuth bool
// MaxAuthTries specifies the maximum number of authentication attempts
// permitted per connection. If set to a negative number, the number of
// attempts are unlimited. If set to zero, the number of attempts are limited
// to 6.
MaxAuthTries int
// PasswordCallback, if non-nil, is called when a user // PasswordCallback, if non-nil, is called when a user
// attempts to authenticate using a password. // attempts to authenticate using a password.
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
// PublicKeyCallback, if non-nil, is called when a client attempts public // PublicKeyCallback, if non-nil, is called when a client
// key authentication. It must return true if the given public key is // offers a public key for authentication. It must return a nil error
// valid for the given user. For example, see CertChecker.Authenticate. // if the given public key can be used to authenticate the
// given user. For example, see CertChecker.Authenticate. A
// call to this function does not guarantee that the key
// offered is in fact used to authenticate. To record any data
// depending on the public key, store it inside a
// Permissions.Extensions entry.
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// KeyboardInteractiveCallback, if non-nil, is called when // KeyboardInteractiveCallback, if non-nil, is called when
@ -73,6 +95,10 @@ type ServerConfig struct {
// Note that RFC 4253 section 4.2 requires that this string start with // Note that RFC 4253 section 4.2 requires that this string start with
// "SSH-2.0-". // "SSH-2.0-".
ServerVersion string ServerVersion string
// BannerCallback, if present, is called and the return string is sent to
// the client after key exchange completed but before authentication.
BannerCallback func(conn ConnMetadata) string
} }
// AddHostKey adds a private key as a host key. If an existing host // AddHostKey adds a private key as a host key. If an existing host
@ -143,6 +169,10 @@ type ServerConn struct {
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config fullConf := *config
fullConf.SetDefaults() fullConf.SetDefaults()
if fullConf.MaxAuthTries == 0 {
fullConf.MaxAuthTries = 6
}
s := &connection{ s := &connection{
sshConn: sshConn{conn: c}, sshConn: sshConn{conn: c},
} }
@ -226,7 +256,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
func isAcceptableAlgo(algo string) bool { func isAcceptableAlgo(algo string) bool {
switch algo { switch algo {
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519, case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01:
return true return true
} }
return false return false
@ -262,15 +292,52 @@ func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
} }
// ServerAuthError implements the error interface. It appends any authentication
// errors that may occur, and is returned if all of the authentication methods
// provided by the user failed to authenticate.
type ServerAuthError struct {
// Errors contains authentication errors returned by the authentication
// callback methods.
Errors []error
}
func (l ServerAuthError) Error() string {
var errs []string
for _, err := range l.Errors {
errs = append(errs, err.Error())
}
return "[" + strings.Join(errs, ", ") + "]"
}
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
sessionID := s.transport.getSessionID() sessionID := s.transport.getSessionID()
var cache pubKeyCache var cache pubKeyCache
var perms *Permissions var perms *Permissions
authFailures := 0
var authErrs []error
var displayedBanner bool
userAuthLoop: userAuthLoop:
for { for {
if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
discMsg := &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
}
if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
return nil, err
}
return nil, discMsg
}
var userAuthReq userAuthRequestMsg var userAuthReq userAuthRequestMsg
if packet, err := s.transport.readPacket(); err != nil { if packet, err := s.transport.readPacket(); err != nil {
if err == io.EOF {
return nil, &ServerAuthError{Errors: authErrs}
}
return nil, err return nil, err
} else if err = Unmarshal(packet, &userAuthReq); err != nil { } else if err = Unmarshal(packet, &userAuthReq); err != nil {
return nil, err return nil, err
@ -281,6 +348,20 @@ userAuthLoop:
} }
s.user = userAuthReq.User s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil {
displayedBanner = true
msg := config.BannerCallback(s)
if msg != "" {
bannerMsg := &userAuthBannerMsg{
Message: msg,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}
perms = nil perms = nil
authErr := errors.New("no auth passed yet") authErr := errors.New("no auth passed yet")
@ -289,6 +370,11 @@ userAuthLoop:
if config.NoClientAuth { if config.NoClientAuth {
authErr = nil authErr = nil
} }
// allow initial attempt of 'none' without penalty
if authFailures == 0 {
authFailures--
}
case "password": case "password":
if config.PasswordCallback == nil { if config.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured") authErr = errors.New("ssh: password auth not configured")
@ -360,6 +446,7 @@ userAuthLoop:
if isQuery { if isQuery {
// The client can query if the given public key // The client can query if the given public key
// would be okay. // would be okay.
if len(payload) > 0 { if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest) return nil, parseError(msgUserAuthRequest)
} }
@ -401,6 +488,8 @@ userAuthLoop:
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
} }
authErrs = append(authErrs, authErr)
if config.AuthLogCallback != nil { if config.AuthLogCallback != nil {
config.AuthLogCallback(s, userAuthReq.Method, authErr) config.AuthLogCallback(s, userAuthReq.Method, authErr)
} }
@ -409,6 +498,8 @@ userAuthLoop:
break userAuthLoop break userAuthLoop
} }
authFailures++
var failureMsg userAuthFailureMsg var failureMsg userAuthFailureMsg
if config.PasswordCallback != nil { if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password") failureMsg.Methods = append(failureMsg.Methods, "password")

View File

@ -231,6 +231,26 @@ func (s *Session) RequestSubsystem(subsystem string) error {
return err return err
} }
// RFC 4254 Section 6.7.
type ptyWindowChangeMsg struct {
Columns uint32
Rows uint32
Width uint32
Height uint32
}
// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns.
func (s *Session) WindowChange(h, w int) error {
req := ptyWindowChangeMsg{
Columns: uint32(w),
Rows: uint32(h),
Width: uint32(w * 8),
Height: uint32(h * 8),
}
_, err := s.ch.SendRequest("window-change", false, Marshal(&req))
return err
}
// RFC 4254 Section 6.9. // RFC 4254 Section 6.9.
type signalMsg struct { type signalMsg struct {
Signal string Signal string
@ -386,7 +406,7 @@ func (s *Session) Wait() error {
s.stdinPipeWriter.Close() s.stdinPipeWriter.Close()
} }
var copyError error var copyError error
for _ = range s.copyFuncs { for range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil { if err := <-s.errors; err != nil && copyError == nil {
copyError = err copyError = err
} }

115
vendor/golang.org/x/crypto/ssh/streamlocal.go generated vendored Normal file
View File

@ -0,0 +1,115 @@
package ssh
import (
"errors"
"io"
"net"
)
// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "direct-streamlocal@openssh.com" string.
//
// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
type streamLocalChannelOpenDirectMsg struct {
socketPath string
reserved0 string
reserved1 uint32
}
// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "forwarded-streamlocal@openssh.com" string.
type forwardedStreamLocalPayload struct {
SocketPath string
Reserved0 string
}
// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
type streamLocalChannelForwardMsg struct {
socketPath string
}
// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
m := streamLocalChannelForwardMsg{
socketPath,
}
// send message
ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
}
ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
return &unixListener{socketPath, c, ch}, nil
}
func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
msg := streamLocalChannelOpenDirectMsg{
socketPath: socketPath,
}
ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
if err != nil {
return nil, err
}
go DiscardRequests(in)
return ch, err
}
type unixListener struct {
socketPath string
conn *Client
in <-chan forward
}
// Accept waits for and returns the next connection to the listener.
func (l *unixListener) Accept() (net.Conn, error) {
s, ok := <-l.in
if !ok {
return nil, io.EOF
}
ch, incoming, err := s.newCh.Accept()
if err != nil {
return nil, err
}
go DiscardRequests(incoming)
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
},
raddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
}, nil
}
// Close closes the listener.
func (l *unixListener) Close() error {
// this also closes the listener.
l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
m := streamLocalChannelForwardMsg{
l.socketPath,
}
ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
}
return err
}
// Addr returns the listener's network address.
func (l *unixListener) Addr() net.Addr {
return &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
}
}

View File

@ -20,12 +20,20 @@ import (
// addr. Incoming connections will be available by calling Accept on // addr. Incoming connections will be available by calling Accept on
// the returned net.Listener. The listener must be serviced, or the // the returned net.Listener. The listener must be serviced, or the
// SSH connection may hang. // SSH connection may hang.
// N must be "tcp", "tcp4", "tcp6", or "unix".
func (c *Client) Listen(n, addr string) (net.Listener, error) { func (c *Client) Listen(n, addr string) (net.Listener, error) {
laddr, err := net.ResolveTCPAddr(n, addr) switch n {
if err != nil { case "tcp", "tcp4", "tcp6":
return nil, err laddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
}
return c.ListenTCP(laddr)
case "unix":
return c.ListenUnix(addr)
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
} }
return c.ListenTCP(laddr)
} }
// Automatic port allocation is broken with OpenSSH before 6.0. See // Automatic port allocation is broken with OpenSSH before 6.0. See
@ -116,7 +124,7 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
} }
// Register this forward, using the port number we obtained. // Register this forward, using the port number we obtained.
ch := c.forwards.add(*laddr) ch := c.forwards.add(laddr)
return &tcpListener{laddr, c, ch}, nil return &tcpListener{laddr, c, ch}, nil
} }
@ -131,7 +139,7 @@ type forwardList struct {
// forwardEntry represents an established mapping of a laddr on a // forwardEntry represents an established mapping of a laddr on a
// remote ssh server to a channel connected to a tcpListener. // remote ssh server to a channel connected to a tcpListener.
type forwardEntry struct { type forwardEntry struct {
laddr net.TCPAddr laddr net.Addr
c chan forward c chan forward
} }
@ -139,16 +147,16 @@ type forwardEntry struct {
// arguments to add/remove/lookup should be address as specified in // arguments to add/remove/lookup should be address as specified in
// the original forward-request. // the original forward-request.
type forward struct { type forward struct {
newCh NewChannel // the ssh client channel underlying this forward newCh NewChannel // the ssh client channel underlying this forward
raddr *net.TCPAddr // the raddr of the incoming connection raddr net.Addr // the raddr of the incoming connection
} }
func (l *forwardList) add(addr net.TCPAddr) chan forward { func (l *forwardList) add(addr net.Addr) chan forward {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
f := forwardEntry{ f := forwardEntry{
addr, laddr: addr,
make(chan forward, 1), c: make(chan forward, 1),
} }
l.entries = append(l.entries, f) l.entries = append(l.entries, f)
return f.c return f.c
@ -176,44 +184,69 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
func (l *forwardList) handleChannels(in <-chan NewChannel) { func (l *forwardList) handleChannels(in <-chan NewChannel) {
for ch := range in { for ch := range in {
var payload forwardedTCPPayload var (
if err := Unmarshal(ch.ExtraData(), &payload); err != nil { laddr net.Addr
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) raddr net.Addr
continue err error
} )
switch channelType := ch.ChannelType(); channelType {
case "forwarded-tcpip":
var payload forwardedTCPPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
continue
}
// RFC 4254 section 7.2 specifies that incoming // RFC 4254 section 7.2 specifies that incoming
// addresses should list the address, in string // addresses should list the address, in string
// format. It is implied that this should be an IP // format. It is implied that this should be an IP
// address, as it would be impossible to connect to it // address, as it would be impossible to connect to it
// otherwise. // otherwise.
laddr, err := parseTCPAddr(payload.Addr, payload.Port) laddr, err = parseTCPAddr(payload.Addr, payload.Port)
if err != nil { if err != nil {
ch.Reject(ConnectionFailed, err.Error()) ch.Reject(ConnectionFailed, err.Error())
continue continue
} }
raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
if err != nil { if err != nil {
ch.Reject(ConnectionFailed, err.Error()) ch.Reject(ConnectionFailed, err.Error())
continue continue
} }
if ok := l.forward(*laddr, *raddr, ch); !ok { case "forwarded-streamlocal@openssh.com":
var payload forwardedStreamLocalPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
continue
}
laddr = &net.UnixAddr{
Name: payload.SocketPath,
Net: "unix",
}
raddr = &net.UnixAddr{
Name: "@",
Net: "unix",
}
default:
panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
}
if ok := l.forward(laddr, raddr, ch); !ok {
// Section 7.2, implementations MUST reject spurious incoming // Section 7.2, implementations MUST reject spurious incoming
// connections. // connections.
ch.Reject(Prohibited, "no forward for address") ch.Reject(Prohibited, "no forward for address")
continue continue
} }
} }
} }
// remove removes the forward entry, and the channel feeding its // remove removes the forward entry, and the channel feeding its
// listener. // listener.
func (l *forwardList) remove(addr net.TCPAddr) { func (l *forwardList) remove(addr net.Addr) {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
for i, f := range l.entries { for i, f := range l.entries {
if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
l.entries = append(l.entries[:i], l.entries[i+1:]...) l.entries = append(l.entries[:i], l.entries[i+1:]...)
close(f.c) close(f.c)
return return
@ -231,12 +264,12 @@ func (l *forwardList) closeAll() {
l.entries = nil l.entries = nil
} }
func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
for _, f := range l.entries { for _, f := range l.entries {
if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
f.c <- forward{ch, &raddr} f.c <- forward{newCh: ch, raddr: raddr}
return true return true
} }
} }
@ -262,7 +295,7 @@ func (l *tcpListener) Accept() (net.Conn, error) {
} }
go DiscardRequests(incoming) go DiscardRequests(incoming)
return &tcpChanConn{ return &chanConn{
Channel: ch, Channel: ch,
laddr: l.laddr, laddr: l.laddr,
raddr: s.raddr, raddr: s.raddr,
@ -277,7 +310,7 @@ func (l *tcpListener) Close() error {
} }
// this also closes the listener. // this also closes the listener.
l.conn.forwards.remove(*l.laddr) l.conn.forwards.remove(l.laddr)
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
if err == nil && !ok { if err == nil && !ok {
err = errors.New("ssh: cancel-tcpip-forward failed") err = errors.New("ssh: cancel-tcpip-forward failed")
@ -293,29 +326,52 @@ func (l *tcpListener) Addr() net.Addr {
// Dial initiates a connection to the addr from the remote host. // Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr(). // The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) { func (c *Client) Dial(n, addr string) (net.Conn, error) {
// Parse the address into host and numeric port. var ch Channel
host, portString, err := net.SplitHostPort(addr) switch n {
if err != nil { case "tcp", "tcp4", "tcp6":
return nil, err // Parse the address into host and numeric port.
host, portString, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
return &chanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
case "unix":
var err error
ch, err = c.dialStreamLocal(addr)
if err != nil {
return nil, err
}
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
raddr: &net.UnixAddr{
Name: addr,
Net: "unix",
},
}, nil
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
} }
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
return &tcpChanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
} }
// DialTCP connects to the remote address raddr on the network net, // DialTCP connects to the remote address raddr on the network net,
@ -332,7 +388,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tcpChanConn{ return &chanConn{
Channel: ch, Channel: ch,
laddr: laddr, laddr: laddr,
raddr: raddr, raddr: raddr,
@ -366,26 +422,26 @@ type tcpChan struct {
Channel // the backing channel Channel // the backing channel
} }
// tcpChanConn fulfills the net.Conn interface without // chanConn fulfills the net.Conn interface without
// the tcpChan having to hold laddr or raddr directly. // the tcpChan having to hold laddr or raddr directly.
type tcpChanConn struct { type chanConn struct {
Channel Channel
laddr, raddr net.Addr laddr, raddr net.Addr
} }
// LocalAddr returns the local network address. // LocalAddr returns the local network address.
func (t *tcpChanConn) LocalAddr() net.Addr { func (t *chanConn) LocalAddr() net.Addr {
return t.laddr return t.laddr
} }
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network address.
func (t *tcpChanConn) RemoteAddr() net.Addr { func (t *chanConn) RemoteAddr() net.Addr {
return t.raddr return t.raddr
} }
// SetDeadline sets the read and write deadlines associated // SetDeadline sets the read and write deadlines associated
// with the connection. // with the connection.
func (t *tcpChanConn) SetDeadline(deadline time.Time) error { func (t *chanConn) SetDeadline(deadline time.Time) error {
if err := t.SetReadDeadline(deadline); err != nil { if err := t.SetReadDeadline(deadline); err != nil {
return err return err
} }
@ -396,12 +452,14 @@ func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
// A zero value for t means Read will not time out. // A zero value for t means Read will not time out.
// After the deadline, the error from Read will implement net.Error // After the deadline, the error from Read will implement net.Error
// with Timeout() == true. // with Timeout() == true.
func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { func (t *chanConn) SetReadDeadline(deadline time.Time) error {
// for compatibility with previous version,
// the error message contains "tcpChan"
return errors.New("ssh: tcpChan: deadline not supported") return errors.New("ssh: tcpChan: deadline not supported")
} }
// SetWriteDeadline exists to satisfy the net.Conn interface // SetWriteDeadline exists to satisfy the net.Conn interface
// but is not implemented by this type. It always returns an error. // but is not implemented by this type. It always returns an error.
func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
return errors.New("ssh: tcpChan: deadline not supported") return errors.New("ssh: tcpChan: deadline not supported")
} }

View File

@ -76,17 +76,17 @@ type connectionState struct {
// both directions are triggered by reading and writing a msgNewKey packet // both directions are triggered by reading and writing a msgNewKey packet
// respectively. // respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil {
return err return err
} else {
t.reader.pendingKeyChange <- ciph
} }
t.reader.pendingKeyChange <- ciph
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil {
return err return err
} else {
t.writer.pendingKeyChange <- ciph
} }
t.writer.pendingKeyChange <- ciph
return nil return nil
} }
@ -139,7 +139,7 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
case cipher := <-s.pendingKeyChange: case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher s.packetCipher = cipher
default: default:
return nil, errors.New("ssh: got bogus newkeys message.") return nil, errors.New("ssh: got bogus newkeys message")
} }
case msgDisconnect: case msgDisconnect:
@ -254,7 +254,7 @@ func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (pac
iv, key, macKey := generateKeys(d, algs, kex) iv, key, macKey := generateKeys(d, algs, kex)
if algs.Cipher == gcmCipherID { if algs.Cipher == gcmCipherID {
return newGCMCipher(iv, key, macKey) return newGCMCipher(iv, key)
} }
if algs.Cipher == aes128cbcID { if algs.Cipher == aes128cbcID {

16
vendor/vendor.json vendored
View File

@ -2252,16 +2252,20 @@
"revisionTime": "2017-02-08T20:51:15Z" "revisionTime": "2017-02-08T20:51:15Z"
}, },
{ {
"checksumSHA1": "fsrFs762jlaILyqqQImS1GfvIvw=", "checksumSHA1": "pySTR3iSeU7FhjbBnQPNxgyIa4I=",
"path": "golang.org/x/crypto/ssh", "path": "golang.org/x/crypto/ssh",
"revision": "453249f01cfeb54c3d549ddb75ff152ca243f9d8", "revision": "d585fd2cc9195196078f516b69daff6744ef5e84",
"revisionTime": "2017-02-08T20:51:15Z" "revisionTime": "2017-12-16T04:08:15Z",
"version": "master",
"versionExact": "master"
}, },
{ {
"checksumSHA1": "SJ3Ma3Ozavxpbh1usZWBCnzMKIc=", "checksumSHA1": "NMRX0onGReaL9IfLr0XQ3kl5Id0=",
"path": "golang.org/x/crypto/ssh/agent", "path": "golang.org/x/crypto/ssh/agent",
"revision": "453249f01cfeb54c3d549ddb75ff152ca243f9d8", "revision": "d585fd2cc9195196078f516b69daff6744ef5e84",
"revisionTime": "2017-02-08T20:51:15Z" "revisionTime": "2017-12-16T04:08:15Z",
"version": "master",
"versionExact": "master"
}, },
{ {
"checksumSHA1": "9jjO5GjLa0XF/nfWihF02RoH4qc=", "checksumSHA1": "9jjO5GjLa0XF/nfWihF02RoH4qc=",