svchost/disco: lookup of service URLs within a discovered map
This package implements our Terraform-native Service discovery protocol, which allows us to find the base URL for a particular service given a hostname that was already validated and normalized by the svchost package.
This commit is contained in:
parent
db08ee4ac5
commit
6cd9a8f9c2
|
@ -0,0 +1,177 @@
|
|||
// Package disco handles Terraform's remote service discovery protocol.
|
||||
//
|
||||
// This protocol allows mapping from a service hostname, as produced by the
|
||||
// svchost package, to a set of services supported by that host and the
|
||||
// endpoint information for each supported service.
|
||||
package disco
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
"github.com/hashicorp/terraform/terraform"
|
||||
)
|
||||
|
||||
const (
|
||||
discoPath = "/.well-known/terraform.json"
|
||||
maxRedirects = 3 // arbitrary-but-small number to prevent runaway redirect loops
|
||||
discoTimeout = 4 * time.Second // arbitrary-but-small time limit to prevent UI "hangs" during discovery
|
||||
maxDiscoDocBytes = 1 * 1024 * 1024 // 1MB - to prevent abusive services from using loads of our memory
|
||||
)
|
||||
|
||||
var userAgent = fmt.Sprintf("Terraform/%s (service discovery)", terraform.VersionString())
|
||||
var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during tests, to skip TLS verification
|
||||
|
||||
// Disco is the main type in this package, which allows discovery on given
|
||||
// hostnames and caches the results by hostname to avoid repeated requests
|
||||
// for the same information.
|
||||
type Disco struct {
|
||||
hostCache map[svchost.Hostname]Host
|
||||
}
|
||||
|
||||
func NewDisco() *Disco {
|
||||
return &Disco{}
|
||||
}
|
||||
|
||||
// Discover runs the discovery protocol against the given hostname (which must
|
||||
// already have been validated and prepared with svchost.ForComparison) and
|
||||
// returns an object describing the services available at that host.
|
||||
//
|
||||
// If a given hostname supports no Terraform services at all, a non-nil but
|
||||
// empty Host object is returned. When giving feedback to the end user about
|
||||
// such situations, we say e.g. "the host <name> doesn't provide a module
|
||||
// registry", regardless of whether that is due to that service specifically
|
||||
// being absent or due to the host not providing Terraform services at all,
|
||||
// since we don't wish to expose the detail of whole-host discovery to an
|
||||
// end-user.
|
||||
func (d *Disco) Discover(host svchost.Hostname) Host {
|
||||
if d.hostCache == nil {
|
||||
d.hostCache = map[svchost.Hostname]Host{}
|
||||
}
|
||||
if cache, cached := d.hostCache[host]; cached {
|
||||
return cache
|
||||
}
|
||||
|
||||
ret := d.discover(host)
|
||||
d.hostCache[host] = ret
|
||||
return ret
|
||||
}
|
||||
|
||||
// DiscoverServiceURL is a convenience wrapper for discovery on a given
|
||||
// hostname and then looking up a particular service in the result.
|
||||
func (d *Disco) DiscoverServiceURL(host svchost.Hostname, serviceID string) *url.URL {
|
||||
return d.Discover(host).ServiceURL(serviceID)
|
||||
}
|
||||
|
||||
// discover implements the actual discovery process, with its result cached
|
||||
// by the public-facing Discover method.
|
||||
func (d *Disco) discover(host svchost.Hostname) Host {
|
||||
discoURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: string(host),
|
||||
Path: discoPath,
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: httpTransport,
|
||||
Timeout: discoTimeout,
|
||||
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
log.Printf("[DEBUG] Service discovery redirected to %s", req.URL)
|
||||
if len(via) > maxRedirects {
|
||||
return errors.New("too many redirects") // (this error message will never actually be seen)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var header = http.Header{}
|
||||
header.Set("User-Agent", userAgent)
|
||||
// TODO: look up credentials and add them to the header if we have them
|
||||
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: discoURL,
|
||||
Header: header,
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL)
|
||||
|
||||
ret := Host{
|
||||
discoURL: discoURL,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] Failed to request discovery document: %s", err)
|
||||
return ret // empty
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
log.Printf("[WARNING] Failed to request discovery document: %s", resp.Status)
|
||||
return ret // empty
|
||||
}
|
||||
|
||||
// If the client followed any redirects, we will have a new URL to use
|
||||
// as our base for relative resolution.
|
||||
ret.discoURL = resp.Request.URL
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] Discovery URL has malformed Content-Type %q", contentType)
|
||||
return ret // empty
|
||||
}
|
||||
if mediaType != "application/json" {
|
||||
log.Printf("[DEBUG] Discovery URL returned Content-Type %q, rather than application/json", mediaType)
|
||||
return ret // empty
|
||||
}
|
||||
|
||||
// (this doesn't catch chunked encoding, because ContentLength is -1 in that case...)
|
||||
if resp.ContentLength > maxDiscoDocBytes {
|
||||
// Size limit here is not a contractual requirement and so we may
|
||||
// adjust it over time if we find a different limit is warranted.
|
||||
log.Printf("[WARNING] Discovery doc response is too large (got %d bytes; limit %d)", resp.ContentLength, maxDiscoDocBytes)
|
||||
return ret // empty
|
||||
}
|
||||
|
||||
// If the response is using chunked encoding then we can't predict
|
||||
// its size, but we'll at least prevent reading the entire thing into
|
||||
// memory.
|
||||
lr := io.LimitReader(resp.Body, maxDiscoDocBytes)
|
||||
|
||||
servicesBytes, err := ioutil.ReadAll(lr)
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] Error reading discovery document body: %s", err)
|
||||
return ret // empty
|
||||
}
|
||||
|
||||
var services map[string]interface{}
|
||||
err = json.Unmarshal(servicesBytes, &services)
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] Failed to decode discovery document as a JSON object: %s", err)
|
||||
return ret // empty
|
||||
}
|
||||
|
||||
ret.services = services
|
||||
return ret
|
||||
}
|
||||
|
||||
// Forget invalidates any cached record of the given hostname. If the host
|
||||
// has no cache entry then this is a no-op.
|
||||
func (d *Disco) Forget(host svchost.Hostname) {
|
||||
delete(d.hostCache, host)
|
||||
}
|
||||
|
||||
// ForgetAll is like Forget, but for all of the hostnames that have cache entries.
|
||||
func (d *Disco) ForgetAll() {
|
||||
d.hostCache = nil
|
||||
}
|
|
@ -0,0 +1,255 @@
|
|||
package disco
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/terraform/svchost"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// During all tests we override the HTTP transport we use for discovery
|
||||
// so it'll tolerate the locally-generated TLS certificates we use
|
||||
// for test URLs.
|
||||
httpTransport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestDiscover(t *testing.T) {
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`
|
||||
{
|
||||
"thingy.v1": "http://example.com/foo",
|
||||
"wotsit.v2": "http://example.net/bar"
|
||||
}
|
||||
`)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := NewDisco()
|
||||
discovered := d.Discover(host)
|
||||
gotURL := discovered.ServiceURL("thingy.v1")
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for thingy.v1")
|
||||
}
|
||||
if got, want := gotURL.String(), "http://example.com/foo"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("chunked encoding", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`
|
||||
{
|
||||
"thingy.v1": "http://example.com/foo",
|
||||
"wotsit.v2": "http://example.net/bar"
|
||||
}
|
||||
`)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
// We're going to force chunked encoding here -- and thus prevent
|
||||
// the server from predicting the length -- so we can make sure
|
||||
// our client is tolerant of servers using this encoding.
|
||||
w.Write(resp[:5])
|
||||
w.(http.Flusher).Flush()
|
||||
w.Write(resp[5:])
|
||||
w.(http.Flusher).Flush()
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := NewDisco()
|
||||
discovered := d.Discover(host)
|
||||
gotURL := discovered.ServiceURL("wotsit.v2")
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for wotsit.v2")
|
||||
}
|
||||
if got, want := gotURL.String(), "http://example.net/bar"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("not JSON", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
|
||||
w.Header().Add("Content-Type", "application/octet-stream")
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := NewDisco()
|
||||
discovered := d.Discover(host)
|
||||
|
||||
// result should be empty, which we can verify only by reaching into
|
||||
// its internals.
|
||||
if discovered.services != nil {
|
||||
t.Errorf("response not empty; should be")
|
||||
}
|
||||
})
|
||||
t.Run("malformed JSON", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`{"thingy.v1": "htt`) // truncated, for example...
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := NewDisco()
|
||||
discovered := d.Discover(host)
|
||||
|
||||
// result should be empty, which we can verify only by reaching into
|
||||
// its internals.
|
||||
if discovered.services != nil {
|
||||
t.Errorf("response not empty; should be")
|
||||
}
|
||||
})
|
||||
t.Run("JSON with redundant charset", func(t *testing.T) {
|
||||
// The JSON RFC defines no parameters for the application/json
|
||||
// MIME type, but some servers have a weird tendency to just add
|
||||
// "charset" to everything, so we'll make sure we ignore it successfully.
|
||||
// (JSON uses content sniffing for encoding detection, not media type params.)
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
|
||||
w.Header().Add("Content-Type", "application/json; charset=latin-1")
|
||||
w.Write(resp)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := NewDisco()
|
||||
discovered := d.Discover(host)
|
||||
|
||||
if discovered.services == nil {
|
||||
t.Errorf("response is empty; shouldn't be")
|
||||
}
|
||||
})
|
||||
t.Run("no discovery doc", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(404)
|
||||
})
|
||||
defer close()
|
||||
|
||||
givenHost := "localhost" + portStr
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := NewDisco()
|
||||
discovered := d.Discover(host)
|
||||
|
||||
// result should be empty, which we can verify only by reaching into
|
||||
// its internals.
|
||||
if discovered.services != nil {
|
||||
t.Errorf("response not empty; should be")
|
||||
}
|
||||
})
|
||||
t.Run("redirect", func(t *testing.T) {
|
||||
// For this test, we have two servers and one redirects to the other
|
||||
portStr1, close1 := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
// This server is the one that returns a real response.
|
||||
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
|
||||
w.Write(resp)
|
||||
})
|
||||
portStr2, close2 := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
// This server is the one that redirects.
|
||||
http.Redirect(w, r, "https://127.0.0.1"+portStr1+"/.well-known/terraform.json", 302)
|
||||
})
|
||||
defer close1()
|
||||
defer close2()
|
||||
|
||||
givenHost := "localhost" + portStr2
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
d := NewDisco()
|
||||
discovered := d.Discover(host)
|
||||
|
||||
gotURL := discovered.ServiceURL("thingy.v1")
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for thingy.v1")
|
||||
}
|
||||
if got, want := gotURL.String(), "http://example.com/foo"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
|
||||
// The base URL for the host object should be the URL we redirected to,
|
||||
// rather than the we redirected _from_.
|
||||
gotBaseURL := discovered.discoURL.String()
|
||||
wantBaseURL := "https://127.0.0.1" + portStr1 + "/.well-known/terraform.json"
|
||||
if gotBaseURL != wantBaseURL {
|
||||
t.Errorf("incorrect base url %s; want %s", gotBaseURL, wantBaseURL)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func testServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// Test server always returns 404 if the URL isn't what we expect
|
||||
if r.URL.Path != "/.well-known/terraform.json" {
|
||||
w.WriteHeader(404)
|
||||
w.Write([]byte("not found"))
|
||||
return
|
||||
}
|
||||
|
||||
// If the URL is correct then the given hander decides the response
|
||||
h(w, r)
|
||||
},
|
||||
))
|
||||
|
||||
serverURL, _ := url.Parse(server.URL)
|
||||
|
||||
portStr = serverURL.Port()
|
||||
if portStr != "" {
|
||||
portStr = ":" + portStr
|
||||
}
|
||||
|
||||
close = func() {
|
||||
server.Close()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package disco
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type Host struct {
|
||||
discoURL *url.URL
|
||||
services map[string]interface{}
|
||||
}
|
||||
|
||||
// ServiceURL returns the URL associated with the given service identifier,
|
||||
// which should be of the form "servicename.vN".
|
||||
//
|
||||
// A non-nil result is always an absolute URL with a scheme of either https
|
||||
// or http.
|
||||
//
|
||||
// If the requested service is not supported by the host, this method returns
|
||||
// a nil URL.
|
||||
//
|
||||
// If the discovery document entry for the given service is invalid (not a URL),
|
||||
// it is treated as absent, also returning a nil URL.
|
||||
func (h Host) ServiceURL(id string) *url.URL {
|
||||
if h.services == nil {
|
||||
return nil // no services supported for an empty Host
|
||||
}
|
||||
|
||||
urlStr, ok := h.services[id].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
ret, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !ret.IsAbs() {
|
||||
ret = h.discoURL.ResolveReference(ret) // make absolute using our discovery doc URL
|
||||
}
|
||||
if ret.Scheme != "https" && ret.Scheme != "http" {
|
||||
return nil
|
||||
}
|
||||
if ret.User != nil {
|
||||
// embedded username/password information is not permitted; credentials
|
||||
// are handled out of band.
|
||||
return nil
|
||||
}
|
||||
ret.Fragment = "" // fragment part is irrelevant, since we're not a browser
|
||||
|
||||
return h.discoURL.ResolveReference(ret)
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package disco
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHostServiceURL(t *testing.T) {
|
||||
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
services: map[string]interface{}{
|
||||
"absolute.v1": "http://example.net/foo/bar",
|
||||
"absolutewithport.v1": "http://example.net:8080/foo/bar",
|
||||
"relative.v1": "./stu/",
|
||||
"rootrelative.v1": "/baz",
|
||||
"protorelative.v1": "//example.net/",
|
||||
"withfragment.v1": "http://example.org/#foo",
|
||||
"querystring.v1": "https://example.net/baz?foo=bar",
|
||||
"nothttp.v1": "ftp://127.0.0.1/pub/",
|
||||
"invalid.v1": "***not A URL at all!:/<@@@@>***",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ID string
|
||||
Want string
|
||||
}{
|
||||
{"absolute.v1", "http://example.net/foo/bar"},
|
||||
{"absolutewithport.v1", "http://example.net:8080/foo/bar"},
|
||||
{"relative.v1", "https://example.com/disco/stu/"},
|
||||
{"rootrelative.v1", "https://example.com/baz"},
|
||||
{"protorelative.v1", "https://example.net/"},
|
||||
{"withfragment.v1", "http://example.org/"},
|
||||
{"querystring.v1", "https://example.net/baz?foo=bar"}, // most callers will disregard query string
|
||||
{"nothttp.v1", "<nil>"},
|
||||
{"invalid.v1", "<nil>"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.ID, func(t *testing.T) {
|
||||
url := host.ServiceURL(test.ID)
|
||||
var got string
|
||||
if url != nil {
|
||||
got = url.String()
|
||||
} else {
|
||||
got = "<nil>"
|
||||
}
|
||||
|
||||
if got != test.Want {
|
||||
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue