diff --git a/internal/getproviders/registry_client.go b/internal/getproviders/registry_client.go index 907ddbd9c..e302db812 100644 --- a/internal/getproviders/registry_client.go +++ b/internal/getproviders/registry_client.go @@ -7,20 +7,52 @@ import ( "errors" "fmt" "io/ioutil" + "log" "net/http" "net/url" + "os" "path" + "strconv" "time" + "github.com/hashicorp/go-retryablehttp" svchost "github.com/hashicorp/terraform-svchost" svcauth "github.com/hashicorp/terraform-svchost/auth" "github.com/hashicorp/terraform/addrs" + "github.com/hashicorp/terraform/helper/logging" "github.com/hashicorp/terraform/httpclient" "github.com/hashicorp/terraform/version" ) -const terraformVersionHeader = "X-Terraform-Version" +const ( + terraformVersionHeader = "X-Terraform-Version" + + // registryDiscoveryRetryEnvName is the name of the environment variable that + // can be configured to customize number of retries for module and provider + // discovery requests with the remote registry. + registryDiscoveryRetryEnvName = "TF_REGISTRY_DISCOVERY_RETRY" + defaultRetry = 1 + + // registryClientTimeoutEnvName is the name of the environment variable that + // can be configured to customize the timeout duration (seconds) for module + // and provider discovery with the remote registry. + registryClientTimeoutEnvName = "TF_REGISTRY_CLIENT_TIMEOUT" + + // defaultRequestTimeout is the default timeout duration for requests to the + // remote registry. + defaultRequestTimeout = 10 * time.Second +) + +var ( + discoveryRetry int + requestTimeout time.Duration +) + +func init() { + configureDiscoveryRetry() + configureRequestTimeout() +} var SupportedPluginProtocols = MustParseVersionConstraints("~> 5") @@ -31,17 +63,30 @@ type registryClient struct { baseURL *url.URL creds svcauth.HostCredentials - httpClient *http.Client + httpClient *retryablehttp.Client } func newRegistryClient(baseURL *url.URL, creds svcauth.HostCredentials) *registryClient { httpClient := httpclient.New() - httpClient.Timeout = 10 * time.Second + httpClient.Timeout = requestTimeout + + retryableClient := retryablehttp.NewClient() + retryableClient.HTTPClient = httpClient + retryableClient.RetryMax = discoveryRetry + retryableClient.RequestLogHook = requestLogHook + retryableClient.ErrorHandler = maxRetryErrorHandler + + logOutput, err := logging.LogOutput() + if err != nil { + log.Printf("[WARN] Failed to set up registry client logger, "+ + "continuing without client logging: %s", err) + } + retryableClient.Logger = log.New(logOutput, "", log.Flags()) return ®istryClient{ baseURL: baseURL, creds: creds, - httpClient: httpClient, + httpClient: retryableClient, } } @@ -61,11 +106,11 @@ func (c *registryClient) ProviderVersions(addr addrs.Provider) (map[string][]str } endpointURL := c.baseURL.ResolveReference(endpointPath) - req, err := http.NewRequest("GET", endpointURL.String(), nil) + req, err := retryablehttp.NewRequest("GET", endpointURL.String(), nil) if err != nil { return nil, err } - c.addHeadersToRequest(req) + c.addHeadersToRequest(req.Request) resp, err := c.httpClient.Do(req) if err != nil { @@ -140,11 +185,11 @@ func (c *registryClient) PackageMeta(provider addrs.Provider, version Version, t } endpointURL := c.baseURL.ResolveReference(endpointPath) - req, err := http.NewRequest("GET", endpointURL.String(), nil) + req, err := retryablehttp.NewRequest("GET", endpointURL.String(), nil) if err != nil { return PackageMeta{}, err } - c.addHeadersToRequest(req) + c.addHeadersToRequest(req.Request) resp, err := c.httpClient.Do(req) if err != nil { @@ -372,11 +417,11 @@ func (c *registryClient) LegacyProviderDefaultNamespace(typeName string) (string } endpointURL := c.baseURL.ResolveReference(endpointPath) - req, err := http.NewRequest("GET", endpointURL.String(), nil) + req, err := retryablehttp.NewRequest("GET", endpointURL.String(), nil) if err != nil { return "", err } - c.addHeadersToRequest(req) + c.addHeadersToRequest(req.Request) // This is just to give us something to return in error messages. It's // not a proper provider address. @@ -462,3 +507,60 @@ func (c *registryClient) getFile(url *url.URL) ([]byte, error) { return data, nil } + +// configureDiscoveryRetry configures the number of retries the registry client +// will attempt for requests with retryable errors, like 502 status codes +func configureDiscoveryRetry() { + discoveryRetry = defaultRetry + + if v := os.Getenv(registryDiscoveryRetryEnvName); v != "" { + retry, err := strconv.Atoi(v) + if err == nil && retry > 0 { + discoveryRetry = retry + } + } +} + +func requestLogHook(logger retryablehttp.Logger, req *http.Request, i int) { + if i > 0 { + logger.Printf("[INFO] Previous request to the remote registry failed, attempting retry.") + } +} + +func maxRetryErrorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) { + // Close the body per library instructions + if resp != nil { + resp.Body.Close() + } + + // Additional error detail: if we have a response, use the status code; + // if we have an error, use that; otherwise nothing. We will never have + // both response and error. + var errMsg string + if resp != nil { + errMsg = fmt.Sprintf(": %d", resp.StatusCode) + } else if err != nil { + errMsg = fmt.Sprintf(": %s", err) + } + + // This function is always called with numTries=RetryMax+1. If we made any + // retry attempts, include that in the error message. + if numTries > 1 { + return resp, fmt.Errorf("the request failed after %d attempts, please try again later%s", + numTries, errMsg) + } + return resp, fmt.Errorf("the request failed, please try again later%s", errMsg) +} + +// configureRequestTimeout configures the registry client request timeout from +// environment variables +func configureRequestTimeout() { + requestTimeout = defaultRequestTimeout + + if v := os.Getenv(registryClientTimeoutEnvName); v != "" { + timeout, err := strconv.Atoi(v) + if err == nil && timeout > 0 { + requestTimeout = time.Duration(timeout) * time.Second + } + } +} diff --git a/internal/getproviders/registry_client_test.go b/internal/getproviders/registry_client_test.go index dbbc74ff8..8e1bf271d 100644 --- a/internal/getproviders/registry_client_test.go +++ b/internal/getproviders/registry_client_test.go @@ -6,8 +6,10 @@ import ( "log" "net/http" "net/http/httptest" + "os" "strings" "testing" + "time" "github.com/apparentlymart/go-versions/versions" "github.com/google/go-cmp/cmp" @@ -16,6 +18,77 @@ import ( "github.com/hashicorp/terraform/addrs" ) +func TestConfigureDiscoveryRetry(t *testing.T) { + t.Run("default retry", func(t *testing.T) { + if discoveryRetry != defaultRetry { + t.Fatalf("expected retry %q, got %q", defaultRetry, discoveryRetry) + } + + rc := newRegistryClient(nil, nil) + if rc.httpClient.RetryMax != defaultRetry { + t.Fatalf("expected client retry %q, got %q", + defaultRetry, rc.httpClient.RetryMax) + } + }) + + t.Run("configured retry", func(t *testing.T) { + defer func(retryEnv string) { + os.Setenv(registryDiscoveryRetryEnvName, retryEnv) + discoveryRetry = defaultRetry + }(os.Getenv(registryDiscoveryRetryEnvName)) + os.Setenv(registryDiscoveryRetryEnvName, "2") + + configureDiscoveryRetry() + expected := 2 + if discoveryRetry != expected { + t.Fatalf("expected retry %q, got %q", + expected, discoveryRetry) + } + + rc := newRegistryClient(nil, nil) + if rc.httpClient.RetryMax != expected { + t.Fatalf("expected client retry %q, got %q", + expected, rc.httpClient.RetryMax) + } + }) +} + +func TestConfigureRegistryClientTimeout(t *testing.T) { + t.Run("default timeout", func(t *testing.T) { + if requestTimeout != defaultRequestTimeout { + t.Fatalf("expected timeout %q, got %q", + defaultRequestTimeout.String(), requestTimeout.String()) + } + + rc := newRegistryClient(nil, nil) + if rc.httpClient.HTTPClient.Timeout != defaultRequestTimeout { + t.Fatalf("expected client timeout %q, got %q", + defaultRequestTimeout.String(), rc.httpClient.HTTPClient.Timeout.String()) + } + }) + + t.Run("configured timeout", func(t *testing.T) { + defer func(timeoutEnv string) { + os.Setenv(registryClientTimeoutEnvName, timeoutEnv) + requestTimeout = defaultRequestTimeout + }(os.Getenv(registryClientTimeoutEnvName)) + os.Setenv(registryClientTimeoutEnvName, "20") + + configureRequestTimeout() + expected := 20 * time.Second + if requestTimeout != expected { + t.Fatalf("expected timeout %q, got %q", + expected, requestTimeout.String()) + } + + rc := newRegistryClient(nil, nil) + if rc.httpClient.HTTPClient.Timeout != expected { + t.Fatalf("expected client timeout %q, got %q", + expected, rc.httpClient.HTTPClient.Timeout.String()) + } + }) +} + // testServices starts up a local HTTP server running a fake provider registry // service and returns a service discovery object pre-configured to consider // the host "example.com" to be served by the fake registry service. diff --git a/internal/getproviders/registry_source_test.go b/internal/getproviders/registry_source_test.go index 258985d10..19848f9bb 100644 --- a/internal/getproviders/registry_source_test.go +++ b/internal/getproviders/registry_source_test.go @@ -52,7 +52,7 @@ func TestSourceAvailableVersions(t *testing.T) { { "fails.example.com/foo/bar", nil, - `could not query provider registry for fails.example.com/foo/bar: Get "` + baseURL + `/fails-immediately/foo/bar/versions": EOF`, + `could not query provider registry for fails.example.com/foo/bar: the request failed after 2 attempts, please try again later: Get "` + baseURL + `/fails-immediately/foo/bar/versions": EOF`, }, } @@ -169,7 +169,7 @@ func TestSourcePackageMeta(t *testing.T) { "1.2.0", "linux", "amd64", PackageMeta{}, - `could not query provider registry for fails.example.com/awesomesauce/happycloud: Get "http://placeholder-origin/fails-immediately/awesomesauce/happycloud/1.2.0/download/linux/amd64": EOF`, + `could not query provider registry for fails.example.com/awesomesauce/happycloud: the request failed after 2 attempts, please try again later: Get "http://placeholder-origin/fails-immediately/awesomesauce/happycloud/1.2.0/download/linux/amd64": EOF`, }, }