446 lines
14 KiB
Go
446 lines
14 KiB
Go
package getproviders
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/apparentlymart/go-versions/versions"
|
|
"github.com/google/go-cmp/cmp"
|
|
svchost "github.com/hashicorp/terraform-svchost"
|
|
disco "github.com/hashicorp/terraform-svchost/disco"
|
|
"github.com/hashicorp/terraform/addrs"
|
|
)
|
|
|
|
func TestConfigureDiscoveryRetry(t *testing.T) {
|
|
t.Run("default retry", func(t *testing.T) {
|
|
if discoveryRetry != defaultRetry {
|
|
t.Fatalf("expected retry %q, got %q", defaultRetry, discoveryRetry)
|
|
}
|
|
|
|
rc := newRegistryClient(nil, nil)
|
|
if rc.httpClient.RetryMax != defaultRetry {
|
|
t.Fatalf("expected client retry %q, got %q",
|
|
defaultRetry, rc.httpClient.RetryMax)
|
|
}
|
|
})
|
|
|
|
t.Run("configured retry", func(t *testing.T) {
|
|
defer func(retryEnv string) {
|
|
os.Setenv(registryDiscoveryRetryEnvName, retryEnv)
|
|
discoveryRetry = defaultRetry
|
|
}(os.Getenv(registryDiscoveryRetryEnvName))
|
|
os.Setenv(registryDiscoveryRetryEnvName, "2")
|
|
|
|
configureDiscoveryRetry()
|
|
expected := 2
|
|
if discoveryRetry != expected {
|
|
t.Fatalf("expected retry %q, got %q",
|
|
expected, discoveryRetry)
|
|
}
|
|
|
|
rc := newRegistryClient(nil, nil)
|
|
if rc.httpClient.RetryMax != expected {
|
|
t.Fatalf("expected client retry %q, got %q",
|
|
expected, rc.httpClient.RetryMax)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestConfigureRegistryClientTimeout(t *testing.T) {
|
|
t.Run("default timeout", func(t *testing.T) {
|
|
if requestTimeout != defaultRequestTimeout {
|
|
t.Fatalf("expected timeout %q, got %q",
|
|
defaultRequestTimeout.String(), requestTimeout.String())
|
|
}
|
|
|
|
rc := newRegistryClient(nil, nil)
|
|
if rc.httpClient.HTTPClient.Timeout != defaultRequestTimeout {
|
|
t.Fatalf("expected client timeout %q, got %q",
|
|
defaultRequestTimeout.String(), rc.httpClient.HTTPClient.Timeout.String())
|
|
}
|
|
})
|
|
|
|
t.Run("configured timeout", func(t *testing.T) {
|
|
defer func(timeoutEnv string) {
|
|
os.Setenv(registryClientTimeoutEnvName, timeoutEnv)
|
|
requestTimeout = defaultRequestTimeout
|
|
}(os.Getenv(registryClientTimeoutEnvName))
|
|
os.Setenv(registryClientTimeoutEnvName, "20")
|
|
|
|
configureRequestTimeout()
|
|
expected := 20 * time.Second
|
|
if requestTimeout != expected {
|
|
t.Fatalf("expected timeout %q, got %q",
|
|
expected, requestTimeout.String())
|
|
}
|
|
|
|
rc := newRegistryClient(nil, nil)
|
|
if rc.httpClient.HTTPClient.Timeout != expected {
|
|
t.Fatalf("expected client timeout %q, got %q",
|
|
expected, rc.httpClient.HTTPClient.Timeout.String())
|
|
}
|
|
})
|
|
}
|
|
|
|
// testRegistryServices starts up a local HTTP server running a fake provider registry
|
|
// service and returns a service discovery object pre-configured to consider
|
|
// the host "example.com" to be served by the fake registry service.
|
|
//
|
|
// The returned discovery object also knows the hostname "not.example.com"
|
|
// which does not have a provider registry at all and "too-new.example.com"
|
|
// which has a "providers.v99" service that is inoperable but could be useful
|
|
// to test the error reporting for detecting an unsupported protocol version.
|
|
// It also knows fails.example.com but it refers to an endpoint that doesn't
|
|
// correctly speak HTTP, to simulate a protocol error.
|
|
//
|
|
// The second return value is a function to call at the end of a test function
|
|
// to shut down the test server. After you call that function, the discovery
|
|
// object becomes useless.
|
|
func testRegistryServices(t *testing.T) (services *disco.Disco, baseURL string, cleanup func()) {
|
|
server := httptest.NewServer(http.HandlerFunc(fakeRegistryHandler))
|
|
|
|
services = disco.New()
|
|
services.ForceHostServices(svchost.Hostname("example.com"), map[string]interface{}{
|
|
"providers.v1": server.URL + "/providers/v1/",
|
|
})
|
|
services.ForceHostServices(svchost.Hostname("not.example.com"), map[string]interface{}{})
|
|
services.ForceHostServices(svchost.Hostname("too-new.example.com"), map[string]interface{}{
|
|
// This service doesn't actually work; it's here only to be
|
|
// detected as "too new" by the discovery logic.
|
|
"providers.v99": server.URL + "/providers/v99/",
|
|
})
|
|
services.ForceHostServices(svchost.Hostname("fails.example.com"), map[string]interface{}{
|
|
"providers.v1": server.URL + "/fails-immediately/",
|
|
})
|
|
|
|
// We'll also permit registry.terraform.io here just because it's our
|
|
// default and has some unique features that are not allowed on any other
|
|
// hostname. It behaves the same as example.com, which should be preferred
|
|
// if you're not testing something specific to the default registry in order
|
|
// to ensure that most things are hostname-agnostic.
|
|
services.ForceHostServices(svchost.Hostname("registry.terraform.io"), map[string]interface{}{
|
|
"providers.v1": server.URL + "/providers/v1/",
|
|
})
|
|
|
|
return services, server.URL, func() {
|
|
server.Close()
|
|
}
|
|
}
|
|
|
|
// testRegistrySource is a wrapper around testServices that uses the created
|
|
// discovery object to produce a Source instance that is ready to use with the
|
|
// fake registry services.
|
|
//
|
|
// As with testServices, the second return value is a function to call at the end
|
|
// of your test in order to shut down the test server.
|
|
func testRegistrySource(t *testing.T) (source *RegistrySource, baseURL string, cleanup func()) {
|
|
services, baseURL, close := testRegistryServices(t)
|
|
source = NewRegistrySource(services)
|
|
return source, baseURL, close
|
|
}
|
|
|
|
func fakeRegistryHandler(resp http.ResponseWriter, req *http.Request) {
|
|
path := req.URL.EscapedPath()
|
|
if strings.HasPrefix(path, "/fails-immediately/") {
|
|
// Here we take over the socket and just close it immediately, to
|
|
// simulate one possible way a server might not be an HTTP server.
|
|
hijacker, ok := resp.(http.Hijacker)
|
|
if !ok {
|
|
// Not hijackable, so we'll just fail normally.
|
|
// If this happens, tests relying on this will fail.
|
|
resp.WriteHeader(500)
|
|
resp.Write([]byte(`cannot hijack`))
|
|
return
|
|
}
|
|
conn, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
resp.WriteHeader(500)
|
|
resp.Write([]byte(`hijack failed`))
|
|
return
|
|
}
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
if strings.HasPrefix(path, "/pkg/") {
|
|
switch path {
|
|
case "/pkg/awesomesauce/happycloud_1.2.0.zip":
|
|
resp.Write([]byte("some zip file"))
|
|
case "/pkg/awesomesauce/happycloud_1.2.0_SHA256SUMS":
|
|
resp.Write([]byte("000000000000000000000000000000000000000000000000000000000000f00d happycloud_1.2.0.zip\n000000000000000000000000000000000000000000000000000000000000face happycloud_1.2.0_face.zip\n"))
|
|
case "/pkg/awesomesauce/happycloud_1.2.0_SHA256SUMS.sig":
|
|
resp.Write([]byte("GPG signature"))
|
|
default:
|
|
resp.WriteHeader(404)
|
|
resp.Write([]byte("unknown package file download"))
|
|
}
|
|
return
|
|
}
|
|
|
|
if !strings.HasPrefix(path, "/providers/v1/") {
|
|
resp.WriteHeader(404)
|
|
resp.Write([]byte(`not a provider registry endpoint`))
|
|
return
|
|
}
|
|
|
|
pathParts := strings.Split(path, "/")[3:]
|
|
if len(pathParts) < 3 {
|
|
resp.WriteHeader(404)
|
|
resp.Write([]byte(`unexpected number of path parts`))
|
|
return
|
|
}
|
|
log.Printf("[TRACE] fake provider registry request for %#v", pathParts)
|
|
|
|
if pathParts[2] == "versions" {
|
|
if len(pathParts) != 3 {
|
|
resp.WriteHeader(404)
|
|
resp.Write([]byte(`extraneous path parts`))
|
|
return
|
|
}
|
|
|
|
switch pathParts[0] + "/" + pathParts[1] {
|
|
case "awesomesauce/happycloud":
|
|
resp.Header().Set("Content-Type", "application/json")
|
|
resp.WriteHeader(200)
|
|
// Note that these version numbers are intentionally misordered
|
|
// so we can test that the client-side code places them in the
|
|
// correct order (lowest precedence first).
|
|
resp.Write([]byte(`{"versions":[{"version":"0.1.0","protocols":["1.0"]},{"version":"2.0.0","protocols":["99.0"]},{"version":"1.2.0","protocols":["5.0"]}, {"version":"1.0.0","protocols":["5.0"]}]}`))
|
|
case "weaksauce/unsupported-protocol":
|
|
resp.Header().Set("Content-Type", "application/json")
|
|
resp.WriteHeader(200)
|
|
resp.Write([]byte(`{"versions":[{"version":"1.0.0","protocols":["0.1"]}]}`))
|
|
case "weaksauce/no-versions":
|
|
resp.Header().Set("Content-Type", "application/json")
|
|
resp.WriteHeader(200)
|
|
resp.Write([]byte(`{"versions":[],"warnings":["this provider is weaksauce"]}`))
|
|
case "-/legacy":
|
|
resp.Header().Set("Content-Type", "application/json")
|
|
resp.WriteHeader(200)
|
|
// This response is used for testing LookupLegacyProvider
|
|
resp.Write([]byte(`{"id":"legacycorp/legacy"}`))
|
|
case "-/moved":
|
|
resp.Header().Set("Content-Type", "application/json")
|
|
resp.WriteHeader(200)
|
|
// This response is used for testing LookupLegacyProvider
|
|
resp.Write([]byte(`{"id":"hashicorp/moved","moved_to":"acme/moved"}`))
|
|
case "-/changetype":
|
|
resp.Header().Set("Content-Type", "application/json")
|
|
resp.WriteHeader(200)
|
|
// This (unrealistic) response is used for error handling code coverage
|
|
resp.Write([]byte(`{"id":"legacycorp/newtype"}`))
|
|
case "-/invalid":
|
|
resp.Header().Set("Content-Type", "application/json")
|
|
resp.WriteHeader(200)
|
|
// This (unrealistic) response is used for error handling code coverage
|
|
resp.Write([]byte(`{"id":"some/invalid/id/string"}`))
|
|
default:
|
|
resp.WriteHeader(404)
|
|
resp.Write([]byte(`unknown namespace or provider type`))
|
|
}
|
|
return
|
|
}
|
|
|
|
if len(pathParts) == 6 && pathParts[3] == "download" {
|
|
switch pathParts[0] + "/" + pathParts[1] {
|
|
case "awesomesauce/happycloud":
|
|
if pathParts[4] == "nonexist" {
|
|
resp.WriteHeader(404)
|
|
resp.Write([]byte(`unsupported OS`))
|
|
return
|
|
}
|
|
var protocols []string
|
|
version := pathParts[2]
|
|
switch version {
|
|
case "0.1.0":
|
|
protocols = []string{"1.0"}
|
|
case "2.0.0":
|
|
protocols = []string{"99.0"}
|
|
default:
|
|
protocols = []string{"5.0"}
|
|
}
|
|
|
|
body := map[string]interface{}{
|
|
"protocols": protocols,
|
|
"os": pathParts[4],
|
|
"arch": pathParts[5],
|
|
"filename": "happycloud_" + version + ".zip",
|
|
"shasum": "000000000000000000000000000000000000000000000000000000000000f00d",
|
|
"download_url": "/pkg/awesomesauce/happycloud_" + version + ".zip",
|
|
"shasums_url": "/pkg/awesomesauce/happycloud_" + version + "_SHA256SUMS",
|
|
"shasums_signature_url": "/pkg/awesomesauce/happycloud_" + version + "_SHA256SUMS.sig",
|
|
"signing_keys": map[string]interface{}{
|
|
"gpg_public_keys": []map[string]interface{}{
|
|
{
|
|
"ascii_armor": HashicorpPublicKey,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
enc, err := json.Marshal(body)
|
|
if err != nil {
|
|
resp.WriteHeader(500)
|
|
resp.Write([]byte("failed to encode body"))
|
|
}
|
|
resp.Header().Set("Content-Type", "application/json")
|
|
resp.WriteHeader(200)
|
|
resp.Write(enc)
|
|
default:
|
|
resp.WriteHeader(404)
|
|
resp.Write([]byte(`unknown namespace/provider/version/architecture`))
|
|
}
|
|
return
|
|
}
|
|
|
|
resp.WriteHeader(404)
|
|
resp.Write([]byte(`unrecognized path scheme`))
|
|
}
|
|
|
|
func TestProviderVersions(t *testing.T) {
|
|
source, _, close := testRegistrySource(t)
|
|
defer close()
|
|
|
|
tests := []struct {
|
|
provider addrs.Provider
|
|
wantVersions map[string][]string
|
|
wantErr string
|
|
}{
|
|
{
|
|
addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"),
|
|
map[string][]string{
|
|
"0.1.0": {"1.0"},
|
|
"1.0.0": {"5.0"},
|
|
"1.2.0": {"5.0"},
|
|
"2.0.0": {"99.0"},
|
|
},
|
|
``,
|
|
},
|
|
{
|
|
addrs.MustParseProviderSourceString("example.com/weaksauce/no-versions"),
|
|
nil,
|
|
``,
|
|
},
|
|
{
|
|
addrs.MustParseProviderSourceString("example.com/nonexist/nonexist"),
|
|
nil,
|
|
`provider registry example.com does not have a provider named example.com/nonexist/nonexist`,
|
|
},
|
|
}
|
|
for _, test := range tests {
|
|
t.Run(test.provider.String(), func(t *testing.T) {
|
|
client, err := source.registryClient(test.provider.Hostname)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
gotVersions, _, err := client.ProviderVersions(test.provider)
|
|
|
|
if err != nil {
|
|
if test.wantErr == "" {
|
|
t.Fatalf("wrong error\ngot: %s\nwant: <nil>", err.Error())
|
|
}
|
|
if got, want := err.Error(), test.wantErr; got != want {
|
|
t.Fatalf("wrong error\ngot: %s\nwant: %s", got, want)
|
|
}
|
|
return
|
|
}
|
|
|
|
if test.wantErr != "" {
|
|
t.Fatalf("wrong error\ngot: <nil>\nwant: %s", test.wantErr)
|
|
}
|
|
|
|
if diff := cmp.Diff(test.wantVersions, gotVersions); diff != "" {
|
|
t.Errorf("wrong result\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFindClosestProtocolCompatibleVersion(t *testing.T) {
|
|
source, _, close := testRegistrySource(t)
|
|
defer close()
|
|
|
|
tests := map[string]struct {
|
|
provider addrs.Provider
|
|
version Version
|
|
wantSuggestion Version
|
|
wantErr string
|
|
}{
|
|
"pinned version too old": {
|
|
addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"),
|
|
MustParseVersion("0.1.0"),
|
|
MustParseVersion("1.2.0"),
|
|
``,
|
|
},
|
|
"pinned version too new": {
|
|
addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"),
|
|
MustParseVersion("2.0.0"),
|
|
MustParseVersion("1.2.0"),
|
|
``,
|
|
},
|
|
// This should not actually happen, the function is only meant to be
|
|
// called when the requested provider version is not supported
|
|
"pinned version just right": {
|
|
addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"),
|
|
MustParseVersion("1.2.0"),
|
|
MustParseVersion("1.2.0"),
|
|
``,
|
|
},
|
|
"nonexisting provider": {
|
|
addrs.MustParseProviderSourceString("example.com/nonexist/nonexist"),
|
|
MustParseVersion("1.2.0"),
|
|
versions.Unspecified,
|
|
`provider registry example.com does not have a provider named example.com/nonexist/nonexist`,
|
|
},
|
|
"versionless provider": {
|
|
addrs.MustParseProviderSourceString("example.com/weaksauce/no-versions"),
|
|
MustParseVersion("1.2.0"),
|
|
versions.Unspecified,
|
|
``,
|
|
},
|
|
"unsupported provider protocol": {
|
|
addrs.MustParseProviderSourceString("example.com/weaksauce/unsupported-protocol"),
|
|
MustParseVersion("1.0.0"),
|
|
versions.Unspecified,
|
|
``,
|
|
},
|
|
}
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
client, err := source.registryClient(test.provider.Hostname)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
got, err := client.findClosestProtocolCompatibleVersion(test.provider, test.version)
|
|
|
|
if err != nil {
|
|
if test.wantErr == "" {
|
|
t.Fatalf("wrong error\ngot: %s\nwant: <nil>", err.Error())
|
|
}
|
|
if got, want := err.Error(), test.wantErr; got != want {
|
|
t.Fatalf("wrong error\ngot: %s\nwant: %s", got, want)
|
|
}
|
|
return
|
|
}
|
|
|
|
if test.wantErr != "" {
|
|
t.Fatalf("wrong error\ngot: <nil>\nwant: %s", test.wantErr)
|
|
}
|
|
|
|
fmt.Printf("Got: %s, Want: %s\n", got, test.wantSuggestion)
|
|
|
|
if !got.Same(test.wantSuggestion) {
|
|
t.Fatalf("wrong result\ngot: %s\nwant: %s", got.String(), test.wantSuggestion.String())
|
|
}
|
|
})
|
|
}
|
|
}
|