core: enhance service discovery

This PR improves the error handling so we can provide better feedback about any service discovery errors that occured.

Additionally it adds logic to test for specific versions when discovering a service using `service.vN`. This will enable more informational errors which can indicate any version incompatibilities.
This commit is contained in:
Sander van Harmelen 2018-12-10 11:06:05 +01:00
parent c77fe806f5
commit a5a2156584
8 changed files with 265 additions and 167 deletions

View File

@ -302,9 +302,9 @@ func (b *Remote) discover(hostname string) (*url.URL, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
service := b.services.DiscoverServiceURL(host, serviceID) service, err := b.services.DiscoverServiceURL(host, serviceID)
if service == nil { if err != nil {
return nil, fmt.Errorf("host %s does not provide a remote backend API", host) return nil, err
} }
return service, nil return service, nil
} }

View File

@ -56,7 +56,7 @@ func TestRemote_config(t *testing.T) {
"prefix": cty.NullVal(cty.String), "prefix": cty.NullVal(cty.String),
}), }),
}), }),
confErr: "Host nonexisting.local does not provide a remote backend API", confErr: "Failed to request discovery document",
}, },
"with_a_name": { "with_a_name": {
config: cty.ObjectVal(map[string]cty.Value{ config: cty.ObjectVal(map[string]cty.Value{
@ -112,8 +112,8 @@ func TestRemote_config(t *testing.T) {
// Validate // Validate
valDiags := b.ValidateConfig(tc.config) valDiags := b.ValidateConfig(tc.config)
if (valDiags.Err() == nil && tc.valErr != "") || if (valDiags.Err() != nil || tc.valErr != "") &&
(valDiags.Err() != nil && !strings.Contains(valDiags.Err().Error(), tc.valErr)) { (valDiags.Err() == nil || !strings.Contains(valDiags.Err().Error(), tc.valErr)) {
t.Fatalf("%s: unexpected validation result: %v", name, valDiags.Err()) t.Fatalf("%s: unexpected validation result: %v", name, valDiags.Err())
} }

View File

@ -59,15 +59,15 @@ func NewClient(services *disco.Disco, client *http.Client) *Client {
} }
// Discover queries the host, and returns the url for the registry. // Discover queries the host, and returns the url for the registry.
func (c *Client) Discover(host svchost.Hostname, serviceID string) *url.URL { func (c *Client) Discover(host svchost.Hostname, serviceID string) (*url.URL, error) {
service := c.services.DiscoverServiceURL(host, serviceID) service, err := c.services.DiscoverServiceURL(host, serviceID)
if service == nil { if err != nil {
return nil return nil, err
} }
if !strings.HasSuffix(service.Path, "/") { if !strings.HasSuffix(service.Path, "/") {
service.Path += "/" service.Path += "/"
} }
return service return service, nil
} }
// ModuleVersions queries the registry for a module, and returns the available versions. // ModuleVersions queries the registry for a module, and returns the available versions.
@ -77,9 +77,9 @@ func (c *Client) ModuleVersions(module *regsrc.Module) (*response.ModuleVersions
return nil, err return nil, err
} }
service := c.Discover(host, modulesServiceID) service, err := c.Discover(host, modulesServiceID)
if service == nil { if err != nil {
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "modules"} return nil, err
} }
p, err := url.Parse(path.Join(module.Module(), "versions")) p, err := url.Parse(path.Join(module.Module(), "versions"))
@ -150,9 +150,9 @@ func (c *Client) ModuleLocation(module *regsrc.Module, version string) (string,
return "", err return "", err
} }
service := c.Discover(host, modulesServiceID) service, err := c.Discover(host, modulesServiceID)
if service == nil { if err != nil {
return "", &errServiceNotProvided{host: host.ForDisplay(), service: "modules"} return "", err
} }
var p *url.URL var p *url.URL
@ -234,9 +234,9 @@ func (c *Client) TerraformProviderVersions(provider *regsrc.TerraformProvider) (
return nil, err return nil, err
} }
service := c.Discover(host, providersServiceID) service, err := c.Discover(host, providersServiceID)
if service == nil { if err != nil {
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"} return nil, err
} }
p, err := url.Parse(path.Join(provider.TerraformProvider(), "versions")) p, err := url.Parse(path.Join(provider.TerraformProvider(), "versions"))
@ -288,9 +288,9 @@ func (c *Client) TerraformProviderLocation(provider *regsrc.TerraformProvider, v
return nil, err return nil, err
} }
service := c.Discover(host, providersServiceID) service, err := c.Discover(host, providersServiceID)
if service == nil { if err != nil {
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"} return nil, err
} }
p, err := url.Parse(path.Join( p, err := url.Parse(path.Join(

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/hashicorp/terraform/registry/regsrc" "github.com/hashicorp/terraform/registry/regsrc"
"github.com/hashicorp/terraform/svchost/disco"
) )
type errModuleNotFound struct { type errModuleNotFound struct {
@ -42,15 +43,6 @@ func IsProviderNotFound(err error) bool {
// error. This allows callers to recognize this particular error condition // error. This allows callers to recognize this particular error condition
// as distinct from operational errors such as poor network connectivity. // as distinct from operational errors such as poor network connectivity.
func IsServiceNotProvided(err error) bool { func IsServiceNotProvided(err error) bool {
_, ok := err.(*errServiceNotProvided) _, ok := err.(*disco.ErrServiceNotProvided)
return ok return ok
} }
type errServiceNotProvided struct {
host string
service string
}
func (e *errServiceNotProvided) Error() string {
return fmt.Sprintf("host %s does not provide %s", e.host, e.service)
}

View File

@ -8,6 +8,7 @@ package disco
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
@ -22,19 +23,27 @@ import (
) )
const ( const (
// Fixed path to the discovery manifest.
discoPath = "/.well-known/terraform.json" discoPath = "/.well-known/terraform.json"
maxRedirects = 3 // arbitrary-but-small number to prevent runaway redirect loops
discoTimeout = 11 * time.Second // arbitrary-but-small time limit to prevent UI "hangs" during discovery // Arbitrary-but-small number to prevent runaway redirect loops.
maxDiscoDocBytes = 1 * 1024 * 1024 // 1MB - to prevent abusive services from using loads of our memory maxRedirects = 3
// Arbitrary-but-small time limit to prevent UI "hangs" during discovery.
discoTimeout = 11 * time.Second
// 1MB - to prevent abusive services from using loads of our memory.
maxDiscoDocBytes = 1 * 1024 * 1024
) )
var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during tests, to skip TLS verification // httpTransport is overridden during tests, to skip TLS verification.
var httpTransport = cleanhttp.DefaultPooledTransport()
// Disco is the main type in this package, which allows discovery on given // Disco is the main type in this package, which allows discovery on given
// hostnames and caches the results by hostname to avoid repeated requests // hostnames and caches the results by hostname to avoid repeated requests
// for the same information. // for the same information.
type Disco struct { type Disco struct {
hostCache map[svchost.Hostname]Host hostCache map[svchost.Hostname]*Host
credsSrc auth.CredentialsSource credsSrc auth.CredentialsSource
// Transport is a custom http.RoundTripper to use. // Transport is a custom http.RoundTripper to use.
@ -50,7 +59,10 @@ func New() *Disco {
// NewWithCredentialsSource returns a new discovery object initialized with // NewWithCredentialsSource returns a new discovery object initialized with
// the given credentials source. // the given credentials source.
func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco { func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
return &Disco{credsSrc: credsSrc} return &Disco{
hostCache: make(map[svchost.Hostname]*Host),
credsSrc: credsSrc,
}
} }
// SetCredentialsSource provides a credentials source that will be used to // SetCredentialsSource provides a credentials source that will be used to
@ -64,11 +76,11 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
// CredentialsForHost returns a non-nil HostCredentials if the embedded source has // CredentialsForHost returns a non-nil HostCredentials if the embedded source has
// credentials available for the host, and a nil HostCredentials if it does not. // credentials available for the host, and a nil HostCredentials if it does not.
func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, error) { func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredentials, error) {
if d.credsSrc == nil { if d.credsSrc == nil {
return nil, nil return nil, nil
} }
return d.credsSrc.ForHost(host) return d.credsSrc.ForHost(hostname)
} }
// ForceHostServices provides a pre-defined set of services for a given // ForceHostServices provides a pre-defined set of services for a given
@ -81,19 +93,17 @@ func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials,
// discovery, yielding the same results as if the given map were published // discovery, yielding the same results as if the given map were published
// at the host's default discovery URL, though using absolute URLs is strongly // at the host's default discovery URL, though using absolute URLs is strongly
// recommended to make the configured behavior more explicit. // recommended to make the configured behavior more explicit.
func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]interface{}) { func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string]interface{}) {
if d.hostCache == nil {
d.hostCache = map[svchost.Hostname]Host{}
}
if services == nil { if services == nil {
services = map[string]interface{}{} services = map[string]interface{}{}
} }
d.hostCache[host] = Host{ d.hostCache[hostname] = &Host{
discoURL: &url.URL{ discoURL: &url.URL{
Scheme: "https", Scheme: "https",
Host: string(host), Host: string(hostname),
Path: discoPath, Path: discoPath,
}, },
hostname: hostname.ForDisplay(),
services: services, services: services,
} }
} }
@ -104,36 +114,40 @@ func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]int
// //
// If a given hostname supports no Terraform services at all, a non-nil but // If a given hostname supports no Terraform services at all, a non-nil but
// empty Host object is returned. When giving feedback to the end user about // empty Host object is returned. When giving feedback to the end user about
// such situations, we say e.g. "the host <name> doesn't provide a module // such situations, we say "host <name> does not provide a <service> service",
// registry", regardless of whether that is due to that service specifically // regardless of whether that is due to that service specifically being absent
// being absent or due to the host not providing Terraform services at all, // or due to the host not providing Terraform services at all, since we don't
// since we don't wish to expose the detail of whole-host discovery to an // wish to expose the detail of whole-host discovery to an end-user.
// end-user. func (d *Disco) Discover(hostname svchost.Hostname) (*Host, error) {
func (d *Disco) Discover(host svchost.Hostname) Host { if host, cached := d.hostCache[hostname]; cached {
if d.hostCache == nil { return host, nil
d.hostCache = map[svchost.Hostname]Host{}
}
if cache, cached := d.hostCache[host]; cached {
return cache
} }
ret := d.discover(host) host, err := d.discover(hostname)
d.hostCache[host] = ret if err != nil {
return ret return nil, err
}
d.hostCache[hostname] = host
return host, nil
} }
// DiscoverServiceURL is a convenience wrapper for discovery on a given // DiscoverServiceURL is a convenience wrapper for discovery on a given
// hostname and then looking up a particular service in the result. // hostname and then looking up a particular service in the result.
func (d *Disco) DiscoverServiceURL(host svchost.Hostname, serviceID string) *url.URL { func (d *Disco) DiscoverServiceURL(hostname svchost.Hostname, serviceID string) (*url.URL, error) {
return d.Discover(host).ServiceURL(serviceID) host, err := d.Discover(hostname)
if err != nil {
return nil, err
}
return host.ServiceURL(serviceID)
} }
// discover implements the actual discovery process, with its result cached // discover implements the actual discovery process, with its result cached
// by the public-facing Discover method. // by the public-facing Discover method.
func (d *Disco) discover(host svchost.Hostname) Host { func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
discoURL := &url.URL{ discoURL := &url.URL{
Scheme: "https", Scheme: "https",
Host: host.String(), Host: hostname.String(),
Path: discoPath, Path: discoPath,
} }
@ -149,7 +163,7 @@ func (d *Disco) discover(host svchost.Hostname) Host {
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
log.Printf("[DEBUG] Service discovery redirected to %s", req.URL) log.Printf("[DEBUG] Service discovery redirected to %s", req.URL)
if len(via) > maxRedirects { if len(via) > maxRedirects {
return errors.New("too many redirects") // (this error message will never actually be seen) return errors.New("too many redirects") // this error will never actually be seen
} }
return nil return nil
}, },
@ -160,82 +174,84 @@ func (d *Disco) discover(host svchost.Hostname) Host {
URL: discoURL, URL: discoURL,
} }
if creds, err := d.CredentialsForHost(host); err != nil { creds, err := d.CredentialsForHost(hostname)
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err) if err != nil {
} else if creds != nil { log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", hostname, err)
creds.PrepareRequest(req) // alters req to include credentials }
if creds != nil {
// Update the request to include credentials.
creds.PrepareRequest(req)
} }
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL) log.Printf("[DEBUG] Service discovery for %s at %s", hostname, discoURL)
ret := Host{
discoURL: discoURL,
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
log.Printf("[WARN] Failed to request discovery document: %s", err) return nil, fmt.Errorf("Failed to request discovery document: %v", err)
return ret // empty
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != 200 { host := &Host{
log.Printf("[WARN] Failed to request discovery document: %s", resp.Status) // Use the discovery URL from resp.Request in
return ret // empty // case the client followed any redirects.
discoURL: resp.Request.URL,
hostname: hostname.ForDisplay(),
} }
// If the client followed any redirects, we will have a new URL to use // Return the host without any services.
// as our base for relative resolution. if resp.StatusCode == 404 {
ret.discoURL = resp.Request.URL return host, nil
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("Failed to request discovery document: %s", resp.Status)
}
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType) mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil { if err != nil {
log.Printf("[WARN] Discovery URL has malformed Content-Type %q", contentType) return nil, fmt.Errorf("Discovery URL has a malformed Content-Type %q", contentType)
return ret // empty
} }
if mediaType != "application/json" { if mediaType != "application/json" {
log.Printf("[DEBUG] Discovery URL returned Content-Type %q, rather than application/json", mediaType) return nil, fmt.Errorf("Discovery URL returned an unsupported Content-Type %q", mediaType)
return ret // empty
} }
// (this doesn't catch chunked encoding, because ContentLength is -1 in that case...) // This doesn't catch chunked encoding, because ContentLength is -1 in that case.
if resp.ContentLength > maxDiscoDocBytes { if resp.ContentLength > maxDiscoDocBytes {
// Size limit here is not a contractual requirement and so we may // Size limit here is not a contractual requirement and so we may
// adjust it over time if we find a different limit is warranted. // adjust it over time if we find a different limit is warranted.
log.Printf("[WARN] Discovery doc response is too large (got %d bytes; limit %d)", resp.ContentLength, maxDiscoDocBytes) return nil, fmt.Errorf(
return ret // empty "Discovery doc response is too large (got %d bytes; limit %d)",
resp.ContentLength, maxDiscoDocBytes,
)
} }
// If the response is using chunked encoding then we can't predict // If the response is using chunked encoding then we can't predict its
// its size, but we'll at least prevent reading the entire thing into // size, but we'll at least prevent reading the entire thing into memory.
// memory.
lr := io.LimitReader(resp.Body, maxDiscoDocBytes) lr := io.LimitReader(resp.Body, maxDiscoDocBytes)
servicesBytes, err := ioutil.ReadAll(lr) servicesBytes, err := ioutil.ReadAll(lr)
if err != nil { if err != nil {
log.Printf("[WARN] Error reading discovery document body: %s", err) return nil, fmt.Errorf("Error reading discovery document body: %v", err)
return ret // empty
} }
var services map[string]interface{} var services map[string]interface{}
err = json.Unmarshal(servicesBytes, &services) err = json.Unmarshal(servicesBytes, &services)
if err != nil { if err != nil {
log.Printf("[WARN] Failed to decode discovery document as a JSON object: %s", err) return nil, fmt.Errorf("Failed to decode discovery document as a JSON object: %v", err)
return ret // empty
} }
host.services = services
ret.services = services return host, nil
return ret
} }
// Forget invalidates any cached record of the given hostname. If the host // Forget invalidates any cached record of the given hostname. If the host
// has no cache entry then this is a no-op. // has no cache entry then this is a no-op.
func (d *Disco) Forget(host svchost.Hostname) { func (d *Disco) Forget(hostname svchost.Hostname) {
delete(d.hostCache, host) delete(d.hostCache, hostname)
} }
// ForgetAll is like Forget, but for all of the hostnames that have cache entries. // ForgetAll is like Forget, but for all of the hostnames that have cache entries.
func (d *Disco) ForgetAll() { func (d *Disco) ForgetAll() {
d.hostCache = nil d.hostCache = make(map[svchost.Hostname]*Host)
} }

