svchost/disco: lookup of service URLs within a discovered map

This package implements our Terraform-native Service discovery protocol,
which allows us to find the base URL for a particular service given a
hostname that was already validated and normalized by the svchost package.
This commit is contained in:
Martin Atkins 2017-10-17 15:25:16 -07:00
parent db08ee4ac5
commit 6cd9a8f9c2
4 changed files with 538 additions and 0 deletions

177
svchost/disco/disco.go Normal file
View File

@ -0,0 +1,177 @@
// Package disco handles Terraform's remote service discovery protocol.
//
// This protocol allows mapping from a service hostname, as produced by the
// svchost package, to a set of services supported by that host and the
// endpoint information for each supported service.
package disco
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"mime"
"net/http"
"net/url"
"time"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/terraform/svchost"
"github.com/hashicorp/terraform/terraform"
)
const (
discoPath = "/.well-known/terraform.json"
maxRedirects = 3 // arbitrary-but-small number to prevent runaway redirect loops
discoTimeout = 4 * time.Second // arbitrary-but-small time limit to prevent UI "hangs" during discovery
maxDiscoDocBytes = 1 * 1024 * 1024 // 1MB - to prevent abusive services from using loads of our memory
)
var userAgent = fmt.Sprintf("Terraform/%s (service discovery)", terraform.VersionString())
var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during tests, to skip TLS verification
// Disco is the main type in this package, which allows discovery on given
// hostnames and caches the results by hostname to avoid repeated requests
// for the same information.
type Disco struct {
hostCache map[svchost.Hostname]Host
}
func NewDisco() *Disco {
return &Disco{}
}
// Discover runs the discovery protocol against the given hostname (which must
// already have been validated and prepared with svchost.ForComparison) and
// returns an object describing the services available at that host.
//
// If a given hostname supports no Terraform services at all, a non-nil but
// empty Host object is returned. When giving feedback to the end user about
// such situations, we say e.g. "the host <name> doesn't provide a module
// registry", regardless of whether that is due to that service specifically
// being absent or due to the host not providing Terraform services at all,
// since we don't wish to expose the detail of whole-host discovery to an
// end-user.
func (d *Disco) Discover(host svchost.Hostname) Host {
if d.hostCache == nil {
d.hostCache = map[svchost.Hostname]Host{}
}
if cache, cached := d.hostCache[host]; cached {
return cache
}
ret := d.discover(host)
d.hostCache[host] = ret
return ret
}
// DiscoverServiceURL is a convenience wrapper for discovery on a given
// hostname and then looking up a particular service in the result.
func (d *Disco) DiscoverServiceURL(host svchost.Hostname, serviceID string) *url.URL {
return d.Discover(host).ServiceURL(serviceID)
}
// discover implements the actual discovery process, with its result cached
// by the public-facing Discover method.
func (d *Disco) discover(host svchost.Hostname) Host {
discoURL := &url.URL{
Scheme: "https",
Host: string(host),
Path: discoPath,
}
client := &http.Client{
Transport: httpTransport,
Timeout: discoTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
log.Printf("[DEBUG] Service discovery redirected to %s", req.URL)
if len(via) > maxRedirects {
return errors.New("too many redirects") // (this error message will never actually be seen)
}
return nil
},
}
var header = http.Header{}
header.Set("User-Agent", userAgent)
// TODO: look up credentials and add them to the header if we have them
req := &http.Request{
Method: "GET",
URL: discoURL,
Header: header,
}
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL)
ret := Host{
discoURL: discoURL,
}
resp, err := client.Do(req)
if err != nil {
log.Printf("[WARNING] Failed to request discovery document: %s", err)
return ret // empty
}
if resp.StatusCode != 200 {
log.Printf("[WARNING] Failed to request discovery document: %s", resp.Status)
return ret // empty
}
// If the client followed any redirects, we will have a new URL to use
// as our base for relative resolution.
ret.discoURL = resp.Request.URL
contentType := resp.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
log.Printf("[WARNING] Discovery URL has malformed Content-Type %q", contentType)
return ret // empty
}
if mediaType != "application/json" {
log.Printf("[DEBUG] Discovery URL returned Content-Type %q, rather than application/json", mediaType)
return ret // empty
}
// (this doesn't catch chunked encoding, because ContentLength is -1 in that case...)
if resp.ContentLength > maxDiscoDocBytes {
// Size limit here is not a contractual requirement and so we may
// adjust it over time if we find a different limit is warranted.
log.Printf("[WARNING] Discovery doc response is too large (got %d bytes; limit %d)", resp.ContentLength, maxDiscoDocBytes)
return ret // empty
}
// If the response is using chunked encoding then we can't predict
// its size, but we'll at least prevent reading the entire thing into
// memory.
lr := io.LimitReader(resp.Body, maxDiscoDocBytes)
servicesBytes, err := ioutil.ReadAll(lr)
if err != nil {
log.Printf("[WARNING] Error reading discovery document body: %s", err)
return ret // empty
}
var services map[string]interface{}
err = json.Unmarshal(servicesBytes, &services)
if err != nil {
log.Printf("[WARNING] Failed to decode discovery document as a JSON object: %s", err)
return ret // empty
}
ret.services = services
return ret
}
// Forget invalidates any cached record of the given hostname. If the host
// has no cache entry then this is a no-op.
func (d *Disco) Forget(host svchost.Hostname) {
delete(d.hostCache, host)
}
// ForgetAll is like Forget, but for all of the hostnames that have cache entries.
func (d *Disco) ForgetAll() {
d.hostCache = nil
}

255
svchost/disco/disco_test.go Normal file
View File

@ -0,0 +1,255 @@
package disco
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strconv"
"testing"
"github.com/hashicorp/terraform/svchost"
)
func TestMain(m *testing.M) {
// During all tests we override the HTTP transport we use for discovery
// so it'll tolerate the locally-generated TLS certificates we use
// for test URLs.
httpTransport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
os.Exit(m.Run())
}
func TestDiscover(t *testing.T) {
t.Run("happy path", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`
{
"thingy.v1": "http://example.com/foo",
"wotsit.v2": "http://example.net/bar"
}
`)
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
discovered := d.Discover(host)
gotURL := discovered.ServiceURL("thingy.v1")
if gotURL == nil {
t.Fatalf("found no URL for thingy.v1")
}
if got, want := gotURL.String(), "http://example.com/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
})
t.Run("chunked encoding", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`
{
"thingy.v1": "http://example.com/foo",
"wotsit.v2": "http://example.net/bar"
}
`)
w.Header().Add("Content-Type", "application/json")
// We're going to force chunked encoding here -- and thus prevent
// the server from predicting the length -- so we can make sure
// our client is tolerant of servers using this encoding.
w.Write(resp[:5])
w.(http.Flusher).Flush()
w.Write(resp[5:])
w.(http.Flusher).Flush()
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
discovered := d.Discover(host)
gotURL := discovered.ServiceURL("wotsit.v2")
if gotURL == nil {
t.Fatalf("found no URL for wotsit.v2")
}
if got, want := gotURL.String(), "http://example.net/bar"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
})
t.Run("not JSON", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
w.Header().Add("Content-Type", "application/octet-stream")
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
discovered := d.Discover(host)
// result should be empty, which we can verify only by reaching into
// its internals.
if discovered.services != nil {
t.Errorf("response not empty; should be")
}
})
t.Run("malformed JSON", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{"thingy.v1": "htt`) // truncated, for example...
w.Header().Add("Content-Type", "application/json")
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
discovered := d.Discover(host)
// result should be empty, which we can verify only by reaching into
// its internals.
if discovered.services != nil {
t.Errorf("response not empty; should be")
}
})
t.Run("JSON with redundant charset", func(t *testing.T) {
// The JSON RFC defines no parameters for the application/json
// MIME type, but some servers have a weird tendency to just add
// "charset" to everything, so we'll make sure we ignore it successfully.
// (JSON uses content sniffing for encoding detection, not media type params.)
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
w.Header().Add("Content-Type", "application/json; charset=latin-1")
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
discovered := d.Discover(host)
if discovered.services == nil {
t.Errorf("response is empty; shouldn't be")
}
})
t.Run("no discovery doc", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
discovered := d.Discover(host)
// result should be empty, which we can verify only by reaching into
// its internals.
if discovered.services != nil {
t.Errorf("response not empty; should be")
}
})
t.Run("redirect", func(t *testing.T) {
// For this test, we have two servers and one redirects to the other
portStr1, close1 := testServer(func(w http.ResponseWriter, r *http.Request) {
// This server is the one that returns a real response.
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
portStr2, close2 := testServer(func(w http.ResponseWriter, r *http.Request) {
// This server is the one that redirects.
http.Redirect(w, r, "https://127.0.0.1"+portStr1+"/.well-known/terraform.json", 302)
})
defer close1()
defer close2()
givenHost := "localhost" + portStr2
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
discovered := d.Discover(host)
gotURL := discovered.ServiceURL("thingy.v1")
if gotURL == nil {
t.Fatalf("found no URL for thingy.v1")
}
if got, want := gotURL.String(), "http://example.com/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
// The base URL for the host object should be the URL we redirected to,
// rather than the we redirected _from_.
gotBaseURL := discovered.discoURL.String()
wantBaseURL := "https://127.0.0.1" + portStr1 + "/.well-known/terraform.json"
if gotBaseURL != wantBaseURL {
t.Errorf("incorrect base url %s; want %s", gotBaseURL, wantBaseURL)
}
})
}
func testServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
server := httptest.NewTLSServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// Test server always returns 404 if the URL isn't what we expect
if r.URL.Path != "/.well-known/terraform.json" {
w.WriteHeader(404)
w.Write([]byte("not found"))
return
}
// If the URL is correct then the given hander decides the response
h(w, r)
},
))
serverURL, _ := url.Parse(server.URL)
portStr = serverURL.Port()
if portStr != "" {
portStr = ":" + portStr
}
close = func() {
server.Close()
}
return
}

51
svchost/disco/host.go Normal file
View File

@ -0,0 +1,51 @@
package disco
import (
"net/url"
)
type Host struct {
discoURL *url.URL
services map[string]interface{}
}
// ServiceURL returns the URL associated with the given service identifier,
// which should be of the form "servicename.vN".
//
// A non-nil result is always an absolute URL with a scheme of either https
// or http.
//
// If the requested service is not supported by the host, this method returns
// a nil URL.
//
// If the discovery document entry for the given service is invalid (not a URL),
// it is treated as absent, also returning a nil URL.
func (h Host) ServiceURL(id string) *url.URL {
if h.services == nil {
return nil // no services supported for an empty Host
}
urlStr, ok := h.services[id].(string)
if !ok {
return nil
}
ret, err := url.Parse(urlStr)
if err != nil {
return nil
}
if !ret.IsAbs() {
ret = h.discoURL.ResolveReference(ret) // make absolute using our discovery doc URL
}
if ret.Scheme != "https" && ret.Scheme != "http" {
return nil
}
if ret.User != nil {
// embedded username/password information is not permitted; credentials
// are handled out of band.
return nil
}
ret.Fragment = "" // fragment part is irrelevant, since we're not a browser
return h.discoURL.ResolveReference(ret)
}

View File

@ -0,0 +1,55 @@
package disco
import (
"net/url"
"testing"
)
func TestHostServiceURL(t *testing.T) {
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
host := Host{
discoURL: baseURL,
services: map[string]interface{}{
"absolute.v1": "http://example.net/foo/bar",
"absolutewithport.v1": "http://example.net:8080/foo/bar",
"relative.v1": "./stu/",
"rootrelative.v1": "/baz",
"protorelative.v1": "//example.net/",
"withfragment.v1": "http://example.org/#foo",
"querystring.v1": "https://example.net/baz?foo=bar",
"nothttp.v1": "ftp://127.0.0.1/pub/",
"invalid.v1": "***not A URL at all!:/<@@@@>***",
},
}
tests := []struct {
ID string
Want string
}{
{"absolute.v1", "http://example.net/foo/bar"},
{"absolutewithport.v1", "http://example.net:8080/foo/bar"},
{"relative.v1", "https://example.com/disco/stu/"},
{"rootrelative.v1", "https://example.com/baz"},
{"protorelative.v1", "https://example.net/"},
{"withfragment.v1", "http://example.org/"},
{"querystring.v1", "https://example.net/baz?foo=bar"}, // most callers will disregard query string
{"nothttp.v1", "<nil>"},
{"invalid.v1", "<nil>"},
}
for _, test := range tests {
t.Run(test.ID, func(t *testing.T) {
url := host.ServiceURL(test.ID)
var got string
if url != nil {
got = url.String()
} else {
got = "<nil>"
}
if got != test.Want {
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.Want)
}
})
}
}