Merge pull request #19589 from hashicorp/svh/f-service-discovery
core: enhance service discovery
This commit is contained in:
commit
ecc93c9889
|
@ -302,9 +302,9 @@ func (b *Remote) discover(hostname string) (*url.URL, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
service := b.services.DiscoverServiceURL(host, serviceID)
|
service, err := b.services.DiscoverServiceURL(host, serviceID)
|
||||||
if service == nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("host %s does not provide a remote backend API", host)
|
return nil, err
|
||||||
}
|
}
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,7 @@ func TestRemote_config(t *testing.T) {
|
||||||
"prefix": cty.NullVal(cty.String),
|
"prefix": cty.NullVal(cty.String),
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
confErr: "Host nonexisting.local does not provide a remote backend API",
|
confErr: "Failed to request discovery document",
|
||||||
},
|
},
|
||||||
"with_a_name": {
|
"with_a_name": {
|
||||||
config: cty.ObjectVal(map[string]cty.Value{
|
config: cty.ObjectVal(map[string]cty.Value{
|
||||||
|
@ -112,8 +112,8 @@ func TestRemote_config(t *testing.T) {
|
||||||
|
|
||||||
// Validate
|
// Validate
|
||||||
valDiags := b.ValidateConfig(tc.config)
|
valDiags := b.ValidateConfig(tc.config)
|
||||||
if (valDiags.Err() == nil && tc.valErr != "") ||
|
if (valDiags.Err() != nil || tc.valErr != "") &&
|
||||||
(valDiags.Err() != nil && !strings.Contains(valDiags.Err().Error(), tc.valErr)) {
|
(valDiags.Err() == nil || !strings.Contains(valDiags.Err().Error(), tc.valErr)) {
|
||||||
t.Fatalf("%s: unexpected validation result: %v", name, valDiags.Err())
|
t.Fatalf("%s: unexpected validation result: %v", name, valDiags.Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,15 +59,15 @@ func NewClient(services *disco.Disco, client *http.Client) *Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Discover queries the host, and returns the url for the registry.
|
// Discover queries the host, and returns the url for the registry.
|
||||||
func (c *Client) Discover(host svchost.Hostname, serviceID string) *url.URL {
|
func (c *Client) Discover(host svchost.Hostname, serviceID string) (*url.URL, error) {
|
||||||
service := c.services.DiscoverServiceURL(host, serviceID)
|
service, err := c.services.DiscoverServiceURL(host, serviceID)
|
||||||
if service == nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
if !strings.HasSuffix(service.Path, "/") {
|
if !strings.HasSuffix(service.Path, "/") {
|
||||||
service.Path += "/"
|
service.Path += "/"
|
||||||
}
|
}
|
||||||
return service
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModuleVersions queries the registry for a module, and returns the available versions.
|
// ModuleVersions queries the registry for a module, and returns the available versions.
|
||||||
|
@ -77,9 +77,9 @@ func (c *Client) ModuleVersions(module *regsrc.Module) (*response.ModuleVersions
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
service := c.Discover(host, modulesServiceID)
|
service, err := c.Discover(host, modulesServiceID)
|
||||||
if service == nil {
|
if err != nil {
|
||||||
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "modules"}
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := url.Parse(path.Join(module.Module(), "versions"))
|
p, err := url.Parse(path.Join(module.Module(), "versions"))
|
||||||
|
@ -150,9 +150,9 @@ func (c *Client) ModuleLocation(module *regsrc.Module, version string) (string,
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
service := c.Discover(host, modulesServiceID)
|
service, err := c.Discover(host, modulesServiceID)
|
||||||
if service == nil {
|
if err != nil {
|
||||||
return "", &errServiceNotProvided{host: host.ForDisplay(), service: "modules"}
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var p *url.URL
|
var p *url.URL
|
||||||
|
@ -234,9 +234,9 @@ func (c *Client) TerraformProviderVersions(provider *regsrc.TerraformProvider) (
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
service := c.Discover(host, providersServiceID)
|
service, err := c.Discover(host, providersServiceID)
|
||||||
if service == nil {
|
if err != nil {
|
||||||
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"}
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := url.Parse(path.Join(provider.TerraformProvider(), "versions"))
|
p, err := url.Parse(path.Join(provider.TerraformProvider(), "versions"))
|
||||||
|
@ -288,9 +288,9 @@ func (c *Client) TerraformProviderLocation(provider *regsrc.TerraformProvider, v
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
service := c.Discover(host, providersServiceID)
|
service, err := c.Discover(host, providersServiceID)
|
||||||
if service == nil {
|
if err != nil {
|
||||||
return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"}
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := url.Parse(path.Join(
|
p, err := url.Parse(path.Join(
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/hashicorp/terraform/registry/regsrc"
|
"github.com/hashicorp/terraform/registry/regsrc"
|
||||||
|
"github.com/hashicorp/terraform/svchost/disco"
|
||||||
)
|
)
|
||||||
|
|
||||||
type errModuleNotFound struct {
|
type errModuleNotFound struct {
|
||||||
|
@ -42,15 +43,6 @@ func IsProviderNotFound(err error) bool {
|
||||||
// error. This allows callers to recognize this particular error condition
|
// error. This allows callers to recognize this particular error condition
|
||||||
// as distinct from operational errors such as poor network connectivity.
|
// as distinct from operational errors such as poor network connectivity.
|
||||||
func IsServiceNotProvided(err error) bool {
|
func IsServiceNotProvided(err error) bool {
|
||||||
_, ok := err.(*errServiceNotProvided)
|
_, ok := err.(*disco.ErrServiceNotProvided)
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
type errServiceNotProvided struct {
|
|
||||||
host string
|
|
||||||
service string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *errServiceNotProvided) Error() string {
|
|
||||||
return fmt.Sprintf("host %s does not provide %s", e.host, e.service)
|
|
||||||
}
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ package disco
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
@ -22,19 +23,27 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// Fixed path to the discovery manifest.
|
||||||
discoPath = "/.well-known/terraform.json"
|
discoPath = "/.well-known/terraform.json"
|
||||||
maxRedirects = 3 // arbitrary-but-small number to prevent runaway redirect loops
|
|
||||||
discoTimeout = 11 * time.Second // arbitrary-but-small time limit to prevent UI "hangs" during discovery
|
// Arbitrary-but-small number to prevent runaway redirect loops.
|
||||||
maxDiscoDocBytes = 1 * 1024 * 1024 // 1MB - to prevent abusive services from using loads of our memory
|
maxRedirects = 3
|
||||||
|
|
||||||
|
// Arbitrary-but-small time limit to prevent UI "hangs" during discovery.
|
||||||
|
discoTimeout = 11 * time.Second
|
||||||
|
|
||||||
|
// 1MB - to prevent abusive services from using loads of our memory.
|
||||||
|
maxDiscoDocBytes = 1 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during tests, to skip TLS verification
|
// httpTransport is overridden during tests, to skip TLS verification.
|
||||||
|
var httpTransport = cleanhttp.DefaultPooledTransport()
|
||||||
|
|
||||||
// Disco is the main type in this package, which allows discovery on given
|
// Disco is the main type in this package, which allows discovery on given
|
||||||
// hostnames and caches the results by hostname to avoid repeated requests
|
// hostnames and caches the results by hostname to avoid repeated requests
|
||||||
// for the same information.
|
// for the same information.
|
||||||
type Disco struct {
|
type Disco struct {
|
||||||
hostCache map[svchost.Hostname]Host
|
hostCache map[svchost.Hostname]*Host
|
||||||
credsSrc auth.CredentialsSource
|
credsSrc auth.CredentialsSource
|
||||||
|
|
||||||
// Transport is a custom http.RoundTripper to use.
|
// Transport is a custom http.RoundTripper to use.
|
||||||
|
@ -50,7 +59,10 @@ func New() *Disco {
|
||||||
// NewWithCredentialsSource returns a new discovery object initialized with
|
// NewWithCredentialsSource returns a new discovery object initialized with
|
||||||
// the given credentials source.
|
// the given credentials source.
|
||||||
func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
|
func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
|
||||||
return &Disco{credsSrc: credsSrc}
|
return &Disco{
|
||||||
|
hostCache: make(map[svchost.Hostname]*Host),
|
||||||
|
credsSrc: credsSrc,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCredentialsSource provides a credentials source that will be used to
|
// SetCredentialsSource provides a credentials source that will be used to
|
||||||
|
@ -64,11 +76,11 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
|
||||||
|
|
||||||
// CredentialsForHost returns a non-nil HostCredentials if the embedded source has
|
// CredentialsForHost returns a non-nil HostCredentials if the embedded source has
|
||||||
// credentials available for the host, and a nil HostCredentials if it does not.
|
// credentials available for the host, and a nil HostCredentials if it does not.
|
||||||
func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, error) {
|
func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredentials, error) {
|
||||||
if d.credsSrc == nil {
|
if d.credsSrc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return d.credsSrc.ForHost(host)
|
return d.credsSrc.ForHost(hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForceHostServices provides a pre-defined set of services for a given
|
// ForceHostServices provides a pre-defined set of services for a given
|
||||||
|
@ -81,19 +93,17 @@ func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials,
|
||||||
// discovery, yielding the same results as if the given map were published
|
// discovery, yielding the same results as if the given map were published
|
||||||
// at the host's default discovery URL, though using absolute URLs is strongly
|
// at the host's default discovery URL, though using absolute URLs is strongly
|
||||||
// recommended to make the configured behavior more explicit.
|
// recommended to make the configured behavior more explicit.
|
||||||
func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]interface{}) {
|
func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string]interface{}) {
|
||||||
if d.hostCache == nil {
|
|
||||||
d.hostCache = map[svchost.Hostname]Host{}
|
|
||||||
}
|
|
||||||
if services == nil {
|
if services == nil {
|
||||||
services = map[string]interface{}{}
|
services = map[string]interface{}{}
|
||||||
}
|
}
|
||||||
d.hostCache[host] = Host{
|
d.hostCache[hostname] = &Host{
|
||||||
discoURL: &url.URL{
|
discoURL: &url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: string(host),
|
Host: string(hostname),
|
||||||
Path: discoPath,
|
Path: discoPath,
|
||||||
},
|
},
|
||||||
|
hostname: hostname.ForDisplay(),
|
||||||
services: services,
|
services: services,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,36 +114,40 @@ func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]int
|
||||||
//
|
//
|
||||||
// If a given hostname supports no Terraform services at all, a non-nil but
|
// If a given hostname supports no Terraform services at all, a non-nil but
|
||||||
// empty Host object is returned. When giving feedback to the end user about
|
// empty Host object is returned. When giving feedback to the end user about
|
||||||
// such situations, we say e.g. "the host <name> doesn't provide a module
|
// such situations, we say "host <name> does not provide a <service> service",
|
||||||
// registry", regardless of whether that is due to that service specifically
|
// regardless of whether that is due to that service specifically being absent
|
||||||
// being absent or due to the host not providing Terraform services at all,
|
// or due to the host not providing Terraform services at all, since we don't
|
||||||
// since we don't wish to expose the detail of whole-host discovery to an
|
// wish to expose the detail of whole-host discovery to an end-user.
|
||||||
// end-user.
|
func (d *Disco) Discover(hostname svchost.Hostname) (*Host, error) {
|
||||||
func (d *Disco) Discover(host svchost.Hostname) Host {
|
if host, cached := d.hostCache[hostname]; cached {
|
||||||
if d.hostCache == nil {
|
return host, nil
|
||||||
d.hostCache = map[svchost.Hostname]Host{}
|
|
||||||
}
|
|
||||||
if cache, cached := d.hostCache[host]; cached {
|
|
||||||
return cache
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ret := d.discover(host)
|
host, err := d.discover(hostname)
|
||||||
d.hostCache[host] = ret
|
if err != nil {
|
||||||
return ret
|
return nil, err
|
||||||
|
}
|
||||||
|
d.hostCache[hostname] = host
|
||||||
|
|
||||||
|
return host, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DiscoverServiceURL is a convenience wrapper for discovery on a given
|
// DiscoverServiceURL is a convenience wrapper for discovery on a given
|
||||||
// hostname and then looking up a particular service in the result.
|
// hostname and then looking up a particular service in the result.
|
||||||
func (d *Disco) DiscoverServiceURL(host svchost.Hostname, serviceID string) *url.URL {
|
func (d *Disco) DiscoverServiceURL(hostname svchost.Hostname, serviceID string) (*url.URL, error) {
|
||||||
return d.Discover(host).ServiceURL(serviceID)
|
host, err := d.Discover(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return host.ServiceURL(serviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// discover implements the actual discovery process, with its result cached
|
// discover implements the actual discovery process, with its result cached
|
||||||
// by the public-facing Discover method.
|
// by the public-facing Discover method.
|
||||||
func (d *Disco) discover(host svchost.Hostname) Host {
|
func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
|
||||||
discoURL := &url.URL{
|
discoURL := &url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: host.String(),
|
Host: hostname.String(),
|
||||||
Path: discoPath,
|
Path: discoPath,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,7 +163,7 @@ func (d *Disco) discover(host svchost.Hostname) Host {
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
log.Printf("[DEBUG] Service discovery redirected to %s", req.URL)
|
log.Printf("[DEBUG] Service discovery redirected to %s", req.URL)
|
||||||
if len(via) > maxRedirects {
|
if len(via) > maxRedirects {
|
||||||
return errors.New("too many redirects") // (this error message will never actually be seen)
|
return errors.New("too many redirects") // this error will never actually be seen
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -160,82 +174,84 @@ func (d *Disco) discover(host svchost.Hostname) Host {
|
||||||
URL: discoURL,
|
URL: discoURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
if creds, err := d.CredentialsForHost(host); err != nil {
|
creds, err := d.CredentialsForHost(hostname)
|
||||||
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err)
|
if err != nil {
|
||||||
} else if creds != nil {
|
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", hostname, err)
|
||||||
creds.PrepareRequest(req) // alters req to include credentials
|
}
|
||||||
|
if creds != nil {
|
||||||
|
// Update the request to include credentials.
|
||||||
|
creds.PrepareRequest(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL)
|
log.Printf("[DEBUG] Service discovery for %s at %s", hostname, discoURL)
|
||||||
|
|
||||||
ret := Host{
|
|
||||||
discoURL: discoURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[WARN] Failed to request discovery document: %s", err)
|
return nil, fmt.Errorf("Failed to request discovery document: %v", err)
|
||||||
return ret // empty
|
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
host := &Host{
|
||||||
log.Printf("[WARN] Failed to request discovery document: %s", resp.Status)
|
// Use the discovery URL from resp.Request in
|
||||||
return ret // empty
|
// case the client followed any redirects.
|
||||||
|
discoURL: resp.Request.URL,
|
||||||
|
hostname: hostname.ForDisplay(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the client followed any redirects, we will have a new URL to use
|
// Return the host without any services.
|
||||||
// as our base for relative resolution.
|
if resp.StatusCode == 404 {
|
||||||
ret.discoURL = resp.Request.URL
|
return host, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return nil, fmt.Errorf("Failed to request discovery document: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[WARN] Discovery URL has malformed Content-Type %q", contentType)
|
return nil, fmt.Errorf("Discovery URL has a malformed Content-Type %q", contentType)
|
||||||
return ret // empty
|
|
||||||
}
|
}
|
||||||
if mediaType != "application/json" {
|
if mediaType != "application/json" {
|
||||||
log.Printf("[DEBUG] Discovery URL returned Content-Type %q, rather than application/json", mediaType)
|
return nil, fmt.Errorf("Discovery URL returned an unsupported Content-Type %q", mediaType)
|
||||||
return ret // empty
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// (this doesn't catch chunked encoding, because ContentLength is -1 in that case...)
|
// This doesn't catch chunked encoding, because ContentLength is -1 in that case.
|
||||||
if resp.ContentLength > maxDiscoDocBytes {
|
if resp.ContentLength > maxDiscoDocBytes {
|
||||||
// Size limit here is not a contractual requirement and so we may
|
// Size limit here is not a contractual requirement and so we may
|
||||||
// adjust it over time if we find a different limit is warranted.
|
// adjust it over time if we find a different limit is warranted.
|
||||||
log.Printf("[WARN] Discovery doc response is too large (got %d bytes; limit %d)", resp.ContentLength, maxDiscoDocBytes)
|
return nil, fmt.Errorf(
|
||||||
return ret // empty
|
"Discovery doc response is too large (got %d bytes; limit %d)",
|
||||||
|
resp.ContentLength, maxDiscoDocBytes,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the response is using chunked encoding then we can't predict
|
// If the response is using chunked encoding then we can't predict its
|
||||||
// its size, but we'll at least prevent reading the entire thing into
|
// size, but we'll at least prevent reading the entire thing into memory.
|
||||||
// memory.
|
|
||||||
lr := io.LimitReader(resp.Body, maxDiscoDocBytes)
|
lr := io.LimitReader(resp.Body, maxDiscoDocBytes)
|
||||||
|
|
||||||
servicesBytes, err := ioutil.ReadAll(lr)
|
servicesBytes, err := ioutil.ReadAll(lr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[WARN] Error reading discovery document body: %s", err)
|
return nil, fmt.Errorf("Error reading discovery document body: %v", err)
|
||||||
return ret // empty
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var services map[string]interface{}
|
var services map[string]interface{}
|
||||||
err = json.Unmarshal(servicesBytes, &services)
|
err = json.Unmarshal(servicesBytes, &services)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[WARN] Failed to decode discovery document as a JSON object: %s", err)
|
return nil, fmt.Errorf("Failed to decode discovery document as a JSON object: %v", err)
|
||||||
return ret // empty
|
|
||||||
}
|
}
|
||||||
|
host.services = services
|
||||||
|
|
||||||
ret.services = services
|
return host, nil
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forget invalidates any cached record of the given hostname. If the host
|
// Forget invalidates any cached record of the given hostname. If the host
|
||||||
// has no cache entry then this is a no-op.
|
// has no cache entry then this is a no-op.
|
||||||
func (d *Disco) Forget(host svchost.Hostname) {
|
func (d *Disco) Forget(hostname svchost.Hostname) {
|
||||||
delete(d.hostCache, host)
|
delete(d.hostCache, hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForgetAll is like Forget, but for all of the hostnames that have cache entries.
|
// ForgetAll is like Forget, but for all of the hostnames that have cache entries.
|
||||||
func (d *Disco) ForgetAll() {
|
func (d *Disco) ForgetAll() {
|
||||||
d.hostCache = nil
|
d.hostCache = make(map[svchost.Hostname]*Host)
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,8 +46,15 @@ func TestDiscover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
d := New()
|
d := New()
|
||||||
discovered := d.Discover(host)
|
discovered, err := d.Discover(host)
|
||||||
gotURL := discovered.ServiceURL("thingy.v1")
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected discovery error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotURL, err := discovered.ServiceURL("thingy.v1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected service URL error: %s", err)
|
||||||
|
}
|
||||||
if gotURL == nil {
|
if gotURL == nil {
|
||||||
t.Fatalf("found no URL for thingy.v1")
|
t.Fatalf("found no URL for thingy.v1")
|
||||||
}
|
}
|
||||||
|
@ -81,8 +88,15 @@ func TestDiscover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
d := New()
|
d := New()
|
||||||
discovered := d.Discover(host)
|
discovered, err := d.Discover(host)
|
||||||
gotURL := discovered.ServiceURL("wotsit.v2")
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected discovery error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotURL, err := discovered.ServiceURL("wotsit.v2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected service URL error: %s", err)
|
||||||
|
}
|
||||||
if gotURL == nil {
|
if gotURL == nil {
|
||||||
t.Fatalf("found no URL for wotsit.v2")
|
t.Fatalf("found no URL for wotsit.v2")
|
||||||
}
|
}
|
||||||
|
@ -133,9 +147,15 @@ func TestDiscover(t *testing.T) {
|
||||||
t.Fatalf("test server hostname is invalid: %s", err)
|
t.Fatalf("test server hostname is invalid: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
discovered := d.Discover(host)
|
discovered, err := d.Discover(host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected discovery error: %s", err)
|
||||||
|
}
|
||||||
{
|
{
|
||||||
gotURL := discovered.ServiceURL("thingy.v1")
|
gotURL, err := discovered.ServiceURL("thingy.v1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected service URL error: %s", err)
|
||||||
|
}
|
||||||
if gotURL == nil {
|
if gotURL == nil {
|
||||||
t.Fatalf("found no URL for thingy.v1")
|
t.Fatalf("found no URL for thingy.v1")
|
||||||
}
|
}
|
||||||
|
@ -144,7 +164,10 @@ func TestDiscover(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
gotURL := discovered.ServiceURL("wotsit.v2")
|
gotURL, err := discovered.ServiceURL("wotsit.v2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected service URL error: %s", err)
|
||||||
|
}
|
||||||
if gotURL == nil {
|
if gotURL == nil {
|
||||||
t.Fatalf("found no URL for wotsit.v2")
|
t.Fatalf("found no URL for wotsit.v2")
|
||||||
}
|
}
|
||||||
|
@ -168,12 +191,14 @@ func TestDiscover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
d := New()
|
d := New()
|
||||||
discovered := d.Discover(host)
|
discovered, err := d.Discover(host)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected a discovery error")
|
||||||
|
}
|
||||||
|
|
||||||
// result should be empty, which we can verify only by reaching into
|
// Returned discovered should be nil.
|
||||||
// its internals.
|
if discovered != nil {
|
||||||
if discovered.services != nil {
|
t.Errorf("discovered not nil; should be")
|
||||||
t.Errorf("response not empty; should be")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("malformed JSON", func(t *testing.T) {
|
t.Run("malformed JSON", func(t *testing.T) {
|
||||||
|
@ -191,12 +216,14 @@ func TestDiscover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
d := New()
|
d := New()
|
||||||
discovered := d.Discover(host)
|
discovered, err := d.Discover(host)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected a discovery error")
|
||||||
|
}
|
||||||
|
|
||||||
// result should be empty, which we can verify only by reaching into
|
// Returned discovered should be nil.
|
||||||
// its internals.
|
if discovered != nil {
|
||||||
if discovered.services != nil {
|
t.Errorf("discovered not nil; should be")
|
||||||
t.Errorf("response not empty; should be")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("JSON with redundant charset", func(t *testing.T) {
|
t.Run("JSON with redundant charset", func(t *testing.T) {
|
||||||
|
@ -218,7 +245,10 @@ func TestDiscover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
d := New()
|
d := New()
|
||||||
discovered := d.Discover(host)
|
discovered, err := d.Discover(host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected discovery error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
if discovered.services == nil {
|
if discovered.services == nil {
|
||||||
t.Errorf("response is empty; shouldn't be")
|
t.Errorf("response is empty; shouldn't be")
|
||||||
|
@ -237,12 +267,14 @@ func TestDiscover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
d := New()
|
d := New()
|
||||||
discovered := d.Discover(host)
|
discovered, err := d.Discover(host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected discovery error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
// result should be empty, which we can verify only by reaching into
|
// Returned discovered.services should be nil (empty).
|
||||||
// its internals.
|
|
||||||
if discovered.services != nil {
|
if discovered.services != nil {
|
||||||
t.Errorf("response not empty; should be")
|
t.Errorf("discovered.services not nil (empty); should be")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("redirect", func(t *testing.T) {
|
t.Run("redirect", func(t *testing.T) {
|
||||||
|
@ -268,9 +300,15 @@ func TestDiscover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
d := New()
|
d := New()
|
||||||
discovered := d.Discover(host)
|
discovered, err := d.Discover(host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected discovery error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
gotURL := discovered.ServiceURL("thingy.v1")
|
gotURL, err := discovered.ServiceURL("thingy.v1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected service URL error: %s", err)
|
||||||
|
}
|
||||||
if gotURL == nil {
|
if gotURL == nil {
|
||||||
t.Fatalf("found no URL for thingy.v1")
|
t.Fatalf("found no URL for thingy.v1")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,51 +1,95 @@
|
||||||
package disco
|
package disco
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Host represents a service discovered host.
|
||||||
type Host struct {
|
type Host struct {
|
||||||
discoURL *url.URL
|
discoURL *url.URL
|
||||||
|
hostname string
|
||||||
services map[string]interface{}
|
services map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrServiceNotProvided is returned when the service is not provided.
|
||||||
|
type ErrServiceNotProvided struct {
|
||||||
|
hostname string
|
||||||
|
service string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a customized error message.
|
||||||
|
func (e *ErrServiceNotProvided) Error() string {
|
||||||
|
return fmt.Sprintf("host %s does not provide a %s service", e.hostname, e.service)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrVersionNotSupported is returned when the version is not supported.
|
||||||
|
type ErrVersionNotSupported struct {
|
||||||
|
hostname string
|
||||||
|
service string
|
||||||
|
version string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a customized error message.
|
||||||
|
func (e *ErrVersionNotSupported) Error() string {
|
||||||
|
return fmt.Sprintf("host %s does not support %s version %s", e.hostname, e.service, e.version)
|
||||||
|
}
|
||||||
|
|
||||||
// ServiceURL returns the URL associated with the given service identifier,
|
// ServiceURL returns the URL associated with the given service identifier,
|
||||||
// which should be of the form "servicename.vN".
|
// which should be of the form "servicename.vN".
|
||||||
//
|
//
|
||||||
// A non-nil result is always an absolute URL with a scheme of either https
|
// A non-nil result is always an absolute URL with a scheme of either HTTPS
|
||||||
// or http.
|
// or HTTP.
|
||||||
//
|
func (h *Host) ServiceURL(id string) (*url.URL, error) {
|
||||||
// If the requested service is not supported by the host, this method returns
|
parts := strings.SplitN(id, ".", 2)
|
||||||
// a nil URL.
|
if len(parts) != 2 {
|
||||||
//
|
return nil, fmt.Errorf("Invalid service ID format (i.e. service.vN): %s", id)
|
||||||
// If the discovery document entry for the given service is invalid (not a URL),
|
}
|
||||||
// it is treated as absent, also returning a nil URL.
|
service, version := parts[0], parts[1]
|
||||||
func (h Host) ServiceURL(id string) *url.URL {
|
|
||||||
if h.services == nil {
|
// No services supported for an empty Host.
|
||||||
return nil // no services supported for an empty Host
|
if h == nil || h.services == nil {
|
||||||
|
return nil, &ErrServiceNotProvided{hostname: "<unknown>", service: service}
|
||||||
}
|
}
|
||||||
|
|
||||||
urlStr, ok := h.services[id].(string)
|
urlStr, ok := h.services[id].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
// See if we have a matching service as that would indicate
|
||||||
|
// the service is supported, but not the requested version.
|
||||||
|
for serviceID := range h.services {
|
||||||
|
if strings.HasPrefix(serviceID, service) {
|
||||||
|
return nil, &ErrVersionNotSupported{
|
||||||
|
hostname: h.hostname,
|
||||||
|
service: service,
|
||||||
|
version: version,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ret, err := url.Parse(urlStr)
|
// No discovered services match the requested service ID.
|
||||||
|
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: service}
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(urlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, fmt.Errorf("Failed to parse service URL: %v", err)
|
||||||
}
|
}
|
||||||
if !ret.IsAbs() {
|
|
||||||
ret = h.discoURL.ResolveReference(ret) // make absolute using our discovery doc URL
|
|
||||||
}
|
|
||||||
if ret.Scheme != "https" && ret.Scheme != "http" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if ret.User != nil {
|
|
||||||
// embedded username/password information is not permitted; credentials
|
|
||||||
// are handled out of band.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
ret.Fragment = "" // fragment part is irrelevant, since we're not a browser
|
|
||||||
|
|
||||||
return h.discoURL.ResolveReference(ret)
|
// Make relative URLs absolute using our discovery URL.
|
||||||
|
if !u.IsAbs() {
|
||||||
|
u = h.discoURL.ResolveReference(u)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Scheme != "https" && u.Scheme != "http" {
|
||||||
|
return nil, fmt.Errorf("Service URL is using an unsupported scheme: %s", u.Scheme)
|
||||||
|
}
|
||||||
|
if u.User != nil {
|
||||||
|
return nil, fmt.Errorf("Embedded username/password information is not permitted")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fragment part is irrelevant, since we're not a browser.
|
||||||
|
u.Fragment = ""
|
||||||
|
|
||||||
|
return h.discoURL.ResolveReference(u), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package disco
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,6 +10,7 @@ func TestHostServiceURL(t *testing.T) {
|
||||||
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
|
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
|
||||||
host := Host{
|
host := Host{
|
||||||
discoURL: baseURL,
|
discoURL: baseURL,
|
||||||
|
hostname: "test-server",
|
||||||
services: map[string]interface{}{
|
services: map[string]interface{}{
|
||||||
"absolute.v1": "http://example.net/foo/bar",
|
"absolute.v1": "http://example.net/foo/bar",
|
||||||
"absolutewithport.v1": "http://example.net:8080/foo/bar",
|
"absolutewithport.v1": "http://example.net:8080/foo/bar",
|
||||||
|
@ -24,22 +26,28 @@ func TestHostServiceURL(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
ID string
|
ID string
|
||||||
Want string
|
want string
|
||||||
|
err string
|
||||||
}{
|
}{
|
||||||
{"absolute.v1", "http://example.net/foo/bar"},
|
{"absolute.v1", "http://example.net/foo/bar", ""},
|
||||||
{"absolutewithport.v1", "http://example.net:8080/foo/bar"},
|
{"absolutewithport.v1", "http://example.net:8080/foo/bar", ""},
|
||||||
{"relative.v1", "https://example.com/disco/stu/"},
|
{"relative.v1", "https://example.com/disco/stu/", ""},
|
||||||
{"rootrelative.v1", "https://example.com/baz"},
|
{"rootrelative.v1", "https://example.com/baz", ""},
|
||||||
{"protorelative.v1", "https://example.net/"},
|
{"protorelative.v1", "https://example.net/", ""},
|
||||||
{"withfragment.v1", "http://example.org/"},
|
{"withfragment.v1", "http://example.org/", ""},
|
||||||
{"querystring.v1", "https://example.net/baz?foo=bar"}, // most callers will disregard query string
|
{"querystring.v1", "https://example.net/baz?foo=bar", ""},
|
||||||
{"nothttp.v1", "<nil>"},
|
{"nothttp.v1", "<nil>", "unsupported scheme"},
|
||||||
{"invalid.v1", "<nil>"},
|
{"invalid.v1", "<nil>", "Failed to parse service URL"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.ID, func(t *testing.T) {
|
t.Run(test.ID, func(t *testing.T) {
|
||||||
url := host.ServiceURL(test.ID)
|
url, err := host.ServiceURL(test.ID)
|
||||||
|
if (err != nil || test.err != "") &&
|
||||||
|
(err == nil || !strings.Contains(err.Error(), test.err)) {
|
||||||
|
t.Fatalf("unexpected service URL error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
var got string
|
var got string
|
||||||
if url != nil {
|
if url != nil {
|
||||||
got = url.String()
|
got = url.String()
|
||||||
|
@ -47,8 +55,8 @@ func TestHostServiceURL(t *testing.T) {
|
||||||
got = "<nil>"
|
got = "<nil>"
|
||||||
}
|
}
|
||||||
|
|
||||||
if got != test.Want {
|
if got != test.want {
|
||||||
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.Want)
|
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue