Merge pull request #19589 from hashicorp/svh/f-service-discovery

core: enhance service discovery
This commit is contained in:
Sander van Harmelen 2018-12-10 21:09:56 +01:00 committed by GitHub
commit ecc93c9889
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 265 additions and 167 deletions

View File

@ -302,9 +302,9 @@ func (b *Remote) discover(hostname string) (*url.URL, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
service := b.services.DiscoverServiceURL(host, serviceID) service, err := b.services.DiscoverServiceURL(host, serviceID)
if service == nil { if err != nil {
return nil, fmt.Errorf("host %s does not provide a remote backend API", host) return nil, err
} }
return service, nil return service, nil
} }

View File

@ -56,7 +56,7 @@ func TestRemote_config(t *testing.T) {
"prefix": cty.NullVal(cty.String), "prefix": cty.NullVal(cty.String),
}), }),
}), }),
confErr: "Host nonexisting.local does not provide a remote backend API", confErr: "Failed to request discovery document",
}, },
"with_a_name": { "with_a_name": {
config: cty.ObjectVal(map[string]cty.Value{ config: cty.ObjectVal(map[string]cty.Value{
@ -112,8 +112,8 @@ func TestRemote_config(t *testing.T) {
// Validate // Validate
valDiags := b.ValidateConfig(tc.config) valDiags := b.ValidateConfig(tc.config)
if (valDiags.Err() == nil && tc.valErr != "") || if (valDiags.Err() != nil || tc.valErr != "") &&
(valDiags.Err() != nil && !strings.Contains(valDiags.Err().Error(), tc.valErr)) { (valDiags.Err() == nil || !strings.Contains(valDiags.Err().Error(), tc.valErr)) {
t.Fatalf("%s: unexpected validation result: %v", name, valDiags.Err()) t.Fatalf("%s: unexpected validation result: %v", name, valDiags.Err())
} }

View File

@ -59,15 +59,15 @@ func NewClient(services *disco.Disco, client *http.Client) *Client {
} }
// Discover queries the host, and returns the url for the registry. // Discover queries the host, and returns the url for the registry.
func (c *Client) Discover(host svchost.Hostname, serviceID string) *url.URL { func (c *Client) Discover(host svchost.Hostname, serviceID string) (*url.URL, error) {
service := c.services.DiscoverServiceURL(host, serviceID) service, err := c.services.DiscoverServiceURL(host, serviceID)
if service == nil { if err != nil {
return nil return nil, err
} }
if !strings.HasSuffix(service.Path, "/") { if !strings.HasSuffix(service.Path, "/") {
service.Path += "/" service.Path += "/"
} }
return service return service, nil
} }
// ModuleVersions queries the registry for a module, and returns the available versions. // ModuleVersions queries the registry for a module, and returns the available versions.
@ -77,9 +77,9 @@ func (c *Client) ModuleVersions(module *regsrc.Module) (*response.ModuleVersions
return nil, err return nil, err
} }
service := c.Discover(host, modulesServiceID) service, err := c.Discover(host, modulesServiceID)
if service == nil { if err != nil {
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "modules"} return nil, err
} }
p, err := url.Parse(path.Join(module.Module(), "versions")) p, err := url.Parse(path.Join(module.Module(), "versions"))
@ -150,9 +150,9 @@ func (c *Client) ModuleLocation(module *regsrc.Module, version string) (string,
return "", err return "", err
} }
service := c.Discover(host, modulesServiceID) service, err := c.Discover(host, modulesServiceID)
if service == nil { if err != nil {
return "", &errServiceNotProvided{host: host.ForDisplay(), service: "modules"} return "", err
} }
var p *url.URL var p *url.URL
@ -234,9 +234,9 @@ func (c *Client) TerraformProviderVersions(provider *regsrc.TerraformProvider) (
return nil, err return nil, err
} }
service := c.Discover(host, providersServiceID) service, err := c.Discover(host, providersServiceID)
if service == nil { if err != nil {
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"} return nil, err
} }
p, err := url.Parse(path.Join(provider.TerraformProvider(), "versions")) p, err := url.Parse(path.Join(provider.TerraformProvider(), "versions"))
@ -288,9 +288,9 @@ func (c *Client) TerraformProviderLocation(provider *regsrc.TerraformProvider, v
return nil, err return nil, err
} }
service := c.Discover(host, providersServiceID) service, err := c.Discover(host, providersServiceID)
if service == nil { if err != nil {
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"} return nil, err
} }
p, err := url.Parse(path.Join( p, err := url.Parse(path.Join(

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/hashicorp/terraform/registry/regsrc" "github.com/hashicorp/terraform/registry/regsrc"
"github.com/hashicorp/terraform/svchost/disco"
) )
type errModuleNotFound struct { type errModuleNotFound struct {
@ -42,15 +43,6 @@ func IsProviderNotFound(err error) bool {
// error. This allows callers to recognize this particular error condition // error. This allows callers to recognize this particular error condition
// as distinct from operational errors such as poor network connectivity. // as distinct from operational errors such as poor network connectivity.
func IsServiceNotProvided(err error) bool { func IsServiceNotProvided(err error) bool {
_, ok := err.(*errServiceNotProvided) _, ok := err.(*disco.ErrServiceNotProvided)
return ok return ok
} }
type errServiceNotProvided struct {
host string
service string
}
func (e *errServiceNotProvided) Error() string {
return fmt.Sprintf("host %s does not provide %s", e.host, e.service)
}

View File

@ -8,6 +8,7 @@ package disco
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
@ -22,19 +23,27 @@ import (
) )
const ( const (
// Fixed path to the discovery manifest.
discoPath = "/.well-known/terraform.json" discoPath = "/.well-known/terraform.json"
maxRedirects = 3 // arbitrary-but-small number to prevent runaway redirect loops
discoTimeout = 11 * time.Second // arbitrary-but-small time limit to prevent UI "hangs" during discovery // Arbitrary-but-small number to prevent runaway redirect loops.
maxDiscoDocBytes = 1 * 1024 * 1024 // 1MB - to prevent abusive services from using loads of our memory maxRedirects = 3
// Arbitrary-but-small time limit to prevent UI "hangs" during discovery.
discoTimeout = 11 * time.Second
// 1MB - to prevent abusive services from using loads of our memory.
maxDiscoDocBytes = 1 * 1024 * 1024
) )
var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during tests, to skip TLS verification // httpTransport is overridden during tests, to skip TLS verification.
var httpTransport = cleanhttp.DefaultPooledTransport()
// Disco is the main type in this package, which allows discovery on given // Disco is the main type in this package, which allows discovery on given
// hostnames and caches the results by hostname to avoid repeated requests // hostnames and caches the results by hostname to avoid repeated requests
// for the same information. // for the same information.
type Disco struct { type Disco struct {
hostCache map[svchost.Hostname]Host hostCache map[svchost.Hostname]*Host
credsSrc auth.CredentialsSource credsSrc auth.CredentialsSource
// Transport is a custom http.RoundTripper to use. // Transport is a custom http.RoundTripper to use.
@ -50,7 +59,10 @@ func New() *Disco {
// NewWithCredentialsSource returns a new discovery object initialized with // NewWithCredentialsSource returns a new discovery object initialized with
// the given credentials source. // the given credentials source.
func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco { func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
return &Disco{credsSrc: credsSrc} return &Disco{
hostCache: make(map[svchost.Hostname]*Host),
credsSrc: credsSrc,
}
} }
// SetCredentialsSource provides a credentials source that will be used to // SetCredentialsSource provides a credentials source that will be used to
@ -64,11 +76,11 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
// CredentialsForHost returns a non-nil HostCredentials if the embedded source has // CredentialsForHost returns a non-nil HostCredentials if the embedded source has
// credentials available for the host, and a nil HostCredentials if it does not. // credentials available for the host, and a nil HostCredentials if it does not.
func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, error) { func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredentials, error) {
if d.credsSrc == nil { if d.credsSrc == nil {
return nil, nil return nil, nil
} }
return d.credsSrc.ForHost(host) return d.credsSrc.ForHost(hostname)
} }
// ForceHostServices provides a pre-defined set of services for a given // ForceHostServices provides a pre-defined set of services for a given
@ -81,19 +93,17 @@ func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials,
// discovery, yielding the same results as if the given map were published // discovery, yielding the same results as if the given map were published
// at the host's default discovery URL, though using absolute URLs is strongly // at the host's default discovery URL, though using absolute URLs is strongly
// recommended to make the configured behavior more explicit. // recommended to make the configured behavior more explicit.
func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]interface{}) { func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string]interface{}) {
if d.hostCache == nil {
d.hostCache = map[svchost.Hostname]Host{}
}
if services == nil { if services == nil {
services = map[string]interface{}{} services = map[string]interface{}{}
} }
d.hostCache[host] = Host{ d.hostCache[hostname] = &Host{
discoURL: &url.URL{ discoURL: &url.URL{
Scheme: "https", Scheme: "https",
Host: string(host), Host: string(hostname),
Path: discoPath, Path: discoPath,
}, },
hostname: hostname.ForDisplay(),
services: services, services: services,
} }
} }
@ -104,36 +114,40 @@ func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]int
// //
// If a given hostname supports no Terraform services at all, a non-nil but // If a given hostname supports no Terraform services at all, a non-nil but
// empty Host object is returned. When giving feedback to the end user about // empty Host object is returned. When giving feedback to the end user about
// such situations, we say e.g. "the host <name> doesn't provide a module // such situations, we say "host <name> does not provide a <service> service",
// registry", regardless of whether that is due to that service specifically // regardless of whether that is due to that service specifically being absent
// being absent or due to the host not providing Terraform services at all, // or due to the host not providing Terraform services at all, since we don't
// since we don't wish to expose the detail of whole-host discovery to an // wish to expose the detail of whole-host discovery to an end-user.
// end-user. func (d *Disco) Discover(hostname svchost.Hostname) (*Host, error) {
func (d *Disco) Discover(host svchost.Hostname) Host { if host, cached := d.hostCache[hostname]; cached {
if d.hostCache == nil { return host, nil
d.hostCache = map[svchost.Hostname]Host{}
}
if cache, cached := d.hostCache[host]; cached {
return cache
} }
ret := d.discover(host) host, err := d.discover(hostname)
d.hostCache[host] = ret if err != nil {
return ret return nil, err
}
d.hostCache[hostname] = host
return host, nil
} }
// DiscoverServiceURL is a convenience wrapper for discovery on a given // DiscoverServiceURL is a convenience wrapper for discovery on a given
// hostname and then looking up a particular service in the result. // hostname and then looking up a particular service in the result.
func (d *Disco) DiscoverServiceURL(host svchost.Hostname, serviceID string) *url.URL { func (d *Disco) DiscoverServiceURL(hostname svchost.Hostname, serviceID string) (*url.URL, error) {
return d.Discover(host).ServiceURL(serviceID) host, err := d.Discover(hostname)
if err != nil {
return nil, err
}
return host.ServiceURL(serviceID)
} }
// discover implements the actual discovery process, with its result cached // discover implements the actual discovery process, with its result cached
// by the public-facing Discover method. // by the public-facing Discover method.
func (d *Disco) discover(host svchost.Hostname) Host { func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
discoURL := &url.URL{ discoURL := &url.URL{
Scheme: "https", Scheme: "https",
Host: host.String(), Host: hostname.String(),
Path: discoPath, Path: discoPath,
} }
@ -149,7 +163,7 @@ func (d *Disco) discover(host svchost.Hostname) Host {
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
log.Printf("[DEBUG] Service discovery redirected to %s", req.URL) log.Printf("[DEBUG] Service discovery redirected to %s", req.URL)
if len(via) > maxRedirects { if len(via) > maxRedirects {
return errors.New("too many redirects") // (this error message will never actually be seen) return errors.New("too many redirects") // this error will never actually be seen
} }
return nil return nil
}, },
@ -160,82 +174,84 @@ func (d *Disco) discover(host svchost.Hostname) Host {
URL: discoURL, URL: discoURL,
} }
if creds, err := d.CredentialsForHost(host); err != nil { creds, err := d.CredentialsForHost(hostname)
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err) if err != nil {
} else if creds != nil { log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", hostname, err)
creds.PrepareRequest(req) // alters req to include credentials }
if creds != nil {
// Update the request to include credentials.
creds.PrepareRequest(req)
} }
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL) log.Printf("[DEBUG] Service discovery for %s at %s", hostname, discoURL)
ret := Host{
discoURL: discoURL,
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
log.Printf("[WARN] Failed to request discovery document: %s", err) return nil, fmt.Errorf("Failed to request discovery document: %v", err)
return ret // empty
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != 200 { host := &Host{
log.Printf("[WARN] Failed to request discovery document: %s", resp.Status) // Use the discovery URL from resp.Request in
return ret // empty // case the client followed any redirects.
discoURL: resp.Request.URL,
hostname: hostname.ForDisplay(),
} }
// If the client followed any redirects, we will have a new URL to use // Return the host without any services.
// as our base for relative resolution. if resp.StatusCode == 404 {
ret.discoURL = resp.Request.URL return host, nil
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("Failed to request discovery document: %s", resp.Status)
}
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType) mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil { if err != nil {
log.Printf("[WARN] Discovery URL has malformed Content-Type %q", contentType) return nil, fmt.Errorf("Discovery URL has a malformed Content-Type %q", contentType)
return ret // empty
} }
if mediaType != "application/json" { if mediaType != "application/json" {
log.Printf("[DEBUG] Discovery URL returned Content-Type %q, rather than application/json", mediaType) return nil, fmt.Errorf("Discovery URL returned an unsupported Content-Type %q", mediaType)
return ret // empty
} }
// (this doesn't catch chunked encoding, because ContentLength is -1 in that case...) // This doesn't catch chunked encoding, because ContentLength is -1 in that case.
if resp.ContentLength > maxDiscoDocBytes { if resp.ContentLength > maxDiscoDocBytes {
// Size limit here is not a contractual requirement and so we may // Size limit here is not a contractual requirement and so we may
// adjust it over time if we find a different limit is warranted. // adjust it over time if we find a different limit is warranted.
log.Printf("[WARN] Discovery doc response is too large (got %d bytes; limit %d)", resp.ContentLength, maxDiscoDocBytes) return nil, fmt.Errorf(
return ret // empty "Discovery doc response is too large (got %d bytes; limit %d)",
resp.ContentLength, maxDiscoDocBytes,
)
} }
// If the response is using chunked encoding then we can't predict // If the response is using chunked encoding then we can't predict its
// its size, but we'll at least prevent reading the entire thing into // size, but we'll at least prevent reading the entire thing into memory.
// memory.
lr := io.LimitReader(resp.Body, maxDiscoDocBytes) lr := io.LimitReader(resp.Body, maxDiscoDocBytes)
servicesBytes, err := ioutil.ReadAll(lr) servicesBytes, err := ioutil.ReadAll(lr)
if err != nil { if err != nil {
log.Printf("[WARN] Error reading discovery document body: %s", err) return nil, fmt.Errorf("Error reading discovery document body: %v", err)
return ret // empty
} }
var services map[string]interface{} var services map[string]interface{}
err = json.Unmarshal(servicesBytes, &services) err = json.Unmarshal(servicesBytes, &services)
if err != nil { if err != nil {
log.Printf("[WARN] Failed to decode discovery document as a JSON object: %s", err) return nil, fmt.Errorf("Failed to decode discovery document as a JSON object: %v", err)
return ret // empty
} }
host.services = services
ret.services = services return host, nil
return ret
} }
// Forget invalidates any cached record of the given hostname. If the host // Forget invalidates any cached record of the given hostname. If the host
// has no cache entry then this is a no-op. // has no cache entry then this is a no-op.
func (d *Disco) Forget(host svchost.Hostname) { func (d *Disco) Forget(hostname svchost.Hostname) {
delete(d.hostCache, host) delete(d.hostCache, hostname)
} }
// ForgetAll is like Forget, but for all of the hostnames that have cache entries. // ForgetAll is like Forget, but for all of the hostnames that have cache entries.
func (d *Disco) ForgetAll() { func (d *Disco) ForgetAll() {
d.hostCache = nil d.hostCache = make(map[svchost.Hostname]*Host)
} }

View File

@ -46,8 +46,15 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
gotURL := discovered.ServiceURL("thingy.v1") if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for thingy.v1") t.Fatalf("found no URL for thingy.v1")
} }
@ -81,8 +88,15 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
gotURL := discovered.ServiceURL("wotsit.v2") if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
gotURL, err := discovered.ServiceURL("wotsit.v2")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for wotsit.v2") t.Fatalf("found no URL for wotsit.v2")
} }
@ -133,9 +147,15 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err) t.Fatalf("test server hostname is invalid: %s", err)
} }
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
{ {
gotURL := discovered.ServiceURL("thingy.v1") gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for thingy.v1") t.Fatalf("found no URL for thingy.v1")
} }
@ -144,7 +164,10 @@ func TestDiscover(t *testing.T) {
} }
} }
{ {
gotURL := discovered.ServiceURL("wotsit.v2") gotURL, err := discovered.ServiceURL("wotsit.v2")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for wotsit.v2") t.Fatalf("found no URL for wotsit.v2")
} }
@ -168,12 +191,14 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err == nil {
t.Fatalf("expected a discovery error")
}
// result should be empty, which we can verify only by reaching into // Returned discovered should be nil.
// its internals. if discovered != nil {
if discovered.services != nil { t.Errorf("discovered not nil; should be")
t.Errorf("response not empty; should be")
} }
}) })
t.Run("malformed JSON", func(t *testing.T) { t.Run("malformed JSON", func(t *testing.T) {
@ -191,12 +216,14 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err == nil {
t.Fatalf("expected a discovery error")
}
// result should be empty, which we can verify only by reaching into // Returned discovered should be nil.
// its internals. if discovered != nil {
if discovered.services != nil { t.Errorf("discovered not nil; should be")
t.Errorf("response not empty; should be")
} }
}) })
t.Run("JSON with redundant charset", func(t *testing.T) { t.Run("JSON with redundant charset", func(t *testing.T) {
@ -218,7 +245,10 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
if discovered.services == nil { if discovered.services == nil {
t.Errorf("response is empty; shouldn't be") t.Errorf("response is empty; shouldn't be")
@ -237,12 +267,14 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
// result should be empty, which we can verify only by reaching into // Returned discovered.services should be nil (empty).
// its internals.
if discovered.services != nil { if discovered.services != nil {
t.Errorf("response not empty; should be") t.Errorf("discovered.services not nil (empty); should be")
} }
}) })
t.Run("redirect", func(t *testing.T) { t.Run("redirect", func(t *testing.T) {
@ -268,9 +300,15 @@ func TestDiscover(t *testing.T) {
} }
d := New() d := New()
discovered := d.Discover(host) discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
gotURL := discovered.ServiceURL("thingy.v1") gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil { if gotURL == nil {
t.Fatalf("found no URL for thingy.v1") t.Fatalf("found no URL for thingy.v1")
} }

View File

@ -1,51 +1,95 @@
package disco package disco
import ( import (
"fmt"
"net/url" "net/url"
"strings"
) )
// Host represents a service discovered host.
type Host struct { type Host struct {
discoURL *url.URL discoURL *url.URL
hostname string
services map[string]interface{} services map[string]interface{}
} }
// ErrServiceNotProvided is returned when the service is not provided.
type ErrServiceNotProvided struct {
hostname string
service string
}
// Error returns a customized error message.
func (e *ErrServiceNotProvided) Error() string {
return fmt.Sprintf("host %s does not provide a %s service", e.hostname, e.service)
}
// ErrVersionNotSupported is returned when the version is not supported.
type ErrVersionNotSupported struct {
hostname string
service string
version string
}
// Error returns a customized error message.
func (e *ErrVersionNotSupported) Error() string {
return fmt.Sprintf("host %s does not support %s version %s", e.hostname, e.service, e.version)
}
// ServiceURL returns the URL associated with the given service identifier, // ServiceURL returns the URL associated with the given service identifier,
// which should be of the form "servicename.vN". // which should be of the form "servicename.vN".
// //
// A non-nil result is always an absolute URL with a scheme of either https // A non-nil result is always an absolute URL with a scheme of either HTTPS
// or http. // or HTTP.
// func (h *Host) ServiceURL(id string) (*url.URL, error) {
// If the requested service is not supported by the host, this method returns parts := strings.SplitN(id, ".", 2)
// a nil URL. if len(parts) != 2 {
// return nil, fmt.Errorf("Invalid service ID format (i.e. service.vN): %s", id)
// If the discovery document entry for the given service is invalid (not a URL), }
// it is treated as absent, also returning a nil URL. service, version := parts[0], parts[1]
func (h Host) ServiceURL(id string) *url.URL {
if h.services == nil { // No services supported for an empty Host.
return nil // no services supported for an empty Host if h == nil || h.services == nil {
return nil, &ErrServiceNotProvided{hostname: "<unknown>", service: service}
} }
urlStr, ok := h.services[id].(string) urlStr, ok := h.services[id].(string)
if !ok { if !ok {
return nil // See if we have a matching service as that would indicate
// the service is supported, but not the requested version.
for serviceID := range h.services {
if strings.HasPrefix(serviceID, service) {
return nil, &ErrVersionNotSupported{
hostname: h.hostname,
service: service,
version: version,
}
}
} }
ret, err := url.Parse(urlStr) // No discovered services match the requested service ID.
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: service}
}
u, err := url.Parse(urlStr)
if err != nil { if err != nil {
return nil return nil, fmt.Errorf("Failed to parse service URL: %v", err)
} }
if !ret.IsAbs() {
ret = h.discoURL.ResolveReference(ret) // make absolute using our discovery doc URL
}
if ret.Scheme != "https" && ret.Scheme != "http" {
return nil
}
if ret.User != nil {
// embedded username/password information is not permitted; credentials
// are handled out of band.
return nil
}
ret.Fragment = "" // fragment part is irrelevant, since we're not a browser
return h.discoURL.ResolveReference(ret) // Make relative URLs absolute using our discovery URL.
if !u.IsAbs() {
u = h.discoURL.ResolveReference(u)
}
if u.Scheme != "https" && u.Scheme != "http" {
return nil, fmt.Errorf("Service URL is using an unsupported scheme: %s", u.Scheme)
}
if u.User != nil {
return nil, fmt.Errorf("Embedded username/password information is not permitted")
}
// Fragment part is irrelevant, since we're not a browser.
u.Fragment = ""
return h.discoURL.ResolveReference(u), nil
} }

View File

@ -2,6 +2,7 @@ package disco
import ( import (
"net/url" "net/url"
"strings"
"testing" "testing"
) )
@ -9,6 +10,7 @@ func TestHostServiceURL(t *testing.T) {
baseURL, _ := url.Parse("https://example.com/disco/foo.json") baseURL, _ := url.Parse("https://example.com/disco/foo.json")
host := Host{ host := Host{
discoURL: baseURL, discoURL: baseURL,
hostname: "test-server",
services: map[string]interface{}{ services: map[string]interface{}{
"absolute.v1": "http://example.net/foo/bar", "absolute.v1": "http://example.net/foo/bar",
"absolutewithport.v1": "http://example.net:8080/foo/bar", "absolutewithport.v1": "http://example.net:8080/foo/bar",
@ -24,22 +26,28 @@ func TestHostServiceURL(t *testing.T) {
tests := []struct { tests := []struct {
ID string ID string
Want string want string
err string
}{ }{
{"absolute.v1", "http://example.net/foo/bar"}, {"absolute.v1", "http://example.net/foo/bar", ""},
{"absolutewithport.v1", "http://example.net:8080/foo/bar"}, {"absolutewithport.v1", "http://example.net:8080/foo/bar", ""},
{"relative.v1", "https://example.com/disco/stu/"}, {"relative.v1", "https://example.com/disco/stu/", ""},
{"rootrelative.v1", "https://example.com/baz"}, {"rootrelative.v1", "https://example.com/baz", ""},
{"protorelative.v1", "https://example.net/"}, {"protorelative.v1", "https://example.net/", ""},
{"withfragment.v1", "http://example.org/"}, {"withfragment.v1", "http://example.org/", ""},
{"querystring.v1", "https://example.net/baz?foo=bar"}, // most callers will disregard query string {"querystring.v1", "https://example.net/baz?foo=bar", ""},
{"nothttp.v1", "<nil>"}, {"nothttp.v1", "<nil>", "unsupported scheme"},
{"invalid.v1", "<nil>"}, {"invalid.v1", "<nil>", "Failed to parse service URL"},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.ID, func(t *testing.T) { t.Run(test.ID, func(t *testing.T) {
url := host.ServiceURL(test.ID) url, err := host.ServiceURL(test.ID)
if (err != nil || test.err != "") &&
(err == nil || !strings.Contains(err.Error(), test.err)) {
t.Fatalf("unexpected service URL error: %s", err)
}
var got string var got string
if url != nil { if url != nil {
got = url.String() got = url.String()
@ -47,8 +55,8 @@ func TestHostServiceURL(t *testing.T) {
got = "<nil>" got = "<nil>"
} }
if got != test.Want { if got != test.want {
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.Want) t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.want)
} }
}) })
} }