Remove svchost package
This commit is contained in:
parent
32f9722d9d
commit
cd21a3859d
|
@ -1,61 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
)
|
||||
|
||||
// CachingCredentialsSource creates a new credentials source that wraps another
|
||||
// and caches its results in memory, on a per-hostname basis.
|
||||
//
|
||||
// No means is provided for expiration of cached credentials, so a caching
|
||||
// credentials source should have a limited lifetime (one Terraform operation,
|
||||
// for example) to ensure that time-limited credentials don't expire before
|
||||
// their cache entries do.
|
||||
func CachingCredentialsSource(source CredentialsSource) CredentialsSource {
|
||||
return &cachingCredentialsSource{
|
||||
source: source,
|
||||
cache: map[svchost.Hostname]HostCredentials{},
|
||||
}
|
||||
}
|
||||
|
||||
type cachingCredentialsSource struct {
|
||||
source CredentialsSource
|
||||
cache map[svchost.Hostname]HostCredentials
|
||||
}
|
||||
|
||||
// ForHost passes the given hostname on to the wrapped credentials source and
|
||||
// caches the result to return for future requests with the same hostname.
|
||||
//
|
||||
// Both credentials and non-credentials (nil) responses are cached.
|
||||
//
|
||||
// No cache entry is created if the wrapped source returns an error, to allow
|
||||
// the caller to retry the failing operation.
|
||||
func (s *cachingCredentialsSource) ForHost(host svchost.Hostname) (HostCredentials, error) {
|
||||
if cache, cached := s.cache[host]; cached {
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
result, err := s.source.ForHost(host)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
s.cache[host] = result
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *cachingCredentialsSource) StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error {
|
||||
// We'll delete the cache entry even if the store fails, since that just
|
||||
// means that the next read will go to the real store and get a chance to
|
||||
// see which object (old or new) is actually present.
|
||||
delete(s.cache, host)
|
||||
return s.source.StoreForHost(host, credentials)
|
||||
}
|
||||
|
||||
func (s *cachingCredentialsSource) ForgetForHost(host svchost.Hostname) error {
|
||||
// We'll delete the cache entry even if the store fails, since that just
|
||||
// means that the next read will go to the real store and get a chance to
|
||||
// see if the object is still present.
|
||||
delete(s.cache, host)
|
||||
return s.source.ForgetForHost(host)
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
// Package auth contains types and functions to manage authentication
|
||||
// credentials for service hosts.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/zclconf/go-cty/cty"
|
||||
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
)
|
||||
|
||||
// Credentials is a list of CredentialsSource objects that can be tried in
|
||||
// turn until one returns credentials for a host, or one returns an error.
|
||||
//
|
||||
// A Credentials is itself a CredentialsSource, wrapping its members.
|
||||
// In principle one CredentialsSource can be nested inside another, though
|
||||
// there is no good reason to do so.
|
||||
//
|
||||
// The write operations on a Credentials are tried only on the first object,
|
||||
// under the assumption that it is the primary store.
|
||||
type Credentials []CredentialsSource
|
||||
|
||||
// NoCredentials is an empty CredentialsSource that always returns nil
|
||||
// when asked for credentials.
|
||||
var NoCredentials CredentialsSource = Credentials{}
|
||||
|
||||
// A CredentialsSource is an object that may be able to provide credentials
|
||||
// for a given host.
|
||||
//
|
||||
// Credentials lookups are not guaranteed to be concurrency-safe. Callers
|
||||
// using these facilities in concurrent code must use external concurrency
|
||||
// primitives to prevent race conditions.
|
||||
type CredentialsSource interface {
|
||||
// ForHost returns a non-nil HostCredentials if the source has credentials
|
||||
// available for the host, and a nil HostCredentials if it does not.
|
||||
//
|
||||
// If an error is returned, progress through a list of CredentialsSources
|
||||
// is halted and the error is returned to the user.
|
||||
ForHost(host svchost.Hostname) (HostCredentials, error)
|
||||
|
||||
// StoreForHost takes a HostCredentialsWritable and saves it as the
|
||||
// credentials for the given host.
|
||||
//
|
||||
// If credentials are already stored for the given host, it will try to
|
||||
// replace those credentials but may produce an error if such replacement
|
||||
// is not possible.
|
||||
StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error
|
||||
|
||||
// ForgetForHost discards any stored credentials for the given host. It
|
||||
// does nothing and returns successfully if no credentials are saved
|
||||
// for that host.
|
||||
ForgetForHost(host svchost.Hostname) error
|
||||
}
|
||||
|
||||
// HostCredentials represents a single set of credentials for a particular
|
||||
// host.
|
||||
type HostCredentials interface {
|
||||
// PrepareRequest modifies the given request in-place to apply the
|
||||
// receiving credentials. The usual behavior of this method is to
|
||||
// add some sort of Authorization header to the request.
|
||||
PrepareRequest(req *http.Request)
|
||||
|
||||
// Token returns the authentication token.
|
||||
Token() string
|
||||
}
|
||||
|
||||
// HostCredentialsWritable is an extension of HostCredentials for credentials
|
||||
// objects that can be serialized as a JSON-compatible object value for
|
||||
// storage.
|
||||
type HostCredentialsWritable interface {
|
||||
HostCredentials
|
||||
|
||||
// ToStore returns a cty.Value, always of an object type,
|
||||
// representing data that can be serialized to represent this object
|
||||
// in persistent storage.
|
||||
//
|
||||
// The resulting value may uses only cty values that can be accepted
|
||||
// by the cty JSON encoder, though the caller may elect to instead store
|
||||
// it in some other format that has a JSON-compatible type system.
|
||||
ToStore() cty.Value
|
||||
}
|
||||
|
||||
// ForHost iterates over the contained CredentialsSource objects and
|
||||
// tries to obtain credentials for the given host from each one in turn.
|
||||
//
|
||||
// If any source returns either a non-nil HostCredentials or a non-nil error
|
||||
// then this result is returned. Otherwise, the result is nil, nil.
|
||||
func (c Credentials) ForHost(host svchost.Hostname) (HostCredentials, error) {
|
||||
for _, source := range c {
|
||||
creds, err := source.ForHost(host)
|
||||
if creds != nil || err != nil {
|
||||
return creds, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// StoreForHost passes the given arguments to the same operation on the
|
||||
// first CredentialsSource in the receiver.
|
||||
func (c Credentials) StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error {
|
||||
if len(c) == 0 {
|
||||
return fmt.Errorf("no credentials store is available")
|
||||
}
|
||||
|
||||
return c[0].StoreForHost(host, credentials)
|
||||
}
|
||||
|
||||
// ForgetForHost passes the given arguments to the same operation on the
|
||||
// first CredentialsSource in the receiver.
|
||||
func (c Credentials) ForgetForHost(host svchost.Hostname) error {
|
||||
if len(c) == 0 {
|
||||
return fmt.Errorf("no credentials store is available")
|
||||
}
|
||||
|
||||
return c[0].ForgetForHost(host)
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"github.com/zclconf/go-cty/cty"
|
||||
)
|
||||
|
||||
// HostCredentialsFromMap converts a map of key-value pairs from a credentials
|
||||
// definition provided by the user (e.g. in a config file, or via a credentials
|
||||
// helper) into a HostCredentials object if possible, or returns nil if
|
||||
// no credentials could be extracted from the map.
|
||||
//
|
||||
// This function ignores map keys it is unfamiliar with, to allow for future
|
||||
// expansion of the credentials map format for new credential types.
|
||||
func HostCredentialsFromMap(m map[string]interface{}) HostCredentials {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
if token, ok := m["token"].(string); ok {
|
||||
return HostCredentialsToken(token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HostCredentialsFromObject converts a cty.Value of an object type into a
|
||||
// HostCredentials object if possible, or returns nil if no credentials could
|
||||
// be extracted from the map.
|
||||
//
|
||||
// This function ignores object attributes it is unfamiliar with, to allow for
|
||||
// future expansion of the credentials object structure for new credential types.
|
||||
//
|
||||
// If the given value is not of an object type, this function will panic.
|
||||
func HostCredentialsFromObject(obj cty.Value) HostCredentials {
|
||||
if !obj.Type().HasAttribute("token") {
|
||||
return nil
|
||||
}
|
||||
|
||||
tokenV := obj.GetAttr("token")
|
||||
if tokenV.IsNull() || !tokenV.IsKnown() {
|
||||
return nil
|
||||
}
|
||||
if !cty.String.Equals(tokenV.Type()) {
|
||||
// Weird, but maybe some future Terraform version accepts an object
|
||||
// here for some reason, so we'll be resilient.
|
||||
return nil
|
||||
}
|
||||
|
||||
return HostCredentialsToken(tokenV.AsString())
|
||||
}
|
|
@ -1,149 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
ctyjson "github.com/zclconf/go-cty/cty/json"
|
||||
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
)
|
||||
|
||||
type helperProgramCredentialsSource struct {
|
||||
executable string
|
||||
args []string
|
||||
}
|
||||
|
||||
// HelperProgramCredentialsSource returns a CredentialsSource that runs the
|
||||
// given program with the given arguments in order to obtain credentials.
|
||||
//
|
||||
// The given executable path must be an absolute path; it is the caller's
|
||||
// responsibility to validate and process a relative path or other input
|
||||
// provided by an end-user. If the given path is not absolute, this
|
||||
// function will panic.
|
||||
//
|
||||
// When credentials are requested, the program will be run in a child process
|
||||
// with the given arguments along with two additional arguments added to the
|
||||
// end of the list: the literal string "get", followed by the requested
|
||||
// hostname in ASCII compatibility form (punycode form).
|
||||
func HelperProgramCredentialsSource(executable string, args ...string) CredentialsSource {
|
||||
if !filepath.IsAbs(executable) {
|
||||
panic("NewCredentialsSourceHelperProgram requires absolute path to executable")
|
||||
}
|
||||
|
||||
fullArgs := make([]string, len(args)+1)
|
||||
fullArgs[0] = executable
|
||||
copy(fullArgs[1:], args)
|
||||
|
||||
return &helperProgramCredentialsSource{
|
||||
executable: executable,
|
||||
args: fullArgs,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *helperProgramCredentialsSource) ForHost(host svchost.Hostname) (HostCredentials, error) {
|
||||
args := make([]string, len(s.args), len(s.args)+2)
|
||||
copy(args, s.args)
|
||||
args = append(args, "get")
|
||||
args = append(args, string(host))
|
||||
|
||||
outBuf := bytes.Buffer{}
|
||||
errBuf := bytes.Buffer{}
|
||||
|
||||
cmd := exec.Cmd{
|
||||
Path: s.executable,
|
||||
Args: args,
|
||||
Stdin: nil,
|
||||
Stdout: &outBuf,
|
||||
Stderr: &errBuf,
|
||||
}
|
||||
err := cmd.Run()
|
||||
if _, isExitErr := err.(*exec.ExitError); isExitErr {
|
||||
errText := errBuf.String()
|
||||
if errText == "" {
|
||||
// Shouldn't happen for a well-behaved helper program
|
||||
return nil, fmt.Errorf("error in %s, but it produced no error message", s.executable)
|
||||
}
|
||||
return nil, fmt.Errorf("error in %s: %s", s.executable, errText)
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to run %s: %s", s.executable, err)
|
||||
}
|
||||
|
||||
var m map[string]interface{}
|
||||
err = json.Unmarshal(outBuf.Bytes(), &m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("malformed output from %s: %s", s.executable, err)
|
||||
}
|
||||
|
||||
return HostCredentialsFromMap(m), nil
|
||||
}
|
||||
|
||||
func (s *helperProgramCredentialsSource) StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error {
|
||||
args := make([]string, len(s.args), len(s.args)+2)
|
||||
copy(args, s.args)
|
||||
args = append(args, "store")
|
||||
args = append(args, string(host))
|
||||
|
||||
toStore := credentials.ToStore()
|
||||
toStoreRaw, err := ctyjson.Marshal(toStore, toStore.Type())
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't serialize credentials to store: %s", err)
|
||||
}
|
||||
|
||||
inReader := bytes.NewReader(toStoreRaw)
|
||||
errBuf := bytes.Buffer{}
|
||||
|
||||
cmd := exec.Cmd{
|
||||
Path: s.executable,
|
||||
Args: args,
|
||||
Stdin: inReader,
|
||||
Stderr: &errBuf,
|
||||
Stdout: nil,
|
||||
}
|
||||
err = cmd.Run()
|
||||
if _, isExitErr := err.(*exec.ExitError); isExitErr {
|
||||
errText := errBuf.String()
|
||||
if errText == "" {
|
||||
// Shouldn't happen for a well-behaved helper program
|
||||
return fmt.Errorf("error in %s, but it produced no error message", s.executable)
|
||||
}
|
||||
return fmt.Errorf("error in %s: %s", s.executable, errText)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to run %s: %s", s.executable, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperProgramCredentialsSource) ForgetForHost(host svchost.Hostname) error {
|
||||
args := make([]string, len(s.args), len(s.args)+2)
|
||||
copy(args, s.args)
|
||||
args = append(args, "forget")
|
||||
args = append(args, string(host))
|
||||
|
||||
errBuf := bytes.Buffer{}
|
||||
|
||||
cmd := exec.Cmd{
|
||||
Path: s.executable,
|
||||
Args: args,
|
||||
Stdin: nil,
|
||||
Stderr: &errBuf,
|
||||
Stdout: nil,
|
||||
}
|
||||
err := cmd.Run()
|
||||
if _, isExitErr := err.(*exec.ExitError); isExitErr {
|
||||
errText := errBuf.String()
|
||||
if errText == "" {
|
||||
// Shouldn't happen for a well-behaved helper program
|
||||
return fmt.Errorf("error in %s, but it produced no error message", s.executable)
|
||||
}
|
||||
return fmt.Errorf("error in %s: %s", s.executable, errText)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to run %s: %s", s.executable, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,83 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
)
|
||||
|
||||
func TestHelperProgramCredentialsSource(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
program := filepath.Join(wd, "testdata/test-helper")
|
||||
t.Logf("testing with helper at %s", program)
|
||||
|
||||
src := HelperProgramCredentialsSource(program)
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
creds, err := src.ForHost(svchost.Hostname("example.com"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tokCreds, isTok := creds.(HostCredentialsToken); isTok {
|
||||
if got, want := string(tokCreds), "example-token"; got != want {
|
||||
t.Errorf("wrong token %q; want %q", got, want)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("wrong type of credentials %T", creds)
|
||||
}
|
||||
})
|
||||
t.Run("no credentials", func(t *testing.T) {
|
||||
creds, err := src.ForHost(svchost.Hostname("nothing.example.com"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if creds != nil {
|
||||
t.Errorf("got credentials; want nil")
|
||||
}
|
||||
})
|
||||
t.Run("unsupported credentials type", func(t *testing.T) {
|
||||
creds, err := src.ForHost(svchost.Hostname("other-cred-type.example.com"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if creds != nil {
|
||||
t.Errorf("got credentials; want nil")
|
||||
}
|
||||
})
|
||||
t.Run("lookup error", func(t *testing.T) {
|
||||
_, err := src.ForHost(svchost.Hostname("fail.example.com"))
|
||||
if err == nil {
|
||||
t.Error("completed successfully; want error")
|
||||
}
|
||||
})
|
||||
t.Run("store happy path", func(t *testing.T) {
|
||||
err := src.StoreForHost(svchost.Hostname("example.com"), HostCredentialsToken("example-token"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
t.Run("store error", func(t *testing.T) {
|
||||
err := src.StoreForHost(svchost.Hostname("fail.example.com"), HostCredentialsToken("example-token"))
|
||||
if err == nil {
|
||||
t.Error("completed successfully; want error")
|
||||
}
|
||||
})
|
||||
t.Run("forget happy path", func(t *testing.T) {
|
||||
err := src.ForgetForHost(svchost.Hostname("example.com"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
t.Run("forget error", func(t *testing.T) {
|
||||
err := src.ForgetForHost(svchost.Hostname("fail.example.com"))
|
||||
if err == nil {
|
||||
t.Error("completed successfully; want error")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
)
|
||||
|
||||
// StaticCredentialsSource is a credentials source that retrieves credentials
|
||||
// from the provided map. It returns nil if a requested hostname is not
|
||||
// present in the map.
|
||||
//
|
||||
// The caller should not modify the given map after passing it to this function.
|
||||
func StaticCredentialsSource(creds map[svchost.Hostname]map[string]interface{}) CredentialsSource {
|
||||
return staticCredentialsSource(creds)
|
||||
}
|
||||
|
||||
type staticCredentialsSource map[svchost.Hostname]map[string]interface{}
|
||||
|
||||
func (s staticCredentialsSource) ForHost(host svchost.Hostname) (HostCredentials, error) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if m, exists := s[host]; exists {
|
||||
return HostCredentialsFromMap(m), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s staticCredentialsSource) StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error {
|
||||
return fmt.Errorf("can't store new credentials in a static credentials source")
|
||||
}
|
||||
|
||||
func (s staticCredentialsSource) ForgetForHost(host svchost.Hostname) error {
|
||||
return fmt.Errorf("can't discard credentials from a static credentials source")
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
)
|
||||
|
||||
func TestStaticCredentialsSource(t *testing.T) {
|
||||
src := StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
|
||||
svchost.Hostname("example.com"): map[string]interface{}{
|
||||
"token": "abc123",
|
||||
},
|
||||
})
|
||||
|
||||
t.Run("exists", func(t *testing.T) {
|
||||
creds, err := src.ForHost(svchost.Hostname("example.com"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tokCreds, isToken := creds.(HostCredentialsToken); isToken {
|
||||
if got, want := string(tokCreds), "abc123"; got != want {
|
||||
t.Errorf("wrong token %q; want %q", got, want)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("creds is %#v; want HostCredentialsToken", creds)
|
||||
}
|
||||
})
|
||||
t.Run("does not exist", func(t *testing.T) {
|
||||
creds, err := src.ForHost(svchost.Hostname("example.net"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if creds != nil {
|
||||
t.Errorf("creds is %#v; want nil", creds)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
main
|
|
@ -1,64 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
// This is a simple program that implements the "helper program" protocol
|
||||
// for the svchost/auth package for unit testing purposes.
|
||||
|
||||
func main() {
|
||||
args := os.Args
|
||||
|
||||
if len(args) < 3 {
|
||||
die("not enough arguments\n")
|
||||
}
|
||||
|
||||
host := args[2]
|
||||
switch args[1] {
|
||||
case "get":
|
||||
switch host {
|
||||
case "example.com":
|
||||
fmt.Print(`{"token":"example-token"}`)
|
||||
case "other-cred-type.example.com":
|
||||
fmt.Print(`{"username":"alfred"}`) // unrecognized by main program
|
||||
case "fail.example.com":
|
||||
die("failing because you told me to fail\n")
|
||||
default:
|
||||
fmt.Print("{}") // no credentials available
|
||||
}
|
||||
case "store":
|
||||
dataSrc, err := ioutil.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
die("invalid input: %s", err)
|
||||
}
|
||||
var data map[string]interface{}
|
||||
err = json.Unmarshal(dataSrc, &data)
|
||||
|
||||
switch host {
|
||||
case "example.com":
|
||||
if data["token"] != "example-token" {
|
||||
die("incorrect token value to store")
|
||||
}
|
||||
default:
|
||||
die("can't store credentials for %s", host)
|
||||
}
|
||||
case "forget":
|
||||
switch host {
|
||||
case "example.com":
|
||||
// okay!
|
||||
default:
|
||||
die("can't forget credentials for %s", host)
|
||||
}
|
||||
default:
|
||||
die("unknown subcommand %q\n", args[1])
|
||||
}
|
||||
}
|
||||
|
||||
func die(f string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, fmt.Sprintf(f, args...))
|
||||
os.Exit(1)
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -eu
|
||||
|
||||
cd "$( dirname "${BASH_SOURCE[0]}" )"
|
||||
[ -x main ] || go build -o main .
|
||||
exec ./main "$@"
|
|
@ -1,43 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/zclconf/go-cty/cty"
|
||||
)
|
||||
|
||||
// HostCredentialsToken is a HostCredentials implementation that represents a
|
||||
// single "bearer token", to be sent to the server via an Authorization header
|
||||
// with the auth type set to "Bearer".
|
||||
//
|
||||
// To save a token as the credentials for a host, convert the token string to
|
||||
// this type and use the result as a HostCredentialsWritable implementation.
|
||||
type HostCredentialsToken string
|
||||
|
||||
// Interface implementation assertions. Compilation will fail here if
|
||||
// HostCredentialsToken does not fully implement these interfaces.
|
||||
var _ HostCredentials = HostCredentialsToken("")
|
||||
var _ HostCredentialsWritable = HostCredentialsToken("")
|
||||
|
||||
// PrepareRequest alters the given HTTP request by setting its Authorization
|
||||
// header to the string "Bearer " followed by the encapsulated authentication
|
||||
// token.
|
||||
func (tc HostCredentialsToken) PrepareRequest(req *http.Request) {
|
||||
if req.Header == nil {
|
||||
req.Header = http.Header{}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+string(tc))
|
||||
}
|
||||
|
||||
// Token returns the authentication token.
|
||||
func (tc HostCredentialsToken) Token() string {
|
||||
return string(tc)
|
||||
}
|
||||
|
||||
// ToStore returns a credentials object with a single attribute "token" whose
|
||||
// value is the token string.
|
||||
func (tc HostCredentialsToken) ToStore() cty.Value {
|
||||
return cty.ObjectVal(map[string]cty.Value{
|
||||
"token": cty.StringVal(string(tc)),
|
||||
})
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/zclconf/go-cty/cty"
|
||||
)
|
||||
|
||||
func TestHostCredentialsToken(t *testing.T) {
|
||||
creds := HostCredentialsToken("foo-bar")
|
||||
|
||||
{
|
||||
req := &http.Request{}
|
||||
creds.PrepareRequest(req)
|
||||
authStr := req.Header.Get("authorization")
|
||||
if got, want := authStr, "Bearer foo-bar"; got != want {
|
||||
t.Errorf("wrong Authorization header value %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
got := creds.ToStore()
|
||||
want := cty.ObjectVal(map[string]cty.Value{
|
||||
"token": cty.StringVal("foo-bar"),
|
||||
})
|
||||
if !want.RawEquals(got) {
|
||||
t.Errorf("wrong storable object value\ngot: %#v\nwant: %#v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,271 +0,0 @@
|
|||
// Package disco handles Terraform's remote service discovery protocol.
|
||||
//
|
||||
// This protocol allows mapping from a service hostname, as produced by the
|
||||
// svchost package, to a set of services supported by that host and the
|
||||
// endpoint information for each supported service.
|
||||
package disco
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/terraform/httpclient"
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
"github.com/hashicorp/terraform/svchost/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
// Fixed path to the discovery manifest.
|
||||
discoPath = "/.well-known/terraform.json"
|
||||
|
||||
// Arbitrary-but-small number to prevent runaway redirect loops.
|
||||
maxRedirects = 3
|
||||
|
||||
// Arbitrary-but-small time limit to prevent UI "hangs" during discovery.
|
||||
discoTimeout = 11 * time.Second
|
||||
|
||||
// 1MB - to prevent abusive services from using loads of our memory.
|
||||
maxDiscoDocBytes = 1 * 1024 * 1024
|
||||
)
|
||||
|
||||
// httpTransport is overridden during tests, to skip TLS verification.
|
||||
var httpTransport = cleanhttp.DefaultPooledTransport()
|
||||
|
||||
// Disco is the main type in this package, which allows discovery on given
|
||||
// hostnames and caches the results by hostname to avoid repeated requests
|
||||
// for the same information.
|
||||
type Disco struct {
|
||||
hostCache map[svchost.Hostname]*Host
|
||||
credsSrc auth.CredentialsSource
|
||||
|
||||
// Transport is a custom http.RoundTripper to use.
|
||||
Transport http.RoundTripper
|
||||
}
|
||||
|
||||
// New returns a new initialized discovery object.
|
||||
func New() *Disco {
|
||||
return NewWithCredentialsSource(nil)
|
||||
}
|
||||
|
||||
// NewWithCredentialsSource returns a new discovery object initialized with
|
||||
// the given credentials source.
|
||||
func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
|
||||
return &Disco{
|
||||
hostCache: make(map[svchost.Hostname]*Host),
|
||||
credsSrc: credsSrc,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCredentialsSource provides a credentials source that will be used to
|
||||
// add credentials to outgoing discovery requests, where available.
|
||||
//
|
||||
// If this method is never called, no outgoing discovery requests will have
|
||||
// credentials.
|
||||
func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
|
||||
d.credsSrc = src
|
||||
}
|
||||
|
||||
// CredentialsSource returns the credentials source associated with the receiver,
|
||||
// or an empty credentials source if none is associated.
|
||||
func (d *Disco) CredentialsSource() auth.CredentialsSource {
|
||||
if d.credsSrc == nil {
|
||||
// We'll return an empty one just to save the caller from having to
|
||||
// protect against the nil case, since this interface already allows
|
||||
// for the possibility of there being no credentials at all.
|
||||
return auth.StaticCredentialsSource(nil)
|
||||
}
|
||||
return d.credsSrc
|
||||
}
|
||||
|
||||
// CredentialsForHost returns a non-nil HostCredentials if the embedded source has
|
||||
// credentials available for the host, and a nil HostCredentials if it does not.
|
||||
func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredentials, error) {
|
||||
if d.credsSrc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return d.credsSrc.ForHost(hostname)
|
||||
}
|
||||
|
||||
// ForceHostServices provides a pre-defined set of services for a given
|
||||
// host, which prevents the receiver from attempting network-based discovery
|
||||
// for the given host. Instead, the given services map will be returned
|
||||
// verbatim.
|
||||
//
|
||||
// When providing "forced" services, any relative URLs are resolved against
|
||||
// the initial discovery URL that would have been used for network-based
|
||||
// discovery, yielding the same results as if the given map were published
|
||||
// at the host's default discovery URL, though using absolute URLs is strongly
|
||||
// recommended to make the configured behavior more explicit.
|
||||
func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string]interface{}) {
|
||||
if services == nil {
|
||||
services = map[string]interface{}{}
|
||||
}
|
||||
|
||||
d.hostCache[hostname] = &Host{
|
||||
discoURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: string(hostname),
|
||||
Path: discoPath,
|
||||
},
|
||||
hostname: hostname.ForDisplay(),
|
||||
services: services,
|
||||
transport: d.Transport,
|
||||
}
|
||||
}
|
||||
|
||||
// Discover runs the discovery protocol against the given hostname (which must
|
||||
// already have been validated and prepared with svchost.ForComparison) and
|
||||
// returns an object describing the services available at that host.
|
||||
//
|
||||
// If a given hostname supports no Terraform services at all, a non-nil but
|
||||
// empty Host object is returned. When giving feedback to the end user about
|
||||
// such situations, we say "host <name> does not provide a <service> service",
|
||||
// regardless of whether that is due to that service specifically being absent
|
||||
// or due to the host not providing Terraform services at all, since we don't
|
||||
// wish to expose the detail of whole-host discovery to an end-user.
|
||||
func (d *Disco) Discover(hostname svchost.Hostname) (*Host, error) {
|
||||
if host, cached := d.hostCache[hostname]; cached {
|
||||
return host, nil
|
||||
}
|
||||
|
||||
host, err := d.discover(hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.hostCache[hostname] = host
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// DiscoverServiceURL is a convenience wrapper for discovery on a given
|
||||
// hostname and then looking up a particular service in the result.
|
||||
func (d *Disco) DiscoverServiceURL(hostname svchost.Hostname, serviceID string) (*url.URL, error) {
|
||||
host, err := d.Discover(hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return host.ServiceURL(serviceID)
|
||||
}
|
||||
|
||||
// discover implements the actual discovery process, with its result cached
|
||||
// by the public-facing Discover method.
|
||||
func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
|
||||
discoURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: hostname.String(),
|
||||
Path: discoPath,
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: d.Transport,
|
||||
Timeout: discoTimeout,
|
||||
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
log.Printf("[DEBUG] Service discovery redirected to %s", req.URL)
|
||||
if len(via) > maxRedirects {
|
||||
return errors.New("too many redirects") // this error will never actually be seen
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Header: make(http.Header),
|
||||
Method: "GET",
|
||||
URL: discoURL,
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", httpclient.UserAgentString())
|
||||
|
||||
creds, err := d.CredentialsForHost(hostname)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", hostname, err)
|
||||
}
|
||||
if creds != nil {
|
||||
// Update the request to include credentials.
|
||||
creds.PrepareRequest(req)
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] Service discovery for %s at %s", hostname, discoURL)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to request discovery document: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
host := &Host{
|
||||
// Use the discovery URL from resp.Request in
|
||||
// case the client followed any redirects.
|
||||
discoURL: resp.Request.URL,
|
||||
hostname: hostname.ForDisplay(),
|
||||
transport: d.Transport,
|
||||
}
|
||||
|
||||
// Return the host without any services.
|
||||
if resp.StatusCode == 404 {
|
||||
return host, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("Failed to request discovery document: %s", resp.Status)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Discovery URL has a malformed Content-Type %q", contentType)
|
||||
}
|
||||
if mediaType != "application/json" {
|
||||
return nil, fmt.Errorf("Discovery URL returned an unsupported Content-Type %q", mediaType)
|
||||
}
|
||||
|
||||
// This doesn't catch chunked encoding, because ContentLength is -1 in that case.
|
||||
if resp.ContentLength > maxDiscoDocBytes {
|
||||
// Size limit here is not a contractual requirement and so we may
|
||||
// adjust it over time if we find a different limit is warranted.
|
||||
return nil, fmt.Errorf(
|
||||
"Discovery doc response is too large (got %d bytes; limit %d)",
|
||||
resp.ContentLength, maxDiscoDocBytes,
|
||||
)
|
||||
}
|
||||
|
||||
// If the response is using chunked encoding then we can't predict its
|
||||
// size, but we'll at least prevent reading the entire thing into memory.
|
||||
lr := io.LimitReader(resp.Body, maxDiscoDocBytes)
|
||||
|
||||
servicesBytes, err := ioutil.ReadAll(lr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error reading discovery document body: %v", err)
|
||||
}
|
||||
|
||||
var services map[string]interface{}
|
||||
err = json.Unmarshal(servicesBytes, &services)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to decode discovery document as a JSON object: %v", err)
|
||||
}
|
||||
host.services = services
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// Forget invalidates any cached record of the given hostname. If the host
|
||||
// has no cache entry then this is a no-op.
|
||||
func (d *Disco) Forget(hostname svchost.Hostname) {
|
||||
delete(d.hostCache, hostname)
|
||||
}
|
||||
|
||||
// ForgetAll is like Forget, but for all of the hostnames that have cache entries.
|
||||
func (d *Disco) ForgetAll() {
|
||||
d.hostCache = make(map[svchost.Hostname]*Host)
|
||||
}
|
|
@ -1,357 +0,0 @@
|
|||
package disco
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
"github.com/hashicorp/terraform/svchost/auth"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// During all tests we override the HTTP transport we use for discovery
|
||||
// so it'll tolerate the locally-generated TLS certificates we use
|
||||
// for test URLs.
|
||||
httpTransport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestDiscover(t *testing.T) {
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`
|
||||
{
|
||||
"thingy.v1": "http://example.com/foo",
|
||||
"wotsit.v2": "http://example.net/bar"
|
||||
}
|
||||
`)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := New()
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
gotURL, err := discovered.ServiceURL("thingy.v1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for thingy.v1")
|
||||
}
|
||||
if got, want := gotURL.String(), "http://example.com/foo"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("chunked encoding", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`
|
||||
{
|
||||
"thingy.v1": "http://example.com/foo",
|
||||
"wotsit.v2": "http://example.net/bar"
|
||||
}
|
||||
`)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
// We're going to force chunked encoding here -- and thus prevent
|
||||
// the server from predicting the length -- so we can make sure
|
||||
// our client is tolerant of servers using this encoding.
|
||||
w.Write(resp[:5])
|
||||
w.(http.Flusher).Flush()
|
||||
w.Write(resp[5:])
|
||||
w.(http.Flusher).Flush()
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := New()
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
gotURL, err := discovered.ServiceURL("wotsit.v2")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for wotsit.v2")
|
||||
}
|
||||
if got, want := gotURL.String(), "http://example.net/bar"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("with credentials", func(t *testing.T) {
|
||||
var authHeaderText string
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`{}`)
|
||||
authHeaderText = r.Header.Get("Authorization")
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := New()
|
||||
d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
|
||||
host: map[string]interface{}{
|
||||
"token": "abc123",
|
||||
},
|
||||
}))
|
||||
d.Discover(host)
|
||||
if got, want := authHeaderText, "Bearer abc123"; got != want {
|
||||
t.Fatalf("wrong Authorization header\ngot: %s\nwant: %s", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("forced services override", func(t *testing.T) {
|
||||
forced := map[string]interface{}{
|
||||
"thingy.v1": "http://example.net/foo",
|
||||
"wotsit.v2": "/foo",
|
||||
}
|
||||
|
||||
d := New()
|
||||
d.ForceHostServices(svchost.Hostname("example.com"), forced)
|
||||
|
||||
givenHost := "example.com"
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
{
|
||||
gotURL, err := discovered.ServiceURL("thingy.v1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for thingy.v1")
|
||||
}
|
||||
if got, want := gotURL.String(), "http://example.net/foo"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
{
|
||||
gotURL, err := discovered.ServiceURL("wotsit.v2")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for wotsit.v2")
|
||||
}
|
||||
if got, want := gotURL.String(), "https://example.com/foo"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("not JSON", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
|
||||
w.Header().Add("Content-Type", "application/octet-stream")
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := New()
|
||||
discovered, err := d.Discover(host)
|
||||
if err == nil {
|
||||
t.Fatalf("expected a discovery error")
|
||||
}
|
||||
|
||||
// Returned discovered should be nil.
|
||||
if discovered != nil {
|
||||
t.Errorf("discovered not nil; should be")
|
||||
}
|
||||
})
|
||||
t.Run("malformed JSON", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`{"thingy.v1": "htt`) // truncated, for example...
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := New()
|
||||
discovered, err := d.Discover(host)
|
||||
if err == nil {
|
||||
t.Fatalf("expected a discovery error")
|
||||
}
|
||||
|
||||
// Returned discovered should be nil.
|
||||
if discovered != nil {
|
||||
t.Errorf("discovered not nil; should be")
|
||||
}
|
||||
})
|
||||
t.Run("JSON with redundant charset", func(t *testing.T) {
|
||||
// The JSON RFC defines no parameters for the application/json
|
||||
// MIME type, but some servers have a weird tendency to just add
|
||||
// "charset" to everything, so we'll make sure we ignore it successfully.
|
||||
// (JSON uses content sniffing for encoding detection, not media type params.)
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
|
||||
w.Header().Add("Content-Type", "application/json; charset=latin-1")
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := New()
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
if discovered.services == nil {
|
||||
t.Errorf("response is empty; shouldn't be")
|
||||
}
|
||||
})
|
||||
t.Run("no discovery doc", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(404)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := New()
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
// Returned discovered.services should be nil (empty).
|
||||
if discovered.services != nil {
|
||||
t.Errorf("discovered.services not nil (empty); should be")
|
||||
}
|
||||
})
|
||||
t.Run("redirect", func(t *testing.T) {
|
||||
// For this test, we have two servers and one redirects to the other
|
||||
portStr1, close1 := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
// This server is the one that returns a real response.
|
||||
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
|
||||
w.Write(resp)
|
||||
})
|
||||
portStr2, close2 := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
// This server is the one that redirects.
|
||||
http.Redirect(w, r, "https://127.0.0.1"+portStr1+"/.well-known/terraform.json", 302)
|
||||
})
|
||||
defer close1()
|
||||
defer close2()
|
||||
|
||||
givenHost := "localhost" + portStr2
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := New()
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
gotURL, err := discovered.ServiceURL("thingy.v1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for thingy.v1")
|
||||
}
|
||||
if got, want := gotURL.String(), "http://example.com/foo"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
|
||||
// The base URL for the host object should be the URL we redirected to,
|
||||
// rather than the we redirected _from_.
|
||||
gotBaseURL := discovered.discoURL.String()
|
||||
wantBaseURL := "https://127.0.0.1" + portStr1 + "/.well-known/terraform.json"
|
||||
if gotBaseURL != wantBaseURL {
|
||||
t.Errorf("incorrect base url %s; want %s", gotBaseURL, wantBaseURL)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func testServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// Test server always returns 404 if the URL isn't what we expect
|
||||
if r.URL.Path != "/.well-known/terraform.json" {
|
||||
w.WriteHeader(404)
|
||||
w.Write([]byte("not found"))
|
||||
return
|
||||
}
|
||||
|
||||
// If the URL is correct then the given hander decides the response
|
||||
h(w, r)
|
||||
},
|
||||
))
|
||||
|
||||
serverURL, _ := url.Parse(server.URL)
|
||||
|
||||
portStr = serverURL.Port()
|
||||
if portStr != "" {
|
||||
portStr = ":" + portStr
|
||||
}
|
||||
|
||||
close = func() {
|
||||
server.Close()
|
||||
}
|
||||
|
||||
return portStr, close
|
||||
}
|
|
@ -1,414 +0,0 @@
|
|||
package disco
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/hashicorp/terraform/httpclient"
|
||||
)
|
||||
|
||||
const versionServiceID = "versions.v1"
|
||||
|
||||
// Host represents a service discovered host.
|
||||
type Host struct {
|
||||
discoURL *url.URL
|
||||
hostname string
|
||||
services map[string]interface{}
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
// Constraints represents the version constraints of a service.
|
||||
type Constraints struct {
|
||||
Service string `json:"service"`
|
||||
Product string `json:"product"`
|
||||
Minimum string `json:"minimum"`
|
||||
Maximum string `json:"maximum"`
|
||||
Excluding []string `json:"excluding"`
|
||||
}
|
||||
|
||||
// ErrServiceNotProvided is returned when the service is not provided.
|
||||
type ErrServiceNotProvided struct {
|
||||
hostname string
|
||||
service string
|
||||
}
|
||||
|
||||
// Error returns a customized error message.
|
||||
func (e *ErrServiceNotProvided) Error() string {
|
||||
if e.hostname == "" {
|
||||
return fmt.Sprintf("host does not provide a %s service", e.service)
|
||||
}
|
||||
return fmt.Sprintf("host %s does not provide a %s service", e.hostname, e.service)
|
||||
}
|
||||
|
||||
// ErrVersionNotSupported is returned when the version is not supported.
|
||||
type ErrVersionNotSupported struct {
|
||||
hostname string
|
||||
service string
|
||||
version string
|
||||
}
|
||||
|
||||
// Error returns a customized error message.
|
||||
func (e *ErrVersionNotSupported) Error() string {
|
||||
if e.hostname == "" {
|
||||
return fmt.Sprintf("host does not support %s version %s", e.service, e.version)
|
||||
}
|
||||
return fmt.Sprintf("host %s does not support %s version %s", e.hostname, e.service, e.version)
|
||||
}
|
||||
|
||||
// ErrNoVersionConstraints is returned when checkpoint was disabled
|
||||
// or the endpoint to query for version constraints was unavailable.
|
||||
type ErrNoVersionConstraints struct {
|
||||
disabled bool
|
||||
}
|
||||
|
||||
// Error returns a customized error message.
|
||||
func (e *ErrNoVersionConstraints) Error() string {
|
||||
if e.disabled {
|
||||
return "checkpoint disabled"
|
||||
}
|
||||
return "unable to contact versions service"
|
||||
}
|
||||
|
||||
// ServiceURL returns the URL associated with the given service identifier,
|
||||
// which should be of the form "servicename.vN".
|
||||
//
|
||||
// A non-nil result is always an absolute URL with a scheme of either HTTPS
|
||||
// or HTTP.
|
||||
func (h *Host) ServiceURL(id string) (*url.URL, error) {
|
||||
svc, ver, err := parseServiceID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// No services supported for an empty Host.
|
||||
if h == nil || h.services == nil {
|
||||
return nil, &ErrServiceNotProvided{service: svc}
|
||||
}
|
||||
|
||||
urlStr, ok := h.services[id].(string)
|
||||
if !ok {
|
||||
// See if we have a matching service as that would indicate
|
||||
// the service is supported, but not the requested version.
|
||||
for serviceID := range h.services {
|
||||
if strings.HasPrefix(serviceID, svc+".") {
|
||||
return nil, &ErrVersionNotSupported{
|
||||
hostname: h.hostname,
|
||||
service: svc,
|
||||
version: ver.Original(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No discovered services match the requested service.
|
||||
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: svc}
|
||||
}
|
||||
|
||||
u, err := h.parseURL(urlStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to parse service URL: %v", err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// ServiceOAuthClient returns the OAuth client configuration associated with the
|
||||
// given service identifier, which should be of the form "servicename.vN".
|
||||
//
|
||||
// This is an alternative to ServiceURL for unusual services that require
|
||||
// a full OAuth2 client definition rather than just a URL. Use this only
|
||||
// for services whose specification calls for this sort of definition.
|
||||
func (h *Host) ServiceOAuthClient(id string) (*OAuthClient, error) {
|
||||
svc, ver, err := parseServiceID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// No services supported for an empty Host.
|
||||
if h == nil || h.services == nil {
|
||||
return nil, &ErrServiceNotProvided{service: svc}
|
||||
}
|
||||
|
||||
if _, ok := h.services[id]; !ok {
|
||||
// See if we have a matching service as that would indicate
|
||||
// the service is supported, but not the requested version.
|
||||
for serviceID := range h.services {
|
||||
if strings.HasPrefix(serviceID, svc+".") {
|
||||
return nil, &ErrVersionNotSupported{
|
||||
hostname: h.hostname,
|
||||
service: svc,
|
||||
version: ver.Original(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No discovered services match the requested service.
|
||||
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: svc}
|
||||
}
|
||||
|
||||
var raw map[string]interface{}
|
||||
switch v := h.services[id].(type) {
|
||||
case map[string]interface{}:
|
||||
raw = v // Great!
|
||||
case []map[string]interface{}:
|
||||
// An absolutely infuriating legacy HCL ambiguity.
|
||||
raw = v[0]
|
||||
default:
|
||||
// Debug message because raw Go types don't belong in our UI.
|
||||
log.Printf("[DEBUG] The definition for %s has Go type %T", id, h.services[id])
|
||||
return nil, fmt.Errorf("Service %s must be declared with an object value in the service discovery document", id)
|
||||
}
|
||||
|
||||
var grantTypes OAuthGrantTypeSet
|
||||
if rawGTs, ok := raw["grant_types"]; ok {
|
||||
if gts, ok := rawGTs.([]interface{}); ok {
|
||||
var kws []string
|
||||
for _, gtI := range gts {
|
||||
gt, ok := gtI.(string)
|
||||
if !ok {
|
||||
// We'll ignore this so that we can potentially introduce
|
||||
// other types into this array later if we need to.
|
||||
continue
|
||||
}
|
||||
kws = append(kws, gt)
|
||||
}
|
||||
grantTypes = NewOAuthGrantTypeSet(kws...)
|
||||
} else {
|
||||
return nil, fmt.Errorf("Service %s is defined with invalid grant_types property: must be an array of grant type strings", id)
|
||||
}
|
||||
} else {
|
||||
grantTypes = NewOAuthGrantTypeSet("authz_code")
|
||||
}
|
||||
|
||||
ret := &OAuthClient{
|
||||
SupportedGrantTypes: grantTypes,
|
||||
}
|
||||
if clientIDStr, ok := raw["client"].(string); ok {
|
||||
ret.ID = clientIDStr
|
||||
} else {
|
||||
return nil, fmt.Errorf("Service %s definition is missing required property \"client\"", id)
|
||||
}
|
||||
if urlStr, ok := raw["authz"].(string); ok {
|
||||
u, err := h.parseURL(urlStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to parse authorization URL: %v", err)
|
||||
}
|
||||
ret.AuthorizationURL = u
|
||||
} else {
|
||||
if grantTypes.RequiresAuthorizationEndpoint() {
|
||||
return nil, fmt.Errorf("Service %s definition is missing required property \"authz\"", id)
|
||||
}
|
||||
}
|
||||
if urlStr, ok := raw["token"].(string); ok {
|
||||
u, err := h.parseURL(urlStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to parse token URL: %v", err)
|
||||
}
|
||||
ret.TokenURL = u
|
||||
} else {
|
||||
if grantTypes.RequiresTokenEndpoint() {
|
||||
return nil, fmt.Errorf("Service %s definition is missing required property \"token\"", id)
|
||||
}
|
||||
}
|
||||
if portsRaw, ok := raw["ports"].([]interface{}); ok {
|
||||
if len(portsRaw) != 2 {
|
||||
return nil, fmt.Errorf("Invalid \"ports\" definition for service %s: must be a two-element array", id)
|
||||
}
|
||||
invalidPortsErr := fmt.Errorf("Invalid \"ports\" definition for service %s: both ports must be whole numbers between 1024 and 65535", id)
|
||||
ports := make([]uint16, 2)
|
||||
for i := range ports {
|
||||
switch v := portsRaw[i].(type) {
|
||||
case float64:
|
||||
// JSON unmarshaling always produces float64. HCL 2 might, if
|
||||
// an invalid fractional number were given.
|
||||
if float64(uint16(v)) != v || v < 1024 {
|
||||
return nil, invalidPortsErr
|
||||
}
|
||||
ports[i] = uint16(v)
|
||||
case int:
|
||||
// Legacy HCL produces int. HCL 2 will too, if the given number
|
||||
// is a whole number.
|
||||
if v < 1024 || v > 65535 {
|
||||
return nil, invalidPortsErr
|
||||
}
|
||||
ports[i] = uint16(v)
|
||||
default:
|
||||
// Debug message because raw Go types don't belong in our UI.
|
||||
log.Printf("[DEBUG] Port value %d has Go type %T", i, portsRaw[i])
|
||||
return nil, invalidPortsErr
|
||||
}
|
||||
}
|
||||
if ports[1] < ports[0] {
|
||||
return nil, fmt.Errorf("Invalid \"ports\" definition for service %s: minimum port cannot be greater than maximum port", id)
|
||||
}
|
||||
ret.MinPort = ports[0]
|
||||
ret.MaxPort = ports[1]
|
||||
} else {
|
||||
// Default is to accept any port in the range, for a client that is
|
||||
// able to call back to any localhost port.
|
||||
ret.MinPort = 1024
|
||||
ret.MaxPort = 65535
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (h *Host) parseURL(urlStr string) (*url.URL, error) {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make relative URLs absolute using our discovery URL.
|
||||
if !u.IsAbs() {
|
||||
u = h.discoURL.ResolveReference(u)
|
||||
}
|
||||
|
||||
if u.Scheme != "https" && u.Scheme != "http" {
|
||||
return nil, fmt.Errorf("unsupported scheme %s", u.Scheme)
|
||||
}
|
||||
if u.User != nil {
|
||||
return nil, fmt.Errorf("embedded username/password information is not permitted")
|
||||
}
|
||||
|
||||
// Fragment part is irrelevant, since we're not a browser.
|
||||
u.Fragment = ""
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// VersionConstraints returns the contraints for a given service identifier
|
||||
// (which should be of the form "servicename.vN") and product.
|
||||
//
|
||||
// When an exact (service and version) match is found, the constraints for
|
||||
// that service are returned.
|
||||
//
|
||||
// When the requested version is not provided but the service is, we will
|
||||
// search for all alternative versions. If mutliple alternative versions
|
||||
// are found, the contrains of the latest available version are returned.
|
||||
//
|
||||
// When a service is not provided at all an error will be returned instead.
|
||||
//
|
||||
// When checkpoint is disabled or when a 404 is returned after making the
|
||||
// HTTP call, an ErrNoVersionConstraints error will be returned.
|
||||
func (h *Host) VersionConstraints(id, product string) (*Constraints, error) {
|
||||
svc, _, err := parseServiceID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return early if checkpoint is disabled.
|
||||
if disabled := os.Getenv("CHECKPOINT_DISABLE"); disabled != "" {
|
||||
return nil, &ErrNoVersionConstraints{disabled: true}
|
||||
}
|
||||
|
||||
// No services supported for an empty Host.
|
||||
if h == nil || h.services == nil {
|
||||
return nil, &ErrServiceNotProvided{service: svc}
|
||||
}
|
||||
|
||||
// Try to get the service URL for the version service and
|
||||
// return early if the service isn't provided by the host.
|
||||
u, err := h.ServiceURL(versionServiceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if we have an exact (service and version) match.
|
||||
if _, ok := h.services[id].(string); !ok {
|
||||
// If we don't have an exact match, we search for all matching
|
||||
// services and then use the service ID of the latest version.
|
||||
var services []string
|
||||
for serviceID := range h.services {
|
||||
if strings.HasPrefix(serviceID, svc+".") {
|
||||
services = append(services, serviceID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(services) == 0 {
|
||||
// No discovered services match the requested service.
|
||||
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: svc}
|
||||
}
|
||||
|
||||
// Set id to the latest service ID we found.
|
||||
var latest *version.Version
|
||||
for _, serviceID := range services {
|
||||
if _, ver, err := parseServiceID(serviceID); err == nil {
|
||||
if latest == nil || latest.LessThan(ver) {
|
||||
id = serviceID
|
||||
latest = ver
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set a default timeout of 1 sec for the versions request (in milliseconds)
|
||||
timeout := 1000
|
||||
if v, err := strconv.Atoi(os.Getenv("CHECKPOINT_TIMEOUT")); err == nil {
|
||||
timeout = v
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: h.transport,
|
||||
Timeout: time.Duration(timeout) * time.Millisecond,
|
||||
}
|
||||
|
||||
// Prepare the service URL by setting the service and product.
|
||||
v := u.Query()
|
||||
v.Set("product", product)
|
||||
u.Path += id
|
||||
u.RawQuery = v.Encode()
|
||||
|
||||
// Create a new request.
|
||||
req, err := http.NewRequest("GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to create version constraints request: %v", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", httpclient.UserAgentString())
|
||||
|
||||
log.Printf("[DEBUG] Retrieve version constraints for service %s and product %s", id, product)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to request version constraints: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 404 {
|
||||
return nil, &ErrNoVersionConstraints{disabled: false}
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("Failed to request version constraints: %s", resp.Status)
|
||||
}
|
||||
|
||||
// Parse the constraints from the response body.
|
||||
result := &Constraints{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
|
||||
return nil, fmt.Errorf("Error parsing version constraints: %v", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func parseServiceID(id string) (string, *version.Version, error) {
|
||||
parts := strings.SplitN(id, ".", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", nil, fmt.Errorf("Invalid service ID format (i.e. service.vN): %s", id)
|
||||
}
|
||||
|
||||
version, err := version.NewVersion(parts[1])
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("Invalid service version: %v", err)
|
||||
}
|
||||
|
||||
return parts[0], version, nil
|
||||
}
|
|
@ -1,528 +0,0 @@
|
|||
package disco
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestHostServiceURL(t *testing.T) {
|
||||
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
services: map[string]interface{}{
|
||||
"absolute.v1": "http://example.net/foo/bar",
|
||||
"absolutewithport.v1": "http://example.net:8080/foo/bar",
|
||||
"relative.v1": "./stu/",
|
||||
"rootrelative.v1": "/baz",
|
||||
"protorelative.v1": "//example.net/",
|
||||
"withfragment.v1": "http://example.org/#foo",
|
||||
"querystring.v1": "https://example.net/baz?foo=bar",
|
||||
"nothttp.v1": "ftp://127.0.0.1/pub/",
|
||||
"invalid.v1": "***not A URL at all!:/<@@@@>***",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ID string
|
||||
want string
|
||||
err string
|
||||
}{
|
||||
{"absolute.v1", "http://example.net/foo/bar", ""},
|
||||
{"absolutewithport.v1", "http://example.net:8080/foo/bar", ""},
|
||||
{"relative.v1", "https://example.com/disco/stu/", ""},
|
||||
{"rootrelative.v1", "https://example.com/baz", ""},
|
||||
{"protorelative.v1", "https://example.net/", ""},
|
||||
{"withfragment.v1", "http://example.org/", ""},
|
||||
{"querystring.v1", "https://example.net/baz?foo=bar", ""},
|
||||
{"nothttp.v1", "<nil>", "unsupported scheme"},
|
||||
{"invalid.v1", "<nil>", "Failed to parse service URL"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.ID, func(t *testing.T) {
|
||||
url, err := host.ServiceURL(test.ID)
|
||||
if (err != nil || test.err != "") &&
|
||||
(err == nil || !strings.Contains(err.Error(), test.err)) {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
|
||||
var got string
|
||||
if url != nil {
|
||||
got = url.String()
|
||||
} else {
|
||||
got = "<nil>"
|
||||
}
|
||||
|
||||
if got != test.want {
|
||||
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostServiceOAuthClient(t *testing.T) {
|
||||
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
services: map[string]interface{}{
|
||||
"explicitgranttype.v1": map[string]interface{}{
|
||||
"client": "explicitgranttype",
|
||||
"authz": "./authz",
|
||||
"token": "./token",
|
||||
"grant_types": []interface{}{"authz_code", "password", "tbd"},
|
||||
},
|
||||
"customports.v1": map[string]interface{}{
|
||||
"client": "customports",
|
||||
"authz": "./authz",
|
||||
"token": "./token",
|
||||
"ports": []interface{}{1025, 1026},
|
||||
},
|
||||
"invalidports.v1": map[string]interface{}{
|
||||
"client": "invalidports",
|
||||
"authz": "./authz",
|
||||
"token": "./token",
|
||||
"ports": []interface{}{1, 65535},
|
||||
},
|
||||
"missingauthz.v1": map[string]interface{}{
|
||||
"client": "missingauthz",
|
||||
"token": "./token",
|
||||
},
|
||||
"missingtoken.v1": map[string]interface{}{
|
||||
"client": "missingtoken",
|
||||
"authz": "./authz",
|
||||
},
|
||||
"passwordmissingauthz.v1": map[string]interface{}{
|
||||
"client": "passwordmissingauthz",
|
||||
"token": "./token",
|
||||
"grant_types": []interface{}{"password"},
|
||||
},
|
||||
"absolute.v1": map[string]interface{}{
|
||||
"client": "absolute",
|
||||
"authz": "http://example.net/foo/authz",
|
||||
"token": "http://example.net/foo/token",
|
||||
},
|
||||
"absolutewithport.v1": map[string]interface{}{
|
||||
"client": "absolutewithport",
|
||||
"authz": "http://example.net:8000/foo/authz",
|
||||
"token": "http://example.net:8000/foo/token",
|
||||
},
|
||||
"relative.v1": map[string]interface{}{
|
||||
"client": "relative",
|
||||
"authz": "./authz",
|
||||
"token": "./token",
|
||||
},
|
||||
"rootrelative.v1": map[string]interface{}{
|
||||
"client": "rootrelative",
|
||||
"authz": "/authz",
|
||||
"token": "/token",
|
||||
},
|
||||
"protorelative.v1": map[string]interface{}{
|
||||
"client": "protorelative",
|
||||
"authz": "//example.net/authz",
|
||||
"token": "//example.net/token",
|
||||
},
|
||||
"nothttp.v1": map[string]interface{}{
|
||||
"client": "nothttp",
|
||||
"authz": "ftp://127.0.0.1/pub/authz",
|
||||
"token": "ftp://127.0.0.1/pub/token",
|
||||
},
|
||||
"invalidauthz.v1": map[string]interface{}{
|
||||
"client": "invalidauthz",
|
||||
"authz": "***not A URL at all!:/<@@@@>***",
|
||||
"token": "/foo",
|
||||
},
|
||||
"invalidtoken.v1": map[string]interface{}{
|
||||
"client": "invalidauthz",
|
||||
"authz": "/foo",
|
||||
"token": "***not A URL at all!:/<@@@@>***",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mustURL := func(t *testing.T, s string) *url.URL {
|
||||
t.Helper()
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid wanted URL %s in test case: %s", s, err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ID string
|
||||
want *OAuthClient
|
||||
err string
|
||||
}{
|
||||
{
|
||||
"explicitgranttype.v1",
|
||||
&OAuthClient{
|
||||
ID: "explicitgranttype",
|
||||
AuthorizationURL: mustURL(t, "https://example.com/disco/authz"),
|
||||
TokenURL: mustURL(t, "https://example.com/disco/token"),
|
||||
MinPort: 1024,
|
||||
MaxPort: 65535,
|
||||
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code", "password", "tbd"),
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"customports.v1",
|
||||
&OAuthClient{
|
||||
ID: "customports",
|
||||
AuthorizationURL: mustURL(t, "https://example.com/disco/authz"),
|
||||
TokenURL: mustURL(t, "https://example.com/disco/token"),
|
||||
MinPort: 1025,
|
||||
MaxPort: 1026,
|
||||
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalidports.v1",
|
||||
nil,
|
||||
`Invalid "ports" definition for service invalidports.v1: both ports must be whole numbers between 1024 and 65535`,
|
||||
},
|
||||
{
|
||||
"missingauthz.v1",
|
||||
nil,
|
||||
`Service missingauthz.v1 definition is missing required property "authz"`,
|
||||
},
|
||||
{
|
||||
"missingtoken.v1",
|
||||
nil,
|
||||
`Service missingtoken.v1 definition is missing required property "token"`,
|
||||
},
|
||||
{
|
||||
"passwordmissingauthz.v1",
|
||||
&OAuthClient{
|
||||
ID: "passwordmissingauthz",
|
||||
TokenURL: mustURL(t, "https://example.com/disco/token"),
|
||||
MinPort: 1024,
|
||||
MaxPort: 65535,
|
||||
SupportedGrantTypes: NewOAuthGrantTypeSet("password"),
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"absolute.v1",
|
||||
&OAuthClient{
|
||||
ID: "absolute",
|
||||
AuthorizationURL: mustURL(t, "http://example.net/foo/authz"),
|
||||
TokenURL: mustURL(t, "http://example.net/foo/token"),
|
||||
MinPort: 1024,
|
||||
MaxPort: 65535,
|
||||
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"absolutewithport.v1",
|
||||
&OAuthClient{
|
||||
ID: "absolutewithport",
|
||||
AuthorizationURL: mustURL(t, "http://example.net:8000/foo/authz"),
|
||||
TokenURL: mustURL(t, "http://example.net:8000/foo/token"),
|
||||
MinPort: 1024,
|
||||
MaxPort: 65535,
|
||||
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"relative.v1",
|
||||
&OAuthClient{
|
||||
ID: "relative",
|
||||
AuthorizationURL: mustURL(t, "https://example.com/disco/authz"),
|
||||
TokenURL: mustURL(t, "https://example.com/disco/token"),
|
||||
MinPort: 1024,
|
||||
MaxPort: 65535,
|
||||
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"rootrelative.v1",
|
||||
&OAuthClient{
|
||||
ID: "rootrelative",
|
||||
AuthorizationURL: mustURL(t, "https://example.com/authz"),
|
||||
TokenURL: mustURL(t, "https://example.com/token"),
|
||||
MinPort: 1024,
|
||||
MaxPort: 65535,
|
||||
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"protorelative.v1",
|
||||
&OAuthClient{
|
||||
ID: "protorelative",
|
||||
AuthorizationURL: mustURL(t, "https://example.net/authz"),
|
||||
TokenURL: mustURL(t, "https://example.net/token"),
|
||||
MinPort: 1024,
|
||||
MaxPort: 65535,
|
||||
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"nothttp.v1",
|
||||
nil,
|
||||
"Failed to parse authorization URL: unsupported scheme ftp",
|
||||
},
|
||||
{
|
||||
"invalidauthz.v1",
|
||||
nil,
|
||||
"Failed to parse authorization URL: parse ***not A URL at all!:/<@@@@>***: first path segment in URL cannot contain colon",
|
||||
},
|
||||
{
|
||||
"invalidtoken.v1",
|
||||
nil,
|
||||
"Failed to parse token URL: parse ***not A URL at all!:/<@@@@>***: first path segment in URL cannot contain colon",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.ID, func(t *testing.T) {
|
||||
got, err := host.ServiceOAuthClient(test.ID)
|
||||
if (err != nil || test.err != "") &&
|
||||
(err == nil || !strings.Contains(err.Error(), test.err)) {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(test.want, got); diff != "" {
|
||||
t.Errorf("wrong result\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionConstrains(t *testing.T) {
|
||||
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
|
||||
|
||||
t.Run("exact service version is provided", func(t *testing.T) {
|
||||
portStr, close := testVersionsServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`
|
||||
{
|
||||
"service": "%s",
|
||||
"product": "%s",
|
||||
"minimum": "0.11.8",
|
||||
"maximum": "0.12.0"
|
||||
}`)
|
||||
// Add the requested service and product to the response.
|
||||
service := path.Base(r.URL.Path)
|
||||
product := r.URL.Query().Get("product")
|
||||
resp = []byte(fmt.Sprintf(string(resp), service, product))
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
transport: httpTransport,
|
||||
services: map[string]interface{}{
|
||||
"thingy.v1": "/api/v1/",
|
||||
"thingy.v2": "/api/v2/",
|
||||
"versions.v1": "https://localhost" + portStr + "/v1/versions/",
|
||||
},
|
||||
}
|
||||
|
||||
expected := &Constraints{
|
||||
Service: "thingy.v1",
|
||||
Product: "terraform",
|
||||
Minimum: "0.11.8",
|
||||
Maximum: "0.12.0",
|
||||
}
|
||||
|
||||
actual, err := host.VersionConstraints("thingy.v1", "terraform")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected version constraints error: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("expected %#v, got: %#v", expected, actual)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("service provided with different versions", func(t *testing.T) {
|
||||
portStr, close := testVersionsServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`
|
||||
{
|
||||
"service": "%s",
|
||||
"product": "%s",
|
||||
"minimum": "0.11.8",
|
||||
"maximum": "0.12.0"
|
||||
}`)
|
||||
// Add the requested service and product to the response.
|
||||
service := path.Base(r.URL.Path)
|
||||
product := r.URL.Query().Get("product")
|
||||
resp = []byte(fmt.Sprintf(string(resp), service, product))
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
transport: httpTransport,
|
||||
services: map[string]interface{}{
|
||||
"thingy.v2": "/api/v2/",
|
||||
"thingy.v3": "/api/v3/",
|
||||
"versions.v1": "https://localhost" + portStr + "/v1/versions/",
|
||||
},
|
||||
}
|
||||
|
||||
expected := &Constraints{
|
||||
Service: "thingy.v3",
|
||||
Product: "terraform",
|
||||
Minimum: "0.11.8",
|
||||
Maximum: "0.12.0",
|
||||
}
|
||||
|
||||
actual, err := host.VersionConstraints("thingy.v1", "terraform")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected version constraints error: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("expected %#v, got: %#v", expected, actual)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("service not provided", func(t *testing.T) {
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
transport: httpTransport,
|
||||
services: map[string]interface{}{
|
||||
"versions.v1": "https://localhost/v1/versions/",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := host.VersionConstraints("thingy.v1", "terraform")
|
||||
if _, ok := err.(*ErrServiceNotProvided); !ok {
|
||||
t.Fatalf("expected service not provided error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("versions service returns a 404", func(t *testing.T) {
|
||||
portStr, close := testVersionsServer(nil)
|
||||
defer close()
|
||||
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
transport: httpTransport,
|
||||
services: map[string]interface{}{
|
||||
"thingy.v1": "/api/v1/",
|
||||
"versions.v1": "https://localhost" + portStr + "/v1/non-existent/",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := host.VersionConstraints("thingy.v1", "terraform")
|
||||
if _, ok := err.(*ErrNoVersionConstraints); !ok {
|
||||
t.Fatalf("expected service not provided error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("checkpoint is disabled", func(t *testing.T) {
|
||||
if err := os.Setenv("CHECKPOINT_DISABLE", "1"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer os.Unsetenv("CHECKPOINT_DISABLE")
|
||||
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
transport: httpTransport,
|
||||
services: map[string]interface{}{
|
||||
"thingy.v1": "/api/v1/",
|
||||
"versions.v1": "https://localhost/v1/versions/",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := host.VersionConstraints("thingy.v1", "terraform")
|
||||
if _, ok := err.(*ErrNoVersionConstraints); !ok {
|
||||
t.Fatalf("expected service not provided error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("versions service not discovered", func(t *testing.T) {
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
transport: httpTransport,
|
||||
services: map[string]interface{}{
|
||||
"thingy.v1": "/api/v1/",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := host.VersionConstraints("thingy.v1", "terraform")
|
||||
if _, ok := err.(*ErrServiceNotProvided); !ok {
|
||||
t.Fatalf("expected service not provided error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("versions service version not discovered", func(t *testing.T) {
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
transport: httpTransport,
|
||||
services: map[string]interface{}{
|
||||
"thingy.v1": "/api/v1/",
|
||||
"versions.v2": "https://localhost/v2/versions/",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := host.VersionConstraints("thingy.v1", "terraform")
|
||||
if _, ok := err.(*ErrVersionNotSupported); !ok {
|
||||
t.Fatalf("expected service not provided error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func testVersionsServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// Test server always returns 404 if the URL isn't what we expect
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/versions/") {
|
||||
w.WriteHeader(404)
|
||||
w.Write([]byte("not found"))
|
||||
return
|
||||
}
|
||||
|
||||
// If the URL is correct then the given hander decides the response
|
||||
h(w, r)
|
||||
},
|
||||
))
|
||||
|
||||
serverURL, _ := url.Parse(server.URL)
|
||||
|
||||
portStr = serverURL.Port()
|
||||
if portStr != "" {
|
||||
portStr = ":" + portStr
|
||||
}
|
||||
|
||||
close = func() {
|
||||
server.Close()
|
||||
}
|
||||
|
||||
return portStr, close
|
||||
}
|
|
@ -1,178 +0,0 @@
|
|||
package disco
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuthClient represents an OAuth client configuration, which is used for
|
||||
// unusual services that require an entire OAuth client configuration as part
|
||||
// of their service discovery, rather than just a URL.
|
||||
type OAuthClient struct {
|
||||
// ID is the identifier for the client, to be used as "client_id" in
|
||||
// OAuth requests.
|
||||
ID string
|
||||
|
||||
// Authorization URL is the URL of the authorization endpoint that must
|
||||
// be used for this OAuth client, as defined in the OAuth2 specifications.
|
||||
//
|
||||
// Not all grant types use the authorization endpoint, so it may be omitted
|
||||
// if none of the grant types in SupportedGrantTypes require it.
|
||||
AuthorizationURL *url.URL
|
||||
|
||||
// Token URL is the URL of the token endpoint that must be used for this
|
||||
// OAuth client, as defined in the OAuth2 specifications.
|
||||
//
|
||||
// Not all grant types use the token endpoint, so it may be omitted
|
||||
// if none of the grant types in SupportedGrantTypes require it.
|
||||
TokenURL *url.URL
|
||||
|
||||
// MinPort and MaxPort define a range of TCP ports on localhost that this
|
||||
// client is able to use as redirect_uri in an authorization request.
|
||||
// Terraform will select a port from this range for the temporary HTTP
|
||||
// server it creates to receive the authorization response, giving
|
||||
// a URL like http://localhost:NNN/ where NNN is the selected port number.
|
||||
//
|
||||
// Terraform will reject any port numbers in this range less than 1024,
|
||||
// to respect the common convention (enforced on some operating systems)
|
||||
// that lower port numbers are reserved for "privileged" services.
|
||||
MinPort, MaxPort uint16
|
||||
|
||||
// SupportedGrantTypes is a set of the grant types that the client may
|
||||
// choose from. This includes an entry for each distinct type advertised
|
||||
// by the server, even if a particular keyword is not supported by the
|
||||
// current version of Terraform.
|
||||
SupportedGrantTypes OAuthGrantTypeSet
|
||||
}
|
||||
|
||||
// Endpoint returns an oauth2.Endpoint value ready to be used with the oauth2
|
||||
// library, representing the URLs from the receiver.
|
||||
func (c *OAuthClient) Endpoint() oauth2.Endpoint {
|
||||
ep := oauth2.Endpoint{
|
||||
// We don't actually auth because we're not a server-based OAuth client,
|
||||
// so this instead just means that we include client_id as an argument
|
||||
// in our requests.
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
}
|
||||
|
||||
if c.AuthorizationURL != nil {
|
||||
ep.AuthURL = c.AuthorizationURL.String()
|
||||
}
|
||||
if c.TokenURL != nil {
|
||||
ep.TokenURL = c.TokenURL.String()
|
||||
}
|
||||
|
||||
return ep
|
||||
}
|
||||
|
||||
// OAuthGrantType is an enumeration of grant type strings that a host can
|
||||
// advertise support for.
|
||||
//
|
||||
// Values of this type don't necessarily match with a known constant of the
|
||||
// type, because they may represent grant type keywords defined in a later
|
||||
// version of Terraform which this version doesn't yet know about.
|
||||
type OAuthGrantType string
|
||||
|
||||
const (
|
||||
// OAuthAuthzCodeGrant represents an authorization code grant, as
|
||||
// defined in IETF RFC 6749 section 4.1.
|
||||
OAuthAuthzCodeGrant = OAuthGrantType("authz_code")
|
||||
|
||||
// OAuthOwnerPasswordGrant represents a resource owner password
|
||||
// credentials grant, as defined in IETF RFC 6749 section 4.3.
|
||||
OAuthOwnerPasswordGrant = OAuthGrantType("password")
|
||||
)
|
||||
|
||||
// UsesAuthorizationEndpoint returns true if the receiving grant type makes
|
||||
// use of the authorization endpoint from the client configuration, and thus
|
||||
// if the authorization endpoint ought to be required.
|
||||
func (t OAuthGrantType) UsesAuthorizationEndpoint() bool {
|
||||
switch t {
|
||||
case OAuthAuthzCodeGrant:
|
||||
return true
|
||||
case OAuthOwnerPasswordGrant:
|
||||
return false
|
||||
default:
|
||||
// We'll default to false so that we don't impose any requirements
|
||||
// on any grant type keywords that might be defined for future
|
||||
// versions of Terraform.
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// UsesTokenEndpoint returns true if the receiving grant type makes
|
||||
// use of the token endpoint from the client configuration, and thus
|
||||
// if the authorization endpoint ought to be required.
|
||||
func (t OAuthGrantType) UsesTokenEndpoint() bool {
|
||||
switch t {
|
||||
case OAuthAuthzCodeGrant:
|
||||
return true
|
||||
case OAuthOwnerPasswordGrant:
|
||||
return true
|
||||
default:
|
||||
// We'll default to false so that we don't impose any requirements
|
||||
// on any grant type keywords that might be defined for future
|
||||
// versions of Terraform.
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// OAuthGrantTypeSet represents a set of OAuthGrantType values.
|
||||
type OAuthGrantTypeSet map[OAuthGrantType]struct{}
|
||||
|
||||
// NewOAuthGrantTypeSet constructs a new grant type set from the given list
|
||||
// of grant type keyword strings. Any duplicates in the list are ignored.
|
||||
func NewOAuthGrantTypeSet(keywords ...string) OAuthGrantTypeSet {
|
||||
ret := make(OAuthGrantTypeSet, len(keywords))
|
||||
for _, kw := range keywords {
|
||||
ret[OAuthGrantType(kw)] = struct{}{}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Has returns true if the given grant type is in the receiving set.
|
||||
func (s OAuthGrantTypeSet) Has(t OAuthGrantType) bool {
|
||||
_, ok := s[t]
|
||||
return ok
|
||||
}
|
||||
|
||||
// RequiresAuthorizationEndpoint returns true if any of the grant types in
|
||||
// the set are known to require an authorization endpoint.
|
||||
func (s OAuthGrantTypeSet) RequiresAuthorizationEndpoint() bool {
|
||||
for t := range s {
|
||||
if t.UsesAuthorizationEndpoint() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RequiresTokenEndpoint returns true if any of the grant types in
|
||||
// the set are known to require a token endpoint.
|
||||
func (s OAuthGrantTypeSet) RequiresTokenEndpoint() bool {
|
||||
for t := range s {
|
||||
if t.UsesTokenEndpoint() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GoString implements fmt.GoStringer.
|
||||
func (s OAuthGrantTypeSet) GoString() string {
|
||||
var buf strings.Builder
|
||||
i := 0
|
||||
buf.WriteString("disco.NewOAuthGrantTypeSet(")
|
||||
for t := range s {
|
||||
if i > 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
fmt.Fprintf(&buf, "%q", string(t))
|
||||
i++
|
||||
}
|
||||
buf.WriteString(")")
|
||||
return buf.String()
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
package svchost
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A labelIter allows iterating over domain name labels.
|
||||
//
|
||||
// This type is copied from golang.org/x/net/idna, where it is used
|
||||
// to segment hostnames into their separate labels for analysis. We use
|
||||
// it for the same purpose here, in ForComparison.
|
||||
type labelIter struct {
|
||||
orig string
|
||||
slice []string
|
||||
curStart int
|
||||
curEnd int
|
||||
i int
|
||||
}
|
||||
|
||||
func (l *labelIter) reset() {
|
||||
l.curStart = 0
|
||||
l.curEnd = 0
|
||||
l.i = 0
|
||||
}
|
||||
|
||||
func (l *labelIter) done() bool {
|
||||
return l.curStart >= len(l.orig)
|
||||
}
|
||||
|
||||
func (l *labelIter) result() string {
|
||||
if l.slice != nil {
|
||||
return strings.Join(l.slice, ".")
|
||||
}
|
||||
return l.orig
|
||||
}
|
||||
|
||||
func (l *labelIter) label() string {
|
||||
if l.slice != nil {
|
||||
return l.slice[l.i]
|
||||
}
|
||||
p := strings.IndexByte(l.orig[l.curStart:], '.')
|
||||
l.curEnd = l.curStart + p
|
||||
if p == -1 {
|
||||
l.curEnd = len(l.orig)
|
||||
}
|
||||
return l.orig[l.curStart:l.curEnd]
|
||||
}
|
||||
|
||||
// next sets the value to the next label. It skips the last label if it is empty.
|
||||
func (l *labelIter) next() {
|
||||
l.i++
|
||||
if l.slice != nil {
|
||||
if l.i >= len(l.slice) || l.i == len(l.slice)-1 && l.slice[l.i] == "" {
|
||||
l.curStart = len(l.orig)
|
||||
}
|
||||
} else {
|
||||
l.curStart = l.curEnd + 1
|
||||
if l.curStart == len(l.orig)-1 && l.orig[l.curStart] == '.' {
|
||||
l.curStart = len(l.orig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *labelIter) set(s string) {
|
||||
if l.slice == nil {
|
||||
l.slice = strings.Split(l.orig, ".")
|
||||
}
|
||||
l.slice[l.i] = s
|
||||
}
|
|
@ -1,207 +0,0 @@
|
|||
// Package svchost deals with the representations of the so-called "friendly
|
||||
// hostnames" that we use to represent systems that provide Terraform-native
|
||||
// remote services, such as module registry, remote operations, etc.
|
||||
//
|
||||
// Friendly hostnames are specified such that, as much as possible, they
|
||||
// are consistent with how web browsers think of hostnames, so that users
|
||||
// can bring their intuitions about how hostnames behave when they access
|
||||
// a Terraform Enterprise instance's web UI (or indeed any other website)
|
||||
// and have this behave in a similar way.
|
||||
package svchost
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
// Hostname is specialized name for string that indicates that the string
|
||||
// has been converted to (or was already in) the storage and comparison form.
|
||||
//
|
||||
// Hostname values are not suitable for display in the user-interface. Use
|
||||
// the ForDisplay method to obtain a form suitable for display in the UI.
|
||||
//
|
||||
// Unlike user-supplied hostnames, strings of type Hostname (assuming they
|
||||
// were constructed by a function within this package) can be compared for
|
||||
// equality using the standard Go == operator.
|
||||
type Hostname string
|
||||
|
||||
// acePrefix is the ASCII Compatible Encoding prefix, used to indicate that
|
||||
// a domain name label is in "punycode" form.
|
||||
const acePrefix = "xn--"
|
||||
|
||||
// displayProfile is a very liberal idna profile that we use to do
|
||||
// normalization for display without imposing validation rules.
|
||||
var displayProfile = idna.New(
|
||||
idna.MapForLookup(),
|
||||
idna.Transitional(true),
|
||||
)
|
||||
|
||||
// ForDisplay takes a user-specified hostname and returns a normalized form of
|
||||
// it suitable for display in the UI.
|
||||
//
|
||||
// If the input is so invalid that no normalization can be performed then
|
||||
// this will return the input, assuming that the caller still wants to
|
||||
// display _something_. This function is, however, more tolerant than the
|
||||
// other functions in this package and will make a best effort to prepare
|
||||
// _any_ given hostname for display.
|
||||
//
|
||||
// For validation, use either IsValid (for explicit validation) or
|
||||
// ForComparison (which implicitly validates, returning an error if invalid).
|
||||
func ForDisplay(given string) string {
|
||||
var portPortion string
|
||||
if colonPos := strings.Index(given, ":"); colonPos != -1 {
|
||||
given, portPortion = given[:colonPos], given[colonPos:]
|
||||
}
|
||||
portPortion, _ = normalizePortPortion(portPortion)
|
||||
|
||||
ascii, err := displayProfile.ToASCII(given)
|
||||
if err != nil {
|
||||
return given + portPortion
|
||||
}
|
||||
display, err := displayProfile.ToUnicode(ascii)
|
||||
if err != nil {
|
||||
return given + portPortion
|
||||
}
|
||||
return display + portPortion
|
||||
}
|
||||
|
||||
// IsValid returns true if the given user-specified hostname is a valid
|
||||
// service hostname.
|
||||
//
|
||||
// Validity is determined by complying with the RFC 5891 requirements for
|
||||
// names that are valid for domain lookup (section 5), with the additional
|
||||
// requirement that user-supplied forms must not _already_ contain
|
||||
// Punycode segments.
|
||||
func IsValid(given string) bool {
|
||||
_, err := ForComparison(given)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ForComparison takes a user-specified hostname and returns a normalized
|
||||
// form of it suitable for storage and comparison. The result is not suitable
|
||||
// for display to end-users because it uses Punycode to represent non-ASCII
|
||||
// characters, and this form is unreadable for non-ASCII-speaking humans.
|
||||
//
|
||||
// The result is typed as Hostname -- a specialized name for string -- so that
|
||||
// other APIs can make it clear within the type system whether they expect a
|
||||
// user-specified or display-form hostname or a value already normalized for
|
||||
// comparison.
|
||||
//
|
||||
// The returned Hostname is not valid if the returned error is non-nil.
|
||||
func ForComparison(given string) (Hostname, error) {
|
||||
var portPortion string
|
||||
if colonPos := strings.Index(given, ":"); colonPos != -1 {
|
||||
given, portPortion = given[:colonPos], given[colonPos:]
|
||||
}
|
||||
|
||||
var err error
|
||||
portPortion, err = normalizePortPortion(portPortion)
|
||||
if err != nil {
|
||||
return Hostname(""), err
|
||||
}
|
||||
|
||||
if given == "" {
|
||||
return Hostname(""), fmt.Errorf("empty string is not a valid hostname")
|
||||
}
|
||||
|
||||
// First we'll apply our additional constraint that Punycode must not
|
||||
// be given directly by the user. This is not an IDN specification
|
||||
// requirement, but we prohibit it to force users to use human-readable
|
||||
// hostname forms within Terraform configuration.
|
||||
labels := labelIter{orig: given}
|
||||
for ; !labels.done(); labels.next() {
|
||||
label := labels.label()
|
||||
if label == "" {
|
||||
return Hostname(""), fmt.Errorf(
|
||||
"hostname contains empty label (two consecutive periods)",
|
||||
)
|
||||
}
|
||||
if strings.HasPrefix(label, acePrefix) {
|
||||
return Hostname(""), fmt.Errorf(
|
||||
"hostname label %q specified in punycode format; service hostnames must be given in unicode",
|
||||
label,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
result, err := idna.Lookup.ToASCII(given)
|
||||
if err != nil {
|
||||
return Hostname(""), err
|
||||
}
|
||||
return Hostname(result + portPortion), nil
|
||||
}
|
||||
|
||||
// ForDisplay returns a version of the receiver that is appropriate for display
|
||||
// in the UI. This includes converting any punycode labels to their
|
||||
// corresponding Unicode characters.
|
||||
//
|
||||
// A round-trip through ForComparison and this ForDisplay method does not
|
||||
// guarantee the same result as calling this package's top-level ForDisplay
|
||||
// function, since a round-trip through the Hostname type implies stricter
|
||||
// handling than we do when doing basic display-only processing.
|
||||
func (h Hostname) ForDisplay() string {
|
||||
given := string(h)
|
||||
var portPortion string
|
||||
if colonPos := strings.Index(given, ":"); colonPos != -1 {
|
||||
given, portPortion = given[:colonPos], given[colonPos:]
|
||||
}
|
||||
// We don't normalize the port portion here because we assume it's
|
||||
// already been normalized on the way in.
|
||||
|
||||
result, err := idna.Lookup.ToUnicode(given)
|
||||
if err != nil {
|
||||
// Should never happen, since type Hostname indicates that a string
|
||||
// passed through our validation rules.
|
||||
panic(fmt.Errorf("ForDisplay called on invalid Hostname: %s", err))
|
||||
}
|
||||
return result + portPortion
|
||||
}
|
||||
|
||||
func (h Hostname) String() string {
|
||||
return string(h)
|
||||
}
|
||||
|
||||
func (h Hostname) GoString() string {
|
||||
return fmt.Sprintf("svchost.Hostname(%q)", string(h))
|
||||
}
|
||||
|
||||
// normalizePortPortion attempts to normalize the "port portion" of a hostname,
|
||||
// which begins with the first colon in the hostname and should be followed
|
||||
// by a string of decimal digits.
|
||||
//
|
||||
// If the port portion is valid, a normalized version of it is returned along
|
||||
// with a nil error.
|
||||
//
|
||||
// If the port portion is invalid, the input string is returned verbatim along
|
||||
// with a non-nil error.
|
||||
//
|
||||
// An empty string is a valid port portion representing the absence of a port.
|
||||
// If non-empty, the first character must be a colon.
|
||||
func normalizePortPortion(s string) (string, error) {
|
||||
if s == "" {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
if s[0] != ':' {
|
||||
// should never happen, since caller tends to guarantee the presence
|
||||
// of a colon due to how it's extracted from the string.
|
||||
return s, errors.New("port portion is missing its initial colon")
|
||||
}
|
||||
|
||||
numStr := s[1:]
|
||||
num, err := strconv.Atoi(numStr)
|
||||
if err != nil {
|
||||
return s, errors.New("port portion contains non-digit characters")
|
||||
}
|
||||
if num == 443 {
|
||||
return "", nil // ":443" is the default
|
||||
}
|
||||
if num > 65535 {
|
||||
return s, errors.New("port number is greater than 65535")
|
||||
}
|
||||
return fmt.Sprintf(":%d", num), nil
|
||||
}
|
|
@ -1,218 +0,0 @@
|
|||
package svchost
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestForDisplay(t *testing.T) {
|
||||
tests := []struct {
|
||||
Input string
|
||||
Want string
|
||||
}{
|
||||
{
|
||||
"",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"example.com",
|
||||
"example.com",
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"invalid",
|
||||
},
|
||||
{
|
||||
"localhost",
|
||||
"localhost",
|
||||
},
|
||||
{
|
||||
"localhost:1211",
|
||||
"localhost:1211",
|
||||
},
|
||||
{
|
||||
"HashiCorp.com",
|
||||
"hashicorp.com",
|
||||
},
|
||||
{
|
||||
"Испытание.com",
|
||||
"испытание.com",
|
||||
},
|
||||
{
|
||||
"münchen.de", // this is a precomposed u with diaeresis
|
||||
"münchen.de", // this is a precomposed u with diaeresis
|
||||
},
|
||||
{
|
||||
"münchen.de", // this is a separate u and combining diaeresis
|
||||
"münchen.de", // this is a precomposed u with diaeresis
|
||||
},
|
||||
{
|
||||
"example.com:443",
|
||||
"example.com",
|
||||
},
|
||||
{
|
||||
"example.com:81",
|
||||
"example.com:81",
|
||||
},
|
||||
{
|
||||
"example.com:boo",
|
||||
"example.com:boo", // invalid, but tolerated for display purposes
|
||||
},
|
||||
{
|
||||
"example.com:boo:boo",
|
||||
"example.com:boo:boo", // invalid, but tolerated for display purposes
|
||||
},
|
||||
{
|
||||
"example.com:081",
|
||||
"example.com:81",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.Input, func(t *testing.T) {
|
||||
got := ForDisplay(test.Input)
|
||||
if got != test.Want {
|
||||
t.Errorf("wrong result\ninput: %s\ngot: %s\nwant: %s", test.Input, got, test.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestForComparison(t *testing.T) {
|
||||
tests := []struct {
|
||||
Input string
|
||||
Want string
|
||||
Err bool
|
||||
}{
|
||||
{
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"example.com",
|
||||
"example.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"example.com:443",
|
||||
"example.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"example.com:81",
|
||||
"example.com:81",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"example.com:081",
|
||||
"example.com:81",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"invalid",
|
||||
false, // the "invalid" TLD is, confusingly, a valid hostname syntactically
|
||||
},
|
||||
{
|
||||
"localhost", // supported for local testing only
|
||||
"localhost",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"localhost:1211", // supported for local testing only
|
||||
"localhost:1211",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"HashiCorp.com",
|
||||
"hashicorp.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"1example.com",
|
||||
"1example.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"Испытание.com",
|
||||
"xn--80akhbyknj4f.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"münchen.de", // this is a precomposed u with diaeresis
|
||||
"xn--mnchen-3ya.de",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"münchen.de", // this is a separate u and combining diaeresis
|
||||
"xn--mnchen-3ya.de",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"blah..blah",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"example.com:boo",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"example.com:80:boo",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.Input, func(t *testing.T) {
|
||||
got, err := ForComparison(test.Input)
|
||||
if (err != nil) != test.Err {
|
||||
if test.Err {
|
||||
t.Error("unexpected success; want error")
|
||||
} else {
|
||||
t.Errorf("unexpected error; want success\nerror: %s", err)
|
||||
}
|
||||
}
|
||||
if string(got) != test.Want {
|
||||
t.Errorf("wrong result\ninput: %s\ngot: %s\nwant: %s", test.Input, got, test.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostnameForDisplay(t *testing.T) {
|
||||
tests := []struct {
|
||||
Input string
|
||||
Want string
|
||||
}{
|
||||
{
|
||||
"example.com",
|
||||
"example.com",
|
||||
},
|
||||
{
|
||||
"example.com:81",
|
||||
"example.com:81",
|
||||
},
|
||||
{
|
||||
"xn--80akhbyknj4f.com",
|
||||
"испытание.com",
|
||||
},
|
||||
{
|
||||
"xn--80akhbyknj4f.com:8080",
|
||||
"испытание.com:8080",
|
||||
},
|
||||
{
|
||||
"xn--mnchen-3ya.de",
|
||||
"münchen.de", // this is a precomposed u with diaeresis
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.Input, func(t *testing.T) {
|
||||
got := Hostname(test.Input).ForDisplay()
|
||||
if got != test.Want {
|
||||
t.Errorf("wrong result\ninput: %s\ngot: %s\nwant: %s", test.Input, got, test.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue