Merge pull request #19589 from hashicorp/svh/f-service-discovery
core: enhance service discovery
This commit is contained in:
commit
ecc93c9889
|
@ -302,9 +302,9 @@ func (b *Remote) discover(hostname string) (*url.URL, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service := b.services.DiscoverServiceURL(host, serviceID)
|
||||
if service == nil {
|
||||
return nil, fmt.Errorf("host %s does not provide a remote backend API", host)
|
||||
service, err := b.services.DiscoverServiceURL(host, serviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return service, nil
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ func TestRemote_config(t *testing.T) {
|
|||
"prefix": cty.NullVal(cty.String),
|
||||
}),
|
||||
}),
|
||||
confErr: "Host nonexisting.local does not provide a remote backend API",
|
||||
confErr: "Failed to request discovery document",
|
||||
},
|
||||
"with_a_name": {
|
||||
config: cty.ObjectVal(map[string]cty.Value{
|
||||
|
@ -112,8 +112,8 @@ func TestRemote_config(t *testing.T) {
|
|||
|
||||
// Validate
|
||||
valDiags := b.ValidateConfig(tc.config)
|
||||
if (valDiags.Err() == nil && tc.valErr != "") ||
|
||||
(valDiags.Err() != nil && !strings.Contains(valDiags.Err().Error(), tc.valErr)) {
|
||||
if (valDiags.Err() != nil || tc.valErr != "") &&
|
||||
(valDiags.Err() == nil || !strings.Contains(valDiags.Err().Error(), tc.valErr)) {
|
||||
t.Fatalf("%s: unexpected validation result: %v", name, valDiags.Err())
|
||||
}
|
||||
|
||||
|
|
|
@ -59,15 +59,15 @@ func NewClient(services *disco.Disco, client *http.Client) *Client {
|
|||
}
|
||||
|
||||
// Discover queries the host, and returns the url for the registry.
|
||||
func (c *Client) Discover(host svchost.Hostname, serviceID string) *url.URL {
|
||||
service := c.services.DiscoverServiceURL(host, serviceID)
|
||||
if service == nil {
|
||||
return nil
|
||||
func (c *Client) Discover(host svchost.Hostname, serviceID string) (*url.URL, error) {
|
||||
service, err := c.services.DiscoverServiceURL(host, serviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !strings.HasSuffix(service.Path, "/") {
|
||||
service.Path += "/"
|
||||
}
|
||||
return service
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// ModuleVersions queries the registry for a module, and returns the available versions.
|
||||
|
@ -77,9 +77,9 @@ func (c *Client) ModuleVersions(module *regsrc.Module) (*response.ModuleVersions
|
|||
return nil, err
|
||||
}
|
||||
|
||||
service := c.Discover(host, modulesServiceID)
|
||||
if service == nil {
|
||||
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "modules"}
|
||||
service, err := c.Discover(host, modulesServiceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p, err := url.Parse(path.Join(module.Module(), "versions"))
|
||||
|
@ -150,9 +150,9 @@ func (c *Client) ModuleLocation(module *regsrc.Module, version string) (string,
|
|||
return "", err
|
||||
}
|
||||
|
||||
service := c.Discover(host, modulesServiceID)
|
||||
if service == nil {
|
||||
return "", &errServiceNotProvided{host: host.ForDisplay(), service: "modules"}
|
||||
service, err := c.Discover(host, modulesServiceID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var p *url.URL
|
||||
|
@ -234,9 +234,9 @@ func (c *Client) TerraformProviderVersions(provider *regsrc.TerraformProvider) (
|
|||
return nil, err
|
||||
}
|
||||
|
||||
service := c.Discover(host, providersServiceID)
|
||||
if service == nil {
|
||||
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"}
|
||||
service, err := c.Discover(host, providersServiceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p, err := url.Parse(path.Join(provider.TerraformProvider(), "versions"))
|
||||
|
@ -288,9 +288,9 @@ func (c *Client) TerraformProviderLocation(provider *regsrc.TerraformProvider, v
|
|||
return nil, err
|
||||
}
|
||||
|
||||
service := c.Discover(host, providersServiceID)
|
||||
if service == nil {
|
||||
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"}
|
||||
service, err := c.Discover(host, providersServiceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p, err := url.Parse(path.Join(
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/hashicorp/terraform/registry/regsrc"
|
||||
"github.com/hashicorp/terraform/svchost/disco"
|
||||
)
|
||||
|
||||
type errModuleNotFound struct {
|
||||
|
@ -42,15 +43,6 @@ func IsProviderNotFound(err error) bool {
|
|||
// error. This allows callers to recognize this particular error condition
|
||||
// as distinct from operational errors such as poor network connectivity.
|
||||
func IsServiceNotProvided(err error) bool {
|
||||
_, ok := err.(*errServiceNotProvided)
|
||||
_, ok := err.(*disco.ErrServiceNotProvided)
|
||||
return ok
|
||||
}
|
||||
|
||||
type errServiceNotProvided struct {
|
||||
host string
|
||||
service string
|
||||
}
|
||||
|
||||
func (e *errServiceNotProvided) Error() string {
|
||||
return fmt.Sprintf("host %s does not provide %s", e.host, e.service)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ package disco
|
|||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
@ -22,19 +23,27 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// Fixed path to the discovery manifest.
|
||||
discoPath = "/.well-known/terraform.json"
|
||||
maxRedirects = 3 // arbitrary-but-small number to prevent runaway redirect loops
|
||||
discoTimeout = 11 * time.Second // arbitrary-but-small time limit to prevent UI "hangs" during discovery
|
||||
maxDiscoDocBytes = 1 * 1024 * 1024 // 1MB - to prevent abusive services from using loads of our memory
|
||||
|
||||
// Arbitrary-but-small number to prevent runaway redirect loops.
|
||||
maxRedirects = 3
|
||||
|
||||
// Arbitrary-but-small time limit to prevent UI "hangs" during discovery.
|
||||
discoTimeout = 11 * time.Second
|
||||
|
||||
// 1MB - to prevent abusive services from using loads of our memory.
|
||||
maxDiscoDocBytes = 1 * 1024 * 1024
|
||||
)
|
||||
|
||||
var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during tests, to skip TLS verification
|
||||
// httpTransport is overridden during tests, to skip TLS verification.
|
||||
var httpTransport = cleanhttp.DefaultPooledTransport()
|
||||
|
||||
// Disco is the main type in this package, which allows discovery on given
|
||||
// hostnames and caches the results by hostname to avoid repeated requests
|
||||
// for the same information.
|
||||
type Disco struct {
|
||||
hostCache map[svchost.Hostname]Host
|
||||
hostCache map[svchost.Hostname]*Host
|
||||
credsSrc auth.CredentialsSource
|
||||
|
||||
// Transport is a custom http.RoundTripper to use.
|
||||
|
@ -50,7 +59,10 @@ func New() *Disco {
|
|||
// NewWithCredentialsSource returns a new discovery object initialized with
|
||||
// the given credentials source.
|
||||
func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
|
||||
return &Disco{credsSrc: credsSrc}
|
||||
return &Disco{
|
||||
hostCache: make(map[svchost.Hostname]*Host),
|
||||
credsSrc: credsSrc,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCredentialsSource provides a credentials source that will be used to
|
||||
|
@ -64,11 +76,11 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
|
|||
|
||||
// CredentialsForHost returns a non-nil HostCredentials if the embedded source has
|
||||
// credentials available for the host, and a nil HostCredentials if it does not.
|
||||
func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, error) {
|
||||
func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredentials, error) {
|
||||
if d.credsSrc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return d.credsSrc.ForHost(host)
|
||||
return d.credsSrc.ForHost(hostname)
|
||||
}
|
||||
|
||||
// ForceHostServices provides a pre-defined set of services for a given
|
||||
|
@ -81,19 +93,17 @@ func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials,
|
|||
// discovery, yielding the same results as if the given map were published
|
||||
// at the host's default discovery URL, though using absolute URLs is strongly
|
||||
// recommended to make the configured behavior more explicit.
|
||||
func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]interface{}) {
|
||||
if d.hostCache == nil {
|
||||
d.hostCache = map[svchost.Hostname]Host{}
|
||||
}
|
||||
func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string]interface{}) {
|
||||
if services == nil {
|
||||
services = map[string]interface{}{}
|
||||
}
|
||||
d.hostCache[host] = Host{
|
||||
d.hostCache[hostname] = &Host{
|
||||
discoURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: string(host),
|
||||
Host: string(hostname),
|
||||
Path: discoPath,
|
||||
},
|
||||
hostname: hostname.ForDisplay(),
|
||||
services: services,
|
||||
}
|
||||
}
|
||||
|
@ -104,36 +114,40 @@ func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]int
|
|||
//
|
||||
// If a given hostname supports no Terraform services at all, a non-nil but
|
||||
// empty Host object is returned. When giving feedback to the end user about
|
||||
// such situations, we say e.g. "the host <name> doesn't provide a module
|
||||
// registry", regardless of whether that is due to that service specifically
|
||||
// being absent or due to the host not providing Terraform services at all,
|
||||
// since we don't wish to expose the detail of whole-host discovery to an
|
||||
// end-user.
|
||||
func (d *Disco) Discover(host svchost.Hostname) Host {
|
||||
if d.hostCache == nil {
|
||||
d.hostCache = map[svchost.Hostname]Host{}
|
||||
}
|
||||
if cache, cached := d.hostCache[host]; cached {
|
||||
return cache
|
||||
// such situations, we say "host <name> does not provide a <service> service",
|
||||
// regardless of whether that is due to that service specifically being absent
|
||||
// or due to the host not providing Terraform services at all, since we don't
|
||||
// wish to expose the detail of whole-host discovery to an end-user.
|
||||
func (d *Disco) Discover(hostname svchost.Hostname) (*Host, error) {
|
||||
if host, cached := d.hostCache[hostname]; cached {
|
||||
return host, nil
|
||||
}
|
||||
|
||||
ret := d.discover(host)
|
||||
d.hostCache[host] = ret
|
||||
return ret
|
||||
host, err := d.discover(hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.hostCache[hostname] = host
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// DiscoverServiceURL is a convenience wrapper for discovery on a given
|
||||
// hostname and then looking up a particular service in the result.
|
||||
func (d *Disco) DiscoverServiceURL(host svchost.Hostname, serviceID string) *url.URL {
|
||||
return d.Discover(host).ServiceURL(serviceID)
|
||||
func (d *Disco) DiscoverServiceURL(hostname svchost.Hostname, serviceID string) (*url.URL, error) {
|
||||
host, err := d.Discover(hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return host.ServiceURL(serviceID)
|
||||
}
|
||||
|
||||
// discover implements the actual discovery process, with its result cached
|
||||
// by the public-facing Discover method.
|
||||
func (d *Disco) discover(host svchost.Hostname) Host {
|
||||
func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
|
||||
discoURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: host.String(),
|
||||
Host: hostname.String(),
|
||||
Path: discoPath,
|
||||
}
|
||||
|
||||
|
@ -149,7 +163,7 @@ func (d *Disco) discover(host svchost.Hostname) Host {
|
|||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
log.Printf("[DEBUG] Service discovery redirected to %s", req.URL)
|
||||
if len(via) > maxRedirects {
|
||||
return errors.New("too many redirects") // (this error message will never actually be seen)
|
||||
return errors.New("too many redirects") // this error will never actually be seen
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
@ -160,82 +174,84 @@ func (d *Disco) discover(host svchost.Hostname) Host {
|
|||
URL: discoURL,
|
||||
}
|
||||
|
||||
if creds, err := d.CredentialsForHost(host); err != nil {
|
||||
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err)
|
||||
} else if creds != nil {
|
||||
creds.PrepareRequest(req) // alters req to include credentials
|
||||
creds, err := d.CredentialsForHost(hostname)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", hostname, err)
|
||||
}
|
||||
if creds != nil {
|
||||
// Update the request to include credentials.
|
||||
creds.PrepareRequest(req)
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL)
|
||||
|
||||
ret := Host{
|
||||
discoURL: discoURL,
|
||||
}
|
||||
log.Printf("[DEBUG] Service discovery for %s at %s", hostname, discoURL)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to request discovery document: %s", err)
|
||||
return ret // empty
|
||||
return nil, fmt.Errorf("Failed to request discovery document: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
log.Printf("[WARN] Failed to request discovery document: %s", resp.Status)
|
||||
return ret // empty
|
||||
host := &Host{
|
||||
// Use the discovery URL from resp.Request in
|
||||
// case the client followed any redirects.
|
||||
discoURL: resp.Request.URL,
|
||||
hostname: hostname.ForDisplay(),
|
||||
}
|
||||
|
||||
// If the client followed any redirects, we will have a new URL to use
|
||||
// as our base for relative resolution.
|
||||
ret.discoURL = resp.Request.URL
|
||||
// Return the host without any services.
|
||||
if resp.StatusCode == 404 {
|
||||
return host, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("Failed to request discovery document: %s", resp.Status)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Discovery URL has malformed Content-Type %q", contentType)
|
||||
return ret // empty
|
||||
return nil, fmt.Errorf("Discovery URL has a malformed Content-Type %q", contentType)
|
||||
}
|
||||
if mediaType != "application/json" {
|
||||
log.Printf("[DEBUG] Discovery URL returned Content-Type %q, rather than application/json", mediaType)
|
||||
return ret // empty
|
||||
return nil, fmt.Errorf("Discovery URL returned an unsupported Content-Type %q", mediaType)
|
||||
}
|
||||
|
||||
// (this doesn't catch chunked encoding, because ContentLength is -1 in that case...)
|
||||
// This doesn't catch chunked encoding, because ContentLength is -1 in that case.
|
||||
if resp.ContentLength > maxDiscoDocBytes {
|
||||
// Size limit here is not a contractual requirement and so we may
|
||||
// adjust it over time if we find a different limit is warranted.
|
||||
log.Printf("[WARN] Discovery doc response is too large (got %d bytes; limit %d)", resp.ContentLength, maxDiscoDocBytes)
|
||||
return ret // empty
|
||||
return nil, fmt.Errorf(
|
||||
"Discovery doc response is too large (got %d bytes; limit %d)",
|
||||
resp.ContentLength, maxDiscoDocBytes,
|
||||
)
|
||||
}
|
||||
|
||||
// If the response is using chunked encoding then we can't predict
|
||||
// its size, but we'll at least prevent reading the entire thing into
|
||||
// memory.
|
||||
// If the response is using chunked encoding then we can't predict its
|
||||
// size, but we'll at least prevent reading the entire thing into memory.
|
||||
lr := io.LimitReader(resp.Body, maxDiscoDocBytes)
|
||||
|
||||
servicesBytes, err := ioutil.ReadAll(lr)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Error reading discovery document body: %s", err)
|
||||
return ret // empty
|
||||
return nil, fmt.Errorf("Error reading discovery document body: %v", err)
|
||||
}
|
||||
|
||||
var services map[string]interface{}
|
||||
err = json.Unmarshal(servicesBytes, &services)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to decode discovery document as a JSON object: %s", err)
|
||||
return ret // empty
|
||||
return nil, fmt.Errorf("Failed to decode discovery document as a JSON object: %v", err)
|
||||
}
|
||||
host.services = services
|
||||
|
||||
ret.services = services
|
||||
return ret
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// Forget invalidates any cached record of the given hostname. If the host
|
||||
// has no cache entry then this is a no-op.
|
||||
func (d *Disco) Forget(host svchost.Hostname) {
|
||||
delete(d.hostCache, host)
|
||||
func (d *Disco) Forget(hostname svchost.Hostname) {
|
||||
delete(d.hostCache, hostname)
|
||||
}
|
||||
|
||||
// ForgetAll is like Forget, but for all of the hostnames that have cache entries.
|
||||
func (d *Disco) ForgetAll() {
|
||||
d.hostCache = nil
|
||||
d.hostCache = make(map[svchost.Hostname]*Host)
|
||||
}
|
||||
|
|
|
@ -46,8 +46,15 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
|
||||
d := New()
|
||||
discovered := d.Discover(host)
|
||||
gotURL := discovered.ServiceURL("thingy.v1")
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
gotURL, err := discovered.ServiceURL("thingy.v1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for thingy.v1")
|
||||
}
|
||||
|
@ -81,8 +88,15 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
|
||||
d := New()
|
||||
discovered := d.Discover(host)
|
||||
gotURL := discovered.ServiceURL("wotsit.v2")
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
gotURL, err := discovered.ServiceURL("wotsit.v2")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for wotsit.v2")
|
||||
}
|
||||
|
@ -133,9 +147,15 @@ func TestDiscover(t *testing.T) {
|
|||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
discovered := d.Discover(host)
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
{
|
||||
gotURL := discovered.ServiceURL("thingy.v1")
|
||||
gotURL, err := discovered.ServiceURL("thingy.v1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for thingy.v1")
|
||||
}
|
||||
|
@ -144,7 +164,10 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
}
|
||||
{
|
||||
gotURL := discovered.ServiceURL("wotsit.v2")
|
||||
gotURL, err := discovered.ServiceURL("wotsit.v2")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for wotsit.v2")
|
||||
}
|
||||
|
@ -168,12 +191,14 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
|
||||
d := New()
|
||||
discovered := d.Discover(host)
|
||||
discovered, err := d.Discover(host)
|
||||
if err == nil {
|
||||
t.Fatalf("expected a discovery error")
|
||||
}
|
||||
|
||||
// result should be empty, which we can verify only by reaching into
|
||||
// its internals.
|
||||
if discovered.services != nil {
|
||||
t.Errorf("response not empty; should be")
|
||||
// Returned discovered should be nil.
|
||||
if discovered != nil {
|
||||
t.Errorf("discovered not nil; should be")
|
||||
}
|
||||
})
|
||||
t.Run("malformed JSON", func(t *testing.T) {
|
||||
|
@ -191,12 +216,14 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
|
||||
d := New()
|
||||
discovered := d.Discover(host)
|
||||
discovered, err := d.Discover(host)
|
||||
if err == nil {
|
||||
t.Fatalf("expected a discovery error")
|
||||
}
|
||||
|
||||
// result should be empty, which we can verify only by reaching into
|
||||
// its internals.
|
||||
if discovered.services != nil {
|
||||
t.Errorf("response not empty; should be")
|
||||
// Returned discovered should be nil.
|
||||
if discovered != nil {
|
||||
t.Errorf("discovered not nil; should be")
|
||||
}
|
||||
})
|
||||
t.Run("JSON with redundant charset", func(t *testing.T) {
|
||||
|
@ -218,7 +245,10 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
|
||||
d := New()
|
||||
discovered := d.Discover(host)
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
if discovered.services == nil {
|
||||
t.Errorf("response is empty; shouldn't be")
|
||||
|
@ -237,12 +267,14 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
|
||||
d := New()
|
||||
discovered := d.Discover(host)
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
// result should be empty, which we can verify only by reaching into
|
||||
// its internals.
|
||||
// Returned discovered.services should be nil (empty).
|
||||
if discovered.services != nil {
|
||||
t.Errorf("response not empty; should be")
|
||||
t.Errorf("discovered.services not nil (empty); should be")
|
||||
}
|
||||
})
|
||||
t.Run("redirect", func(t *testing.T) {
|
||||
|
@ -268,9 +300,15 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
|
||||
d := New()
|
||||
discovered := d.Discover(host)
|
||||
discovered, err := d.Discover(host)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected discovery error: %s", err)
|
||||
}
|
||||
|
||||
gotURL := discovered.ServiceURL("thingy.v1")
|
||||
gotURL, err := discovered.ServiceURL("thingy.v1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for thingy.v1")
|
||||
}
|
||||
|
|
|
@ -1,51 +1,95 @@
|
|||
package disco
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Host represents a service discovered host.
|
||||
type Host struct {
|
||||
discoURL *url.URL
|
||||
hostname string
|
||||
services map[string]interface{}
|
||||
}
|
||||
|
||||
// ErrServiceNotProvided is returned when the service is not provided.
|
||||
type ErrServiceNotProvided struct {
|
||||
hostname string
|
||||
service string
|
||||
}
|
||||
|
||||
// Error returns a customized error message.
|
||||
func (e *ErrServiceNotProvided) Error() string {
|
||||
return fmt.Sprintf("host %s does not provide a %s service", e.hostname, e.service)
|
||||
}
|
||||
|
||||
// ErrVersionNotSupported is returned when the version is not supported.
|
||||
type ErrVersionNotSupported struct {
|
||||
hostname string
|
||||
service string
|
||||
version string
|
||||
}
|
||||
|
||||
// Error returns a customized error message.
|
||||
func (e *ErrVersionNotSupported) Error() string {
|
||||
return fmt.Sprintf("host %s does not support %s version %s", e.hostname, e.service, e.version)
|
||||
}
|
||||
|
||||
// ServiceURL returns the URL associated with the given service identifier,
|
||||
// which should be of the form "servicename.vN".
|
||||
//
|
||||
// A non-nil result is always an absolute URL with a scheme of either https
|
||||
// or http.
|
||||
//
|
||||
// If the requested service is not supported by the host, this method returns
|
||||
// a nil URL.
|
||||
//
|
||||
// If the discovery document entry for the given service is invalid (not a URL),
|
||||
// it is treated as absent, also returning a nil URL.
|
||||
func (h Host) ServiceURL(id string) *url.URL {
|
||||
if h.services == nil {
|
||||
return nil // no services supported for an empty Host
|
||||
// A non-nil result is always an absolute URL with a scheme of either HTTPS
|
||||
// or HTTP.
|
||||
func (h *Host) ServiceURL(id string) (*url.URL, error) {
|
||||
parts := strings.SplitN(id, ".", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("Invalid service ID format (i.e. service.vN): %s", id)
|
||||
}
|
||||
service, version := parts[0], parts[1]
|
||||
|
||||
// No services supported for an empty Host.
|
||||
if h == nil || h.services == nil {
|
||||
return nil, &ErrServiceNotProvided{hostname: "<unknown>", service: service}
|
||||
}
|
||||
|
||||
urlStr, ok := h.services[id].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
// See if we have a matching service as that would indicate
|
||||
// the service is supported, but not the requested version.
|
||||
for serviceID := range h.services {
|
||||
if strings.HasPrefix(serviceID, service) {
|
||||
return nil, &ErrVersionNotSupported{
|
||||
hostname: h.hostname,
|
||||
service: service,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ret, err := url.Parse(urlStr)
|
||||
// No discovered services match the requested service ID.
|
||||
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: service}
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, fmt.Errorf("Failed to parse service URL: %v", err)
|
||||
}
|
||||
if !ret.IsAbs() {
|
||||
ret = h.discoURL.ResolveReference(ret) // make absolute using our discovery doc URL
|
||||
}
|
||||
if ret.Scheme != "https" && ret.Scheme != "http" {
|
||||
return nil
|
||||
}
|
||||
if ret.User != nil {
|
||||
// embedded username/password information is not permitted; credentials
|
||||
// are handled out of band.
|
||||
return nil
|
||||
}
|
||||
ret.Fragment = "" // fragment part is irrelevant, since we're not a browser
|
||||
|
||||
return h.discoURL.ResolveReference(ret)
|
||||
// Make relative URLs absolute using our discovery URL.
|
||||
if !u.IsAbs() {
|
||||
u = h.discoURL.ResolveReference(u)
|
||||
}
|
||||
|
||||
if u.Scheme != "https" && u.Scheme != "http" {
|
||||
return nil, fmt.Errorf("Service URL is using an unsupported scheme: %s", u.Scheme)
|
||||
}
|
||||
if u.User != nil {
|
||||
return nil, fmt.Errorf("Embedded username/password information is not permitted")
|
||||
}
|
||||
|
||||
// Fragment part is irrelevant, since we're not a browser.
|
||||
u.Fragment = ""
|
||||
|
||||
return h.discoURL.ResolveReference(u), nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package disco
|
|||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -9,6 +10,7 @@ func TestHostServiceURL(t *testing.T) {
|
|||
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
|
||||
host := Host{
|
||||
discoURL: baseURL,
|
||||
hostname: "test-server",
|
||||
services: map[string]interface{}{
|
||||
"absolute.v1": "http://example.net/foo/bar",
|
||||
"absolutewithport.v1": "http://example.net:8080/foo/bar",
|
||||
|
@ -24,22 +26,28 @@ func TestHostServiceURL(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
ID string
|
||||
Want string
|
||||
want string
|
||||
err string
|
||||
}{
|
||||
{"absolute.v1", "http://example.net/foo/bar"},
|
||||
{"absolutewithport.v1", "http://example.net:8080/foo/bar"},
|
||||
{"relative.v1", "https://example.com/disco/stu/"},
|
||||
{"rootrelative.v1", "https://example.com/baz"},
|
||||
{"protorelative.v1", "https://example.net/"},
|
||||
{"withfragment.v1", "http://example.org/"},
|
||||
{"querystring.v1", "https://example.net/baz?foo=bar"}, // most callers will disregard query string
|
||||
{"nothttp.v1", "<nil>"},
|
||||
{"invalid.v1", "<nil>"},
|
||||
{"absolute.v1", "http://example.net/foo/bar", ""},
|
||||
{"absolutewithport.v1", "http://example.net:8080/foo/bar", ""},
|
||||
{"relative.v1", "https://example.com/disco/stu/", ""},
|
||||
{"rootrelative.v1", "https://example.com/baz", ""},
|
||||
{"protorelative.v1", "https://example.net/", ""},
|
||||
{"withfragment.v1", "http://example.org/", ""},
|
||||
{"querystring.v1", "https://example.net/baz?foo=bar", ""},
|
||||
{"nothttp.v1", "<nil>", "unsupported scheme"},
|
||||
{"invalid.v1", "<nil>", "Failed to parse service URL"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.ID, func(t *testing.T) {
|
||||
url := host.ServiceURL(test.ID)
|
||||
url, err := host.ServiceURL(test.ID)
|
||||
if (err != nil || test.err != "") &&
|
||||
(err == nil || !strings.Contains(err.Error(), test.err)) {
|
||||
t.Fatalf("unexpected service URL error: %s", err)
|
||||
}
|
||||
|
||||
var got string
|
||||
if url != nil {
|
||||
got = url.String()
|
||||
|
@ -47,8 +55,8 @@ func TestHostServiceURL(t *testing.T) {
|
|||
got = "<nil>"
|
||||
}
|
||||
|
||||
if got != test.Want {
|
||||
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.Want)
|
||||
if got != test.want {
|
||||
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue