svchost/disco: allow overriding discovery for a particular hostname

The default network-based discovery is not desirable for all situations,
so this mechanism allows callers to provide a services map for a given
hostname that was obtained some other way (caller-defined) which will then
cause network-based discovery to be skipped and the given map to be
returned verbatim.
This commit is contained in:
Martin Atkins 2017-10-25 15:57:03 -07:00
parent ddff8bbc00
commit 74180229d0
2 changed files with 62 additions and 0 deletions

View File

@ -58,6 +58,33 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
d.credsSrc = src
}
// ForceHostServices provides a pre-defined set of services for a given
// host, which prevents the receiver from attempting network-based discovery
// for the given host. Instead, the given services map will be returned
// verbatim.
//
// When providing "forced" services, any relative URLs are resolved against
// the initial discovery URL that would have been used for network-based
// discovery, yielding the same results as if the given map were published
// at the host's default discovery URL, though using absolute URLs is strongly
// recommended to make the configured behavior more explicit.
func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]interface{}) {
if d.hostCache == nil {
d.hostCache = map[svchost.Hostname]Host{}
}
if services == nil {
services = map[string]interface{}{}
}
d.hostCache[host] = Host{
discoURL: &url.URL{
Scheme: "https",
Host: string(host),
Path: discoPath,
},
services: services,
}
}
// Discover runs the discovery protocol against the given hostname (which must
// already have been validated and prepared with svchost.ForComparison) and
// returns an object describing the services available at that host.

View File

@ -118,6 +118,41 @@ func TestDiscover(t *testing.T) {
t.Fatalf("wrong Authorization header\ngot: %s\nwant: %s", got, want)
}
})
t.Run("forced services override", func(t *testing.T) {
forced := map[string]interface{}{
"thingy.v1": "http://example.net/foo",
"wotsit.v2": "/foo",
}
d := NewDisco()
d.ForceHostServices(svchost.Hostname("example.com"), forced)
givenHost := "example.com"
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
discovered := d.Discover(host)
{
gotURL := discovered.ServiceURL("thingy.v1")
if gotURL == nil {
t.Fatalf("found no URL for thingy.v1")
}
if got, want := gotURL.String(), "http://example.net/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
}
{
gotURL := discovered.ServiceURL("wotsit.v2")
if gotURL == nil {
t.Fatalf("found no URL for wotsit.v2")
}
if got, want := gotURL.String(), "https://example.com/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
}
})
t.Run("not JSON", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)