View File

@ -46,8 +46,15 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
gotURL := discovered.ServiceURL("thingy.v1") if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for thingy.v1") t.Fatalf("found no URL for thingy.v1")
} }
@ -81,8 +88,15 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
gotURL := discovered.ServiceURL("wotsit.v2") if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
gotURL, err := discovered.ServiceURL("wotsit.v2")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for wotsit.v2") t.Fatalf("found no URL for wotsit.v2")
} }
@ -133,9 +147,15 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err) t.Fatalf("test server hostname is invalid: %s", err)
} }
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
{ {
gotURL := discovered.ServiceURL("thingy.v1") gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for thingy.v1") t.Fatalf("found no URL for thingy.v1")
} }
@ -144,7 +164,10 @@ func TestDiscover(t *testing.T) {
} }
} }
{ {
gotURL := discovered.ServiceURL("wotsit.v2") gotURL, err := discovered.ServiceURL("wotsit.v2")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for wotsit.v2") t.Fatalf("found no URL for wotsit.v2")
} }
@ -168,12 +191,14 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err == nil {
t.Fatalf("expected a discovery error")
}
// result should be empty, which we can verify only by reaching into // Returned discovered should be nil.
// its internals. if discovered != nil {
if discovered.services != nil { t.Errorf("discovered not nil; should be")
t.Errorf("response not empty; should be")
} }
}) })
t.Run("malformed JSON", func(t *testing.T) { t.Run("malformed JSON", func(t *testing.T) {
@ -191,12 +216,14 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err == nil {
t.Fatalf("expected a discovery error")
}
// result should be empty, which we can verify only by reaching into // Returned discovered should be nil.
// its internals. if discovered != nil {
if discovered.services != nil { t.Errorf("discovered not nil; should be")
t.Errorf("response not empty; should be")
} }
}) })
t.Run("JSON with redundant charset", func(t *testing.T) { t.Run("JSON with redundant charset", func(t *testing.T) {
@ -218,7 +245,10 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
if discovered.services == nil { if discovered.services == nil {
t.Errorf("response is empty; shouldn't be") t.Errorf("response is empty; shouldn't be")
@ -237,12 +267,14 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
// result should be empty, which we can verify only by reaching into // Returned discovered.services should be nil (empty).
// its internals.
if discovered.services != nil { if discovered.services != nil {
t.Errorf("response not empty; should be") t.Errorf("discovered.services not nil (empty); should be")
} }
}) })
t.Run("redirect", func(t *testing.T) { t.Run("redirect", func(t *testing.T) {
@ -268,9 +300,15 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
gotURL := discovered.ServiceURL("thingy.v1") gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for thingy.v1") t.Fatalf("found no URL for thingy.v1")
} }

View File

@ -1,51 +1,95 @@
package disco package disco
import ( import (
"fmt"
"net/url" "net/url"
"strings"
) )
// Host represents a service discovered host.
type Host struct { type Host struct {
discoURL *url.URL discoURL *url.URL
hostname string
services map[string]interface{} services map[string]interface{}
} }
// ErrServiceNotProvided is returned when the service is not provided.
type ErrServiceNotProvided struct {
hostname string
service string
}
// Error returns a customized error message.
func (e *ErrServiceNotProvided) Error() string {
return fmt.Sprintf("host %s does not provide a %s service", e.hostname, e.service)
}
// ErrVersionNotSupported is returned when the version is not supported.
type ErrVersionNotSupported struct {
hostname string
service string
version string
}
// Error returns a customized error message.
func (e *ErrVersionNotSupported) Error() string {
return fmt.Sprintf("host %s does not support %s version %s", e.hostname, e.service, e.version)
}
// ServiceURL returns the URL associated with the given service identifier, // ServiceURL returns the URL associated with the given service identifier,
// which should be of the form "servicename.vN". // which should be of the form "servicename.vN".
// //
// A non-nil result is always an absolute URL with a scheme of either https // A non-nil result is always an absolute URL with a scheme of either HTTPS
// or http. // or HTTP.
// func (h *Host) ServiceURL(id string) (*url.URL, error) {
// If the requested service is not supported by the host, this method returns parts := strings.SplitN(id, ".", 2)
// a nil URL. if len(parts) != 2 {
// return nil, fmt.Errorf("Invalid service ID format (i.e. service.vN): %s", id)
// If the discovery document entry for the given service is invalid (not a URL), }
// it is treated as absent, also returning a nil URL. service, version := parts[0], parts[1]
func (h Host) ServiceURL(id string) *url.URL {
if h.services == nil { // No services supported for an empty Host.
return nil // no services supported for an empty Host if h == nil || h.services == nil {
return nil, &ErrServiceNotProvided{hostname: "<unknown>", service: service}
} }
urlStr, ok := h.services[id].(string) urlStr, ok := h.services[id].(string)
if !ok { if !ok {
return nil // See if we have a matching service as that would indicate
// the service is supported, but not the requested version.
for serviceID := range h.services {
if strings.HasPrefix(serviceID, service) {
return nil, &ErrVersionNotSupported{
hostname: h.hostname,
service: service,
version: version,
}
}
} }
ret, err := url.Parse(urlStr) // No discovered services match the requested service ID.
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: service}
}
u, err := url.Parse(urlStr)
if err != nil { if err != nil {
return nil return nil, fmt.Errorf("Failed to parse service URL: %v", err)
} }
if !ret.IsAbs() {
ret = h.discoURL.ResolveReference(ret) // make absolute using our discovery doc URL
}
if ret.Scheme != "https" && ret.Scheme != "http" {
return nil
}
if ret.User != nil {
// embedded username/password information is not permitted; credentials
// are handled out of band.
return nil
}
ret.Fragment = "" // fragment part is irrelevant, since we're not a browser
return h.discoURL.ResolveReference(ret) // Make relative URLs absolute using our discovery URL.
if !u.IsAbs() {
u = h.discoURL.ResolveReference(u)
}
if u.Scheme != "https" && u.Scheme != "http" {
return nil, fmt.Errorf("Service URL is using an unsupported scheme: %s", u.Scheme)
}
if u.User != nil {
return nil, fmt.Errorf("Embedded username/password information is not permitted")
}
// Fragment part is irrelevant, since we're not a browser.
u.Fragment = ""
return h.discoURL.ResolveReference(u), nil
} }

View File

@ -2,6 +2,7 @@ package disco
import ( import (
"net/url" "net/url"
"strings"
"testing" "testing"
) )
@ -9,6 +10,7 @@ func TestHostServiceURL(t *testing.T) {
baseURL, _ := url.Parse("https://example.com/disco/foo.json") baseURL, _ := url.Parse("https://example.com/disco/foo.json")
host := Host{ host := Host{
discoURL: baseURL, discoURL: baseURL,
hostname: "test-server",
services: map[string]interface{}{ services: map[string]interface{}{
"absolute.v1": "http://example.net/foo/bar", "absolute.v1": "http://example.net/foo/bar",
"absolutewithport.v1": "http://example.net:8080/foo/bar", "absolutewithport.v1": "http://example.net:8080/foo/bar",
@ -24,22 +26,28 @@ func TestHostServiceURL(t *testing.T) {
tests := []struct { tests := []struct {
ID string ID string
Want string want string
err string
}{ }{
{"absolute.v1", "http://example.net/foo/bar"}, {"absolute.v1", "http://example.net/foo/bar", ""},
{"absolutewithport.v1", "http://example.net:8080/foo/bar"}, {"absolutewithport.v1", "http://example.net:8080/foo/bar", ""},
{"relative.v1", "https://example.com/disco/stu/"}, {"relative.v1", "https://example.com/disco/stu/", ""},
{"rootrelative.v1", "https://example.com/baz"}, {"rootrelative.v1", "https://example.com/baz", ""},
{"protorelative.v1", "https://example.net/"}, {"protorelative.v1", "https://example.net/", ""},
{"withfragment.v1", "http://example.org/"}, {"withfragment.v1", "http://example.org/", ""},
{"querystring.v1", "https://example.net/baz?foo=bar"}, // most callers will disregard query string {"querystring.v1", "https://example.net/baz?foo=bar", ""},
{"nothttp.v1", "<nil>"}, {"nothttp.v1", "<nil>", "unsupported scheme"},
{"invalid.v1", "<nil>"}, {"invalid.v1", "<nil>", "Failed to parse service URL"},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.ID, func(t *testing.T) { t.Run(test.ID, func(t *testing.T) {
url := host.ServiceURL(test.ID) url, err := host.ServiceURL(test.ID)
if (err != nil || test.err != "") &&
(err == nil || !strings.Contains(err.Error(), test.err)) {
t.Fatalf("unexpected service URL error: %s", err)
}
var got string var got string
if url != nil { if url != nil {
got = url.String() got = url.String()
@ -47,8 +55,8 @@ func TestHostServiceURL(t *testing.T) {
got = "<nil>" got = "<nil>"
} }
if got != test.Want { if got != test.want {
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.Want) t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.want)
} }
}) })
} }