package ssh import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "io/ioutil" "os" "path/filepath" "testing" "golang.org/x/crypto/ssh" ) // verify that we can locate public key data func TestFindKeyData(t *testing.T) { // setup a test directory td, err := ioutil.TempDir("", "ssh") if err != nil { t.Fatal(err) } defer os.RemoveAll(td) cwd, err := os.Getwd() if err != nil { t.Fatal(err) } if err := os.Chdir(td); err != nil { t.Fatal(err) } defer os.Chdir(cwd) id := "provisioner_id" pub := generateSSHKey(t, id) pubData := pub.Marshal() // backup the pub file, and replace it with a broken file to ensure we // extract the public key from the private key. if err := os.Rename(id+".pub", "saved.pub"); err != nil { t.Fatal(err) } if err := ioutil.WriteFile(id+".pub", []byte("not a public key"), 0600); err != nil { t.Fatal(err) } foundData := findIDPublicKey(id) if !bytes.Equal(foundData, pubData) { t.Fatalf("public key %q does not match", foundData) } // move the pub file back, and break the private key file to simulate an // encrypted private key if err := os.Rename("saved.pub", id+".pub"); err != nil { t.Fatal(err) } if err := ioutil.WriteFile(id, []byte("encrypted private key"), 0600); err != nil { t.Fatal(err) } foundData = findIDPublicKey(id) if !bytes.Equal(foundData, pubData) { t.Fatalf("public key %q does not match", foundData) } // check the file by path too foundData = findIDPublicKey(filepath.Join(".", id)) if !bytes.Equal(foundData, pubData) { t.Fatalf("public key %q does not match", foundData) } } func generateSSHKey(t *testing.T, idFile string) ssh.PublicKey { t.Helper() priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) } privFile, err := os.OpenFile(idFile, os.O_RDWR|os.O_CREATE, 0600) if err != nil { t.Fatal(err) } defer privFile.Close() privPEM := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)} if err := pem.Encode(privFile, privPEM); err != nil { t.Fatal(err) } // generate and write public key pub, err := ssh.NewPublicKey(&priv.PublicKey) if err != nil { t.Fatal(err) } err = ioutil.WriteFile(idFile+".pub", ssh.MarshalAuthorizedKey(pub), 0600) if err != nil { t.Fatal(err) } return pub }