diff --git a/internal/getproviders/registry_client.go b/internal/getproviders/registry_client.go index fd96c65de..dbaf7e75d 100644 --- a/internal/getproviders/registry_client.go +++ b/internal/getproviders/registry_client.go @@ -201,6 +201,15 @@ func (c *registryClient) PackageMeta(provider addrs.Provider, version Version, t } protoVersions.Sort() + downloadURL, err := url.Parse(body.DownloadURL) + if err != nil { + return PackageMeta{}, fmt.Errorf("registry response includes invalid download URL: %s", err) + } + downloadURL = resp.Request.URL.ResolveReference(downloadURL) + if downloadURL.Scheme != "http" && downloadURL.Scheme != "https" { + return PackageMeta{}, fmt.Errorf("registry response includes invalid download URL: must use http or https scheme") + } + ret := PackageMeta{ ProtocolVersions: protoVersions, TargetPlatform: Platform{ @@ -208,7 +217,7 @@ func (c *registryClient) PackageMeta(provider addrs.Provider, version Version, t Arch: body.Arch, }, Filename: body.Filename, - Location: PackageHTTPURL(body.DownloadURL), + Location: PackageHTTPURL(downloadURL.String()), // SHA256Sum is populated below } diff --git a/internal/getproviders/registry_client_test.go b/internal/getproviders/registry_client_test.go index 65ddfdb00..841b551d7 100644 --- a/internal/getproviders/registry_client_test.go +++ b/internal/getproviders/registry_client_test.go @@ -25,10 +25,10 @@ import ( // The second return value is a function to call at the end of a test function // to shut down the test server. After you call that function, the discovery // object becomes useless. -func testServices(t *testing.T) (*disco.Disco, func()) { +func testServices(t *testing.T) (services *disco.Disco, baseURL string, cleanup func()) { server := httptest.NewServer(http.HandlerFunc(fakeRegistryHandler)) - services := disco.New() + services = disco.New() services.ForceHostServices(svchost.Hostname("example.com"), map[string]interface{}{ "providers.v1": server.URL + "/providers/v1/", }) @@ -42,7 +42,7 @@ func testServices(t *testing.T) (*disco.Disco, func()) { "providers.v1": server.URL + "/fails-immediately/", }) - return services, func() { + return services, server.URL, func() { server.Close() } } @@ -53,10 +53,10 @@ func testServices(t *testing.T) (*disco.Disco, func()) { // // As with testServices, the second return value is a function to call at the end // of your test in order to shut down the test server. -func testRegistrySource(t *testing.T) (*RegistrySource, func()) { - services, close := testServices(t) - source := NewRegistrySource(services) - return source, close +func testRegistrySource(t *testing.T) (source *RegistrySource, baseURL string, cleanup func()) { + services, baseURL, close := testServices(t) + source = NewRegistrySource(services) + return source, baseURL, close } func fakeRegistryHandler(resp http.ResponseWriter, req *http.Request) { diff --git a/internal/getproviders/registry_source_test.go b/internal/getproviders/registry_source_test.go index 943112086..0d5efab48 100644 --- a/internal/getproviders/registry_source_test.go +++ b/internal/getproviders/registry_source_test.go @@ -14,7 +14,7 @@ import ( ) func TestSourceAvailableVersions(t *testing.T) { - source, close := testRegistrySource(t) + source, baseURL, close := testRegistrySource(t) defer close() tests := []struct { @@ -52,15 +52,10 @@ func TestSourceAvailableVersions(t *testing.T) { { "fails.example.com/foo/bar", nil, - `could not query provider registry for fails.example.com/foo/bar: Get http://placeholder-origin/fails-immediately/foo/bar/versions: EOF`, + `could not query provider registry for fails.example.com/foo/bar: Get ` + baseURL + `/fails-immediately/foo/bar/versions: EOF`, }, } - // Sometimes error messages contain specific HTTP endpoint URLs, but - // since our test server is on a random port we'd not be able to - // consistently match those. Instead, we'll normalize the URLs. - urlPattern := regexp.MustCompile(`http://[^/]+/`) - for _, test := range tests { t.Run(test.provider, func(t *testing.T) { // TEMP: We don't yet have a function for parsing provider @@ -78,8 +73,7 @@ func TestSourceAvailableVersions(t *testing.T) { if test.wantErr == "" { t.Fatalf("wrong error\ngot: %s\nwant: ", err.Error()) } - gotErr := urlPattern.ReplaceAllLiteralString(err.Error(), "http://placeholder-origin/") - if got, want := gotErr, test.wantErr; got != want { + if got, want := err.Error(), test.wantErr; got != want { t.Fatalf("wrong error\ngot: %s\nwant: %s", got, want) } return @@ -106,7 +100,7 @@ func TestSourceAvailableVersions(t *testing.T) { } func TestSourcePackageMeta(t *testing.T) { - source, close := testRegistrySource(t) + source, baseURL, close := testRegistrySource(t) defer close() tests := []struct { @@ -126,7 +120,7 @@ func TestSourcePackageMeta(t *testing.T) { ProtocolVersions: VersionList{versions.MustParseVersion("5.0.0")}, TargetPlatform: Platform{"linux", "amd64"}, Filename: "happycloud_1.2.0.zip", - Location: PackageHTTPURL("/pkg/happycloud_1.2.0.zip"), + Location: PackageHTTPURL(baseURL + "/pkg/happycloud_1.2.0.zip"), SHA256Sum: [32]uint8{30: 0xf0, 31: 0x0d}, // fake registry uses a memorable sum }, ``,