svchost/disco: add credentials, if available, to disco requests
Although service discovery metadata is usually not sensitive, a service host may wish to produce different results depending on the requesting user, such as if users are migrating between two different implementations that are both running concurrently for some period.
This commit is contained in:
parent
fcff4cbc95
commit
83b098344b
|
@ -19,6 +19,7 @@ import (
|
||||||
|
|
||||||
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
||||||
"github.com/hashicorp/terraform/svchost"
|
"github.com/hashicorp/terraform/svchost"
|
||||||
|
"github.com/hashicorp/terraform/svchost/auth"
|
||||||
"github.com/hashicorp/terraform/terraform"
|
"github.com/hashicorp/terraform/terraform"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,12 +38,22 @@ var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during test
|
||||||
// for the same information.
|
// for the same information.
|
||||||
type Disco struct {
|
type Disco struct {
|
||||||
hostCache map[svchost.Hostname]Host
|
hostCache map[svchost.Hostname]Host
|
||||||
|
credsSrc auth.CredentialsSource
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDisco() *Disco {
|
func NewDisco() *Disco {
|
||||||
return &Disco{}
|
return &Disco{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCredentialsSource provides a credentials source that will be used to
|
||||||
|
// add credentials to outgoing discovery requests, where available.
|
||||||
|
//
|
||||||
|
// If this method is never called, no outgoing discovery requests will have
|
||||||
|
// credentials.
|
||||||
|
func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
|
||||||
|
d.credsSrc = src
|
||||||
|
}
|
||||||
|
|
||||||
// Discover runs the discovery protocol against the given hostname (which must
|
// Discover runs the discovery protocol against the given hostname (which must
|
||||||
// already have been validated and prepared with svchost.ForComparison) and
|
// already have been validated and prepared with svchost.ForComparison) and
|
||||||
// returns an object describing the services available at that host.
|
// returns an object describing the services available at that host.
|
||||||
|
@ -96,7 +107,6 @@ func (d *Disco) discover(host svchost.Hostname) Host {
|
||||||
|
|
||||||
var header = http.Header{}
|
var header = http.Header{}
|
||||||
header.Set("User-Agent", userAgent)
|
header.Set("User-Agent", userAgent)
|
||||||
// TODO: look up credentials and add them to the header if we have them
|
|
||||||
|
|
||||||
req := &http.Request{
|
req := &http.Request{
|
||||||
Method: "GET",
|
Method: "GET",
|
||||||
|
@ -104,6 +114,17 @@ func (d *Disco) discover(host svchost.Hostname) Host {
|
||||||
Header: header,
|
Header: header,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if d.credsSrc != nil {
|
||||||
|
creds, err := d.credsSrc.ForHost(host)
|
||||||
|
if err == nil {
|
||||||
|
if creds != nil {
|
||||||
|
creds.PrepareRequest(req) // alters req to include credentials
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("[WARNING] Failed to get credentials for %s: %s (ignoring)", host, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL)
|
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL)
|
||||||
|
|
||||||
ret := Host{
|
ret := Host{
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/hashicorp/terraform/svchost"
|
"github.com/hashicorp/terraform/svchost"
|
||||||
|
"github.com/hashicorp/terraform/svchost/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
@ -89,6 +90,34 @@ func TestDiscover(t *testing.T) {
|
||||||
t.Fatalf("wrong result %q; want %q", got, want)
|
t.Fatalf("wrong result %q; want %q", got, want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
t.Run("with credentials", func(t *testing.T) {
|
||||||
|
var authHeaderText string
|
||||||
|
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := []byte(`{}`)
|
||||||
|
authHeaderText = r.Header.Get("Authorization")
|
||||||
|
w.Header().Add("Content-Type", "application/json")
|
||||||
|
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
|
||||||
|
w.Write(resp)
|
||||||
|
})
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
givenHost := "localhost" + portStr
|
||||||
|
host, err := svchost.ForComparison(givenHost)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("test server hostname is invalid: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
d := NewDisco()
|
||||||
|
d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
|
||||||
|
host: map[string]interface{}{
|
||||||
|
"token": "abc123",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
d.Discover(host)
|
||||||
|
if got, want := authHeaderText, "Bearer abc123"; got != want {
|
||||||
|
t.Fatalf("wrong Authorization header\ngot: %s\nwant: %s", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
t.Run("not JSON", func(t *testing.T) {
|
t.Run("not JSON", func(t *testing.T) {
|
||||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
|
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
|
||||||
|
|
Loading…
Reference in New Issue