diff --git a/plugin/discovery/get.go b/plugin/discovery/get.go index c84e7b1bd..2eff52d81 100644 --- a/plugin/discovery/get.go +++ b/plugin/discovery/get.go @@ -69,91 +69,71 @@ func GetProvider(dst, provider string, req Constraints, pluginProtocolVersion ui return fmt.Errorf("no plugins found for provider %q", provider) } - versions = filterProtocolVersions(provider, versions, pluginProtocolVersion) - + versions = allowedVersions(versions, req) if len(versions) == 0 { - return fmt.Errorf("no versions of %q compatible with the plugin ProtocolVersion", provider) - } - - version, err := newestVersion(versions, req) - if err != nil { return fmt.Errorf("no version of %q available that fulfills constraints %s", provider, req) } - url := providerURL(provider, version.String()) + // sort them newest to oldest + Versions(versions).Sort() - log.Printf("[DEBUG] getting provider %q version %q at %s", provider, version, url) - return getter.Get(dst, url) -} - -// Remove available versions that don't have the correct plugin protocol version. -// TODO: stop checking older versions if the protocol version is too low -func filterProtocolVersions(provider string, versions []Version, pluginProtocolVersion uint) []Version { - var compatible []Version + // take the first matching plugin we find for _, v := range versions { - log.Printf("[DEBUG] fetching provider info for %s version %s", provider, v) url := providerURL(provider, v.String()) - resp, err := httpClient.Head(url) - if err != nil { - log.Printf("[ERROR] error fetching plugin headers: %s", err) - continue + log.Printf("[DEBUG] fetching provider info for %s version %s", provider, v) + if checkPlugin(url, pluginProtocolVersion) { + log.Printf("[DEBUG] getting provider %q version %q at %s", provider, v, url) + return getter.Get(dst, url) } - if resp.StatusCode != http.StatusOK { - log.Println("[ERROR] non-200 status fetching plugin headers:", resp.Status) - continue - } - - proto := resp.Header.Get(protocolVersionHeader) - if proto == "" { - log.Printf("[WARNING] missing %s from: %s", protocolVersionHeader, url) - continue - } - - protoVersion, err := strconv.Atoi(proto) - if err != nil { - log.Printf("[ERROR] invalid ProtocolVersion: %s", proto) - continue - } - - if protoVersion != int(pluginProtocolVersion) { - log.Printf("[INFO] incompatible ProtocolVersion %d from %s version %s", protoVersion, provider, v) - continue - } - - compatible = append(compatible, v) + log.Printf("[INFO] incompatible ProtocolVersion for %s version %s", provider, v) } - return compatible + return fmt.Errorf("no versions of %q compatible with the plugin ProtocolVersion", provider) +} + +// Return the plugin version by making a HEAD request to the provided url +func checkPlugin(url string, pluginProtocolVersion uint) bool { + resp, err := httpClient.Head(url) + if err != nil { + log.Printf("[ERROR] error fetching plugin headers: %s", err) + return false + } + + if resp.StatusCode != http.StatusOK { + log.Println("[ERROR] non-200 status fetching plugin headers:", resp.Status) + return false + } + + proto := resp.Header.Get(protocolVersionHeader) + if proto == "" { + log.Printf("[WARNING] missing %s from: %s", protocolVersionHeader, url) + return false + } + + protoVersion, err := strconv.Atoi(proto) + if err != nil { + log.Printf("[ERROR] invalid ProtocolVersion: %s", proto) + return false + } + + return protoVersion == int(pluginProtocolVersion) } var errVersionNotFound = errors.New("version not found") -// take the list of available versions for a plugin, and the required -// Constraints, and return the latest available version that satisfies the -// constraints. -func newestVersion(available []Version, required Constraints) (Version, error) { - var latest Version - found := false +// take the list of available versions for a plugin, and filter out those that +// don't fit the constraints. +func allowedVersions(available []Version, required Constraints) []Version { + var allowed []Version for _, v := range available { if required.Allows(v) { - if !found { - latest = v - found = true - continue - } - - if v.NewerThan(latest) { - latest = v - } + allowed = append(allowed, v) } } - if !found { - return latest, errVersionNotFound - } - return latest, nil + return allowed } // list the version available for the named plugin @@ -222,6 +202,5 @@ func versionsFromNames(names []string) []Version { } } - Versions(versions).Sort() return versions } diff --git a/plugin/discovery/get_test.go b/plugin/discovery/get_test.go index 10c4d57d7..879be9ab2 100644 --- a/plugin/discovery/get_test.go +++ b/plugin/discovery/get_test.go @@ -97,57 +97,13 @@ func TestVersionListing(t *testing.T) { } } -func TestNewestVersion(t *testing.T) { - var available []Version - for _, v := range []string{"1.2.3", "1.2.1", "1.2.4"} { - version, err := VersionStr(v).Parse() - if err != nil { - t.Fatal(err) - } - available = append(available, version) +func TestCheckProtocolVersions(t *testing.T) { + if checkPlugin(providerURL("test", VersionStr("1.2.3").MustParse().String()), 4) { + t.Fatal("protocol version 4 is not compatible") } - reqd, err := ConstraintStr(">1.2.1").Parse() - if err != nil { - t.Fatal(err) - } - - found, err := newestVersion(available, reqd) - if err != nil { - t.Fatal(err) - } - - if found.String() != "1.2.4" { - t.Fatalf("expected newest version 1.2.4, got: %s", found) - } - - reqd, err = ConstraintStr("> 1.2.4").Parse() - if err != nil { - t.Fatal(err) - } - - found, err = newestVersion(available, reqd) - if err == nil { - t.Fatalf("expceted error, got version %s", found) - } -} - -func TestFilterProtocolVersions(t *testing.T) { - versions, err := listProviderVersions("test") - if err != nil { - t.Fatal(err) - } - - // use plugin protocl version 3, which should only return version 1.2.3 - compat := filterProtocolVersions("test", versions, 3) - - if len(compat) != 1 || compat[0].String() != "1.2.3" { - t.Fatal("found wrong versions: %q", compat) - } - - compat = filterProtocolVersions("test", versions, 6) - if len(compat) != 0 { - t.Fatal("should be no compatible versions, got: %q", compat) + if !checkPlugin(providerURL("test", VersionStr("1.2.3").MustParse().String()), 3) { + t.Fatal("protocol version 3 should be compatible") } } @@ -159,13 +115,19 @@ func TestGetProvider(t *testing.T) { defer os.RemoveAll(tmpDir) - fileName := fmt.Sprintf("terraform-provider-test_1.2.3_%s_%s_X3", runtime.GOOS, runtime.GOARCH) + // attempt to use an incompatible protocol version + err = GetProvider(tmpDir, "test", AllVersions, 5) + if err == nil { + t.Fatal("protocol version is incompatible") + } err = GetProvider(tmpDir, "test", AllVersions, 3) if err != nil { t.Fatal(err) } + // we should have version 1.2.3 + fileName := fmt.Sprintf("terraform-provider-test_1.2.3_%s_%s_X3", runtime.GOOS, runtime.GOARCH) dest := filepath.Join(tmpDir, fileName) f, err := ioutil.ReadFile(dest) if err != nil {