svchost/disco: allow overriding discovery for a particular hostname
The default network-based discovery is not desirable for all situations, so this mechanism allows callers to provide a services map for a given hostname that was obtained some other way (caller-defined) which will then cause network-based discovery to be skipped and the given map to be returned verbatim.
This commit is contained in:
parent
ddff8bbc00
commit
74180229d0
|
@ -58,6 +58,33 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
|
|||
d.credsSrc = src
|
||||
}
|
||||
|
||||
// ForceHostServices provides a pre-defined set of services for a given
|
||||
// host, which prevents the receiver from attempting network-based discovery
|
||||
// for the given host. Instead, the given services map will be returned
|
||||
// verbatim.
|
||||
//
|
||||
// When providing "forced" services, any relative URLs are resolved against
|
||||
// the initial discovery URL that would have been used for network-based
|
||||
// discovery, yielding the same results as if the given map were published
|
||||
// at the host's default discovery URL, though using absolute URLs is strongly
|
||||
// recommended to make the configured behavior more explicit.
|
||||
func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]interface{}) {
|
||||
if d.hostCache == nil {
|
||||
d.hostCache = map[svchost.Hostname]Host{}
|
||||
}
|
||||
if services == nil {
|
||||
services = map[string]interface{}{}
|
||||
}
|
||||
d.hostCache[host] = Host{
|
||||
discoURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: string(host),
|
||||
Path: discoPath,
|
||||
},
|
||||
services: services,
|
||||
}
|
||||
}
|
||||
|
||||
// Discover runs the discovery protocol against the given hostname (which must
|
||||
// already have been validated and prepared with svchost.ForComparison) and
|
||||
// returns an object describing the services available at that host.
|
||||
|
|
|
@ -118,6 +118,41 @@ func TestDiscover(t *testing.T) {
|
|||
t.Fatalf("wrong Authorization header\ngot: %s\nwant: %s", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("forced services override", func(t *testing.T) {
|
||||
forced := map[string]interface{}{
|
||||
"thingy.v1": "http://example.net/foo",
|
||||
"wotsit.v2": "/foo",
|
||||
}
|
||||
|
||||
d := NewDisco()
|
||||
d.ForceHostServices(svchost.Hostname("example.com"), forced)
|
||||
|
||||
givenHost := "example.com"
|
||||
host, err := svchost.ForComparison(givenHost)
|
||||
if err != nil {
|
||||
t.Fatalf("test server hostname is invalid: %s", err)
|
||||
}
|
||||
|
||||
discovered := d.Discover(host)
|
||||
{
|
||||
gotURL := discovered.ServiceURL("thingy.v1")
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for thingy.v1")
|
||||
}
|
||||
if got, want := gotURL.String(), "http://example.net/foo"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
{
|
||||
gotURL := discovered.ServiceURL("wotsit.v2")
|
||||
if gotURL == nil {
|
||||
t.Fatalf("found no URL for wotsit.v2")
|
||||
}
|
||||
if got, want := gotURL.String(), "https://example.com/foo"; got != want {
|
||||
t.Fatalf("wrong result %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("not JSON", func(t *testing.T) {
|
||||
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
|
||||
|
|
Loading…
Reference in New Issue