From a5a2156584dd867ba47a5a34349fa9168ba81910 Mon Sep 17 00:00:00 2001 From: Sander van Harmelen Date: Mon, 10 Dec 2018 11:06:05 +0100 Subject: [PATCH] core: enhance service discovery This PR improves the error handling so we can provide better feedback about any service discovery errors that occured. Additionally it adds logic to test for specific versions when discovering a service using `service.vN`. This will enable more informational errors which can indicate any version incompatibilities. --- backend/remote/backend.go | 6 +- backend/remote/backend_test.go | 6 +- registry/client.go | 34 +++---- registry/errors.go | 12 +-- svchost/disco/disco.go | 156 ++++++++++++++++++--------------- svchost/disco/disco_test.go | 86 +++++++++++++----- svchost/disco/host.go | 98 +++++++++++++++------ svchost/disco/host_test.go | 34 ++++--- 8 files changed, 265 insertions(+), 167 deletions(-) diff --git a/backend/remote/backend.go b/backend/remote/backend.go index 17bb4c912..97b2a148e 100644 --- a/backend/remote/backend.go +++ b/backend/remote/backend.go @@ -302,9 +302,9 @@ func (b *Remote) discover(hostname string) (*url.URL, error) { if err != nil { return nil, err } - service := b.services.DiscoverServiceURL(host, serviceID) - if service == nil { - return nil, fmt.Errorf("host %s does not provide a remote backend API", host) + service, err := b.services.DiscoverServiceURL(host, serviceID) + if err != nil { + return nil, err } return service, nil } diff --git a/backend/remote/backend_test.go b/backend/remote/backend_test.go index db476a231..8aa888cc0 100644 --- a/backend/remote/backend_test.go +++ b/backend/remote/backend_test.go @@ -56,7 +56,7 @@ func TestRemote_config(t *testing.T) { "prefix": cty.NullVal(cty.String), }), }), - confErr: "Host nonexisting.local does not provide a remote backend API", + confErr: "Failed to request discovery document", }, "with_a_name": { config: cty.ObjectVal(map[string]cty.Value{ @@ -112,8 +112,8 @@ func TestRemote_config(t *testing.T) { // Validate valDiags := b.ValidateConfig(tc.config) - if (valDiags.Err() == nil && tc.valErr != "") || - (valDiags.Err() != nil && !strings.Contains(valDiags.Err().Error(), tc.valErr)) { + if (valDiags.Err() != nil || tc.valErr != "") && + (valDiags.Err() == nil || !strings.Contains(valDiags.Err().Error(), tc.valErr)) { t.Fatalf("%s: unexpected validation result: %v", name, valDiags.Err()) } diff --git a/registry/client.go b/registry/client.go index cdd33dc9e..0b90790d4 100644 --- a/registry/client.go +++ b/registry/client.go @@ -59,15 +59,15 @@ func NewClient(services *disco.Disco, client *http.Client) *Client { } // Discover queries the host, and returns the url for the registry. -func (c *Client) Discover(host svchost.Hostname, serviceID string) *url.URL { - service := c.services.DiscoverServiceURL(host, serviceID) - if service == nil { - return nil +func (c *Client) Discover(host svchost.Hostname, serviceID string) (*url.URL, error) { + service, err := c.services.DiscoverServiceURL(host, serviceID) + if err != nil { + return nil, err } if !strings.HasSuffix(service.Path, "/") { service.Path += "/" } - return service + return service, nil } // ModuleVersions queries the registry for a module, and returns the available versions. @@ -77,9 +77,9 @@ func (c *Client) ModuleVersions(module *regsrc.Module) (*response.ModuleVersions return nil, err } - service := c.Discover(host, modulesServiceID) - if service == nil { - return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "modules"} + service, err := c.Discover(host, modulesServiceID) + if err != nil { + return nil, err } p, err := url.Parse(path.Join(module.Module(), "versions")) @@ -150,9 +150,9 @@ func (c *Client) ModuleLocation(module *regsrc.Module, version string) (string, return "", err } - service := c.Discover(host, modulesServiceID) - if service == nil { - return "", &errServiceNotProvided{host: host.ForDisplay(), service: "modules"} + service, err := c.Discover(host, modulesServiceID) + if err != nil { + return "", err } var p *url.URL @@ -234,9 +234,9 @@ func (c *Client) TerraformProviderVersions(provider *regsrc.TerraformProvider) ( return nil, err } - service := c.Discover(host, providersServiceID) - if service == nil { - return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"} + service, err := c.Discover(host, providersServiceID) + if err != nil { + return nil, err } p, err := url.Parse(path.Join(provider.TerraformProvider(), "versions")) @@ -288,9 +288,9 @@ func (c *Client) TerraformProviderLocation(provider *regsrc.TerraformProvider, v return nil, err } - service := c.Discover(host, providersServiceID) - if service == nil { - return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"} + service, err := c.Discover(host, providersServiceID) + if err != nil { + return nil, err } p, err := url.Parse(path.Join( diff --git a/registry/errors.go b/registry/errors.go index 6d6dc95d4..cdde48221 100644 --- a/registry/errors.go +++ b/registry/errors.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/hashicorp/terraform/registry/regsrc" + "github.com/hashicorp/terraform/svchost/disco" ) type errModuleNotFound struct { @@ -42,15 +43,6 @@ func IsProviderNotFound(err error) bool { // error. This allows callers to recognize this particular error condition // as distinct from operational errors such as poor network connectivity. func IsServiceNotProvided(err error) bool { - _, ok := err.(*errServiceNotProvided) + _, ok := err.(*disco.ErrServiceNotProvided) return ok } - -type errServiceNotProvided struct { - host string - service string -} - -func (e *errServiceNotProvided) Error() string { - return fmt.Sprintf("host %s does not provide %s", e.host, e.service) -} diff --git a/svchost/disco/disco.go b/svchost/disco/disco.go index 7fc49da9c..42a2dc4cd 100644 --- a/svchost/disco/disco.go +++ b/svchost/disco/disco.go @@ -8,6 +8,7 @@ package disco import ( "encoding/json" "errors" + "fmt" "io" "io/ioutil" "log" @@ -22,19 +23,27 @@ import ( ) const ( - discoPath = "/.well-known/terraform.json" - maxRedirects = 3 // arbitrary-but-small number to prevent runaway redirect loops - discoTimeout = 11 * time.Second // arbitrary-but-small time limit to prevent UI "hangs" during discovery - maxDiscoDocBytes = 1 * 1024 * 1024 // 1MB - to prevent abusive services from using loads of our memory + // Fixed path to the discovery manifest. + discoPath = "/.well-known/terraform.json" + + // Arbitrary-but-small number to prevent runaway redirect loops. + maxRedirects = 3 + + // Arbitrary-but-small time limit to prevent UI "hangs" during discovery. + discoTimeout = 11 * time.Second + + // 1MB - to prevent abusive services from using loads of our memory. + maxDiscoDocBytes = 1 * 1024 * 1024 ) -var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during tests, to skip TLS verification +// httpTransport is overridden during tests, to skip TLS verification. +var httpTransport = cleanhttp.DefaultPooledTransport() // Disco is the main type in this package, which allows discovery on given // hostnames and caches the results by hostname to avoid repeated requests // for the same information. type Disco struct { - hostCache map[svchost.Hostname]Host + hostCache map[svchost.Hostname]*Host credsSrc auth.CredentialsSource // Transport is a custom http.RoundTripper to use. @@ -50,7 +59,10 @@ func New() *Disco { // NewWithCredentialsSource returns a new discovery object initialized with // the given credentials source. func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco { - return &Disco{credsSrc: credsSrc} + return &Disco{ + hostCache: make(map[svchost.Hostname]*Host), + credsSrc: credsSrc, + } } // SetCredentialsSource provides a credentials source that will be used to @@ -64,11 +76,11 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) { // CredentialsForHost returns a non-nil HostCredentials if the embedded source has // credentials available for the host, and a nil HostCredentials if it does not. -func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, error) { +func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredentials, error) { if d.credsSrc == nil { return nil, nil } - return d.credsSrc.ForHost(host) + return d.credsSrc.ForHost(hostname) } // ForceHostServices provides a pre-defined set of services for a given @@ -81,19 +93,17 @@ func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, // discovery, yielding the same results as if the given map were published // at the host's default discovery URL, though using absolute URLs is strongly // recommended to make the configured behavior more explicit. -func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]interface{}) { - if d.hostCache == nil { - d.hostCache = map[svchost.Hostname]Host{} - } +func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string]interface{}) { if services == nil { services = map[string]interface{}{} } - d.hostCache[host] = Host{ + d.hostCache[hostname] = &Host{ discoURL: &url.URL{ Scheme: "https", - Host: string(host), + Host: string(hostname), Path: discoPath, }, + hostname: hostname.ForDisplay(), services: services, } } @@ -104,36 +114,40 @@ func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]int // // If a given hostname supports no Terraform services at all, a non-nil but // empty Host object is returned. When giving feedback to the end user about -// such situations, we say e.g. "the host doesn't provide a module -// registry", regardless of whether that is due to that service specifically -// being absent or due to the host not providing Terraform services at all, -// since we don't wish to expose the detail of whole-host discovery to an -// end-user. -func (d *Disco) Discover(host svchost.Hostname) Host { - if d.hostCache == nil { - d.hostCache = map[svchost.Hostname]Host{} - } - if cache, cached := d.hostCache[host]; cached { - return cache +// such situations, we say "host does not provide a service", +// regardless of whether that is due to that service specifically being absent +// or due to the host not providing Terraform services at all, since we don't +// wish to expose the detail of whole-host discovery to an end-user. +func (d *Disco) Discover(hostname svchost.Hostname) (*Host, error) { + if host, cached := d.hostCache[hostname]; cached { + return host, nil } - ret := d.discover(host) - d.hostCache[host] = ret - return ret + host, err := d.discover(hostname) + if err != nil { + return nil, err + } + d.hostCache[hostname] = host + + return host, nil } // DiscoverServiceURL is a convenience wrapper for discovery on a given // hostname and then looking up a particular service in the result. -func (d *Disco) DiscoverServiceURL(host svchost.Hostname, serviceID string) *url.URL { - return d.Discover(host).ServiceURL(serviceID) +func (d *Disco) DiscoverServiceURL(hostname svchost.Hostname, serviceID string) (*url.URL, error) { + host, err := d.Discover(hostname) + if err != nil { + return nil, err + } + return host.ServiceURL(serviceID) } // discover implements the actual discovery process, with its result cached // by the public-facing Discover method. -func (d *Disco) discover(host svchost.Hostname) Host { +func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) { discoURL := &url.URL{ Scheme: "https", - Host: host.String(), + Host: hostname.String(), Path: discoPath, } @@ -149,7 +163,7 @@ func (d *Disco) discover(host svchost.Hostname) Host { CheckRedirect: func(req *http.Request, via []*http.Request) error { log.Printf("[DEBUG] Service discovery redirected to %s", req.URL) if len(via) > maxRedirects { - return errors.New("too many redirects") // (this error message will never actually be seen) + return errors.New("too many redirects") // this error will never actually be seen } return nil }, @@ -160,82 +174,84 @@ func (d *Disco) discover(host svchost.Hostname) Host { URL: discoURL, } - if creds, err := d.CredentialsForHost(host); err != nil { - log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err) - } else if creds != nil { - creds.PrepareRequest(req) // alters req to include credentials + creds, err := d.CredentialsForHost(hostname) + if err != nil { + log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", hostname, err) + } + if creds != nil { + // Update the request to include credentials. + creds.PrepareRequest(req) } - log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL) - - ret := Host{ - discoURL: discoURL, - } + log.Printf("[DEBUG] Service discovery for %s at %s", hostname, discoURL) resp, err := client.Do(req) if err != nil { - log.Printf("[WARN] Failed to request discovery document: %s", err) - return ret // empty + return nil, fmt.Errorf("Failed to request discovery document: %v", err) } defer resp.Body.Close() - if resp.StatusCode != 200 { - log.Printf("[WARN] Failed to request discovery document: %s", resp.Status) - return ret // empty + host := &Host{ + // Use the discovery URL from resp.Request in + // case the client followed any redirects. + discoURL: resp.Request.URL, + hostname: hostname.ForDisplay(), } - // If the client followed any redirects, we will have a new URL to use - // as our base for relative resolution. - ret.discoURL = resp.Request.URL + // Return the host without any services. + if resp.StatusCode == 404 { + return host, nil + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("Failed to request discovery document: %s", resp.Status) + } contentType := resp.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(contentType) if err != nil { - log.Printf("[WARN] Discovery URL has malformed Content-Type %q", contentType) - return ret // empty + return nil, fmt.Errorf("Discovery URL has a malformed Content-Type %q", contentType) } if mediaType != "application/json" { - log.Printf("[DEBUG] Discovery URL returned Content-Type %q, rather than application/json", mediaType) - return ret // empty + return nil, fmt.Errorf("Discovery URL returned an unsupported Content-Type %q", mediaType) } - // (this doesn't catch chunked encoding, because ContentLength is -1 in that case...) + // This doesn't catch chunked encoding, because ContentLength is -1 in that case. if resp.ContentLength > maxDiscoDocBytes { // Size limit here is not a contractual requirement and so we may // adjust it over time if we find a different limit is warranted. - log.Printf("[WARN] Discovery doc response is too large (got %d bytes; limit %d)", resp.ContentLength, maxDiscoDocBytes) - return ret // empty + return nil, fmt.Errorf( + "Discovery doc response is too large (got %d bytes; limit %d)", + resp.ContentLength, maxDiscoDocBytes, + ) } - // If the response is using chunked encoding then we can't predict - // its size, but we'll at least prevent reading the entire thing into - // memory. + // If the response is using chunked encoding then we can't predict its + // size, but we'll at least prevent reading the entire thing into memory. lr := io.LimitReader(resp.Body, maxDiscoDocBytes) servicesBytes, err := ioutil.ReadAll(lr) if err != nil { - log.Printf("[WARN] Error reading discovery document body: %s", err) - return ret // empty + return nil, fmt.Errorf("Error reading discovery document body: %v", err) } var services map[string]interface{} err = json.Unmarshal(servicesBytes, &services) if err != nil { - log.Printf("[WARN] Failed to decode discovery document as a JSON object: %s", err) - return ret // empty + return nil, fmt.Errorf("Failed to decode discovery document as a JSON object: %v", err) } + host.services = services - ret.services = services - return ret + return host, nil } // Forget invalidates any cached record of the given hostname. If the host // has no cache entry then this is a no-op. -func (d *Disco) Forget(host svchost.Hostname) { - delete(d.hostCache, host) +func (d *Disco) Forget(hostname svchost.Hostname) { + delete(d.hostCache, hostname) } // ForgetAll is like Forget, but for all of the hostnames that have cache entries. func (d *Disco) ForgetAll() { - d.hostCache = nil + d.hostCache = make(map[svchost.Hostname]*Host) } diff --git a/svchost/disco/disco_test.go b/svchost/disco/disco_test.go index c8bc16c45..95204e6f7 100644 --- a/svchost/disco/disco_test.go +++ b/svchost/disco/disco_test.go @@ -46,8 +46,15 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) - gotURL := discovered.ServiceURL("thingy.v1") + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } + + gotURL, err := discovered.ServiceURL("thingy.v1") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for thingy.v1") } @@ -81,8 +88,15 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) - gotURL := discovered.ServiceURL("wotsit.v2") + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } + + gotURL, err := discovered.ServiceURL("wotsit.v2") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for wotsit.v2") } @@ -133,9 +147,15 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } { - gotURL := discovered.ServiceURL("thingy.v1") + gotURL, err := discovered.ServiceURL("thingy.v1") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for thingy.v1") } @@ -144,7 +164,10 @@ func TestDiscover(t *testing.T) { } } { - gotURL := discovered.ServiceURL("wotsit.v2") + gotURL, err := discovered.ServiceURL("wotsit.v2") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for wotsit.v2") } @@ -168,12 +191,14 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err == nil { + t.Fatalf("expected a discovery error") + } - // result should be empty, which we can verify only by reaching into - // its internals. - if discovered.services != nil { - t.Errorf("response not empty; should be") + // Returned discovered should be nil. + if discovered != nil { + t.Errorf("discovered not nil; should be") } }) t.Run("malformed JSON", func(t *testing.T) { @@ -191,12 +216,14 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err == nil { + t.Fatalf("expected a discovery error") + } - // result should be empty, which we can verify only by reaching into - // its internals. - if discovered.services != nil { - t.Errorf("response not empty; should be") + // Returned discovered should be nil. + if discovered != nil { + t.Errorf("discovered not nil; should be") } }) t.Run("JSON with redundant charset", func(t *testing.T) { @@ -218,7 +245,10 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } if discovered.services == nil { t.Errorf("response is empty; shouldn't be") @@ -237,12 +267,14 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } - // result should be empty, which we can verify only by reaching into - // its internals. + // Returned discovered.services should be nil (empty). if discovered.services != nil { - t.Errorf("response not empty; should be") + t.Errorf("discovered.services not nil (empty); should be") } }) t.Run("redirect", func(t *testing.T) { @@ -268,9 +300,15 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } - gotURL := discovered.ServiceURL("thingy.v1") + gotURL, err := discovered.ServiceURL("thingy.v1") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for thingy.v1") } diff --git a/svchost/disco/host.go b/svchost/disco/host.go index faf58220a..55cc10813 100644 --- a/svchost/disco/host.go +++ b/svchost/disco/host.go @@ -1,51 +1,95 @@ package disco import ( + "fmt" "net/url" + "strings" ) +// Host represents a service discovered host. type Host struct { discoURL *url.URL + hostname string services map[string]interface{} } +// ErrServiceNotProvided is returned when the service is not provided. +type ErrServiceNotProvided struct { + hostname string + service string +} + +// Error returns a customized error message. +func (e *ErrServiceNotProvided) Error() string { + return fmt.Sprintf("host %s does not provide a %s service", e.hostname, e.service) +} + +// ErrVersionNotSupported is returned when the version is not supported. +type ErrVersionNotSupported struct { + hostname string + service string + version string +} + +// Error returns a customized error message. +func (e *ErrVersionNotSupported) Error() string { + return fmt.Sprintf("host %s does not support %s version %s", e.hostname, e.service, e.version) +} + // ServiceURL returns the URL associated with the given service identifier, // which should be of the form "servicename.vN". // -// A non-nil result is always an absolute URL with a scheme of either https -// or http. -// -// If the requested service is not supported by the host, this method returns -// a nil URL. -// -// If the discovery document entry for the given service is invalid (not a URL), -// it is treated as absent, also returning a nil URL. -func (h Host) ServiceURL(id string) *url.URL { - if h.services == nil { - return nil // no services supported for an empty Host +// A non-nil result is always an absolute URL with a scheme of either HTTPS +// or HTTP. +func (h *Host) ServiceURL(id string) (*url.URL, error) { + parts := strings.SplitN(id, ".", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("Invalid service ID format (i.e. service.vN): %s", id) + } + service, version := parts[0], parts[1] + + // No services supported for an empty Host. + if h == nil || h.services == nil { + return nil, &ErrServiceNotProvided{hostname: "", service: service} } urlStr, ok := h.services[id].(string) if !ok { - return nil + // See if we have a matching service as that would indicate + // the service is supported, but not the requested version. + for serviceID := range h.services { + if strings.HasPrefix(serviceID, service) { + return nil, &ErrVersionNotSupported{ + hostname: h.hostname, + service: service, + version: version, + } + } + } + + // No discovered services match the requested service ID. + return nil, &ErrServiceNotProvided{hostname: h.hostname, service: service} } - ret, err := url.Parse(urlStr) + u, err := url.Parse(urlStr) if err != nil { - return nil + return nil, fmt.Errorf("Failed to parse service URL: %v", err) } - if !ret.IsAbs() { - ret = h.discoURL.ResolveReference(ret) // make absolute using our discovery doc URL - } - if ret.Scheme != "https" && ret.Scheme != "http" { - return nil - } - if ret.User != nil { - // embedded username/password information is not permitted; credentials - // are handled out of band. - return nil - } - ret.Fragment = "" // fragment part is irrelevant, since we're not a browser - return h.discoURL.ResolveReference(ret) + // Make relative URLs absolute using our discovery URL. + if !u.IsAbs() { + u = h.discoURL.ResolveReference(u) + } + + if u.Scheme != "https" && u.Scheme != "http" { + return nil, fmt.Errorf("Service URL is using an unsupported scheme: %s", u.Scheme) + } + if u.User != nil { + return nil, fmt.Errorf("Embedded username/password information is not permitted") + } + + // Fragment part is irrelevant, since we're not a browser. + u.Fragment = "" + + return h.discoURL.ResolveReference(u), nil } diff --git a/svchost/disco/host_test.go b/svchost/disco/host_test.go index 8a9fe4c76..c6a1d8eaf 100644 --- a/svchost/disco/host_test.go +++ b/svchost/disco/host_test.go @@ -2,6 +2,7 @@ package disco import ( "net/url" + "strings" "testing" ) @@ -9,6 +10,7 @@ func TestHostServiceURL(t *testing.T) { baseURL, _ := url.Parse("https://example.com/disco/foo.json") host := Host{ discoURL: baseURL, + hostname: "test-server", services: map[string]interface{}{ "absolute.v1": "http://example.net/foo/bar", "absolutewithport.v1": "http://example.net:8080/foo/bar", @@ -24,22 +26,28 @@ func TestHostServiceURL(t *testing.T) { tests := []struct { ID string - Want string + want string + err string }{ - {"absolute.v1", "http://example.net/foo/bar"}, - {"absolutewithport.v1", "http://example.net:8080/foo/bar"}, - {"relative.v1", "https://example.com/disco/stu/"}, - {"rootrelative.v1", "https://example.com/baz"}, - {"protorelative.v1", "https://example.net/"}, - {"withfragment.v1", "http://example.org/"}, - {"querystring.v1", "https://example.net/baz?foo=bar"}, // most callers will disregard query string - {"nothttp.v1", ""}, - {"invalid.v1", ""}, + {"absolute.v1", "http://example.net/foo/bar", ""}, + {"absolutewithport.v1", "http://example.net:8080/foo/bar", ""}, + {"relative.v1", "https://example.com/disco/stu/", ""}, + {"rootrelative.v1", "https://example.com/baz", ""}, + {"protorelative.v1", "https://example.net/", ""}, + {"withfragment.v1", "http://example.org/", ""}, + {"querystring.v1", "https://example.net/baz?foo=bar", ""}, + {"nothttp.v1", "", "unsupported scheme"}, + {"invalid.v1", "", "Failed to parse service URL"}, } for _, test := range tests { t.Run(test.ID, func(t *testing.T) { - url := host.ServiceURL(test.ID) + url, err := host.ServiceURL(test.ID) + if (err != nil || test.err != "") && + (err == nil || !strings.Contains(err.Error(), test.err)) { + t.Fatalf("unexpected service URL error: %s", err) + } + var got string if url != nil { got = url.String() @@ -47,8 +55,8 @@ func TestHostServiceURL(t *testing.T) { got = "" } - if got != test.Want { - t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.Want) + if got != test.want { + t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.want) } }) }