check protocol version on plugins

Verify that the plugin we're requesting has a compatible protocol
version.
This commit is contained in:
James Bardin 2017-05-31 17:38:55 -04:00 committed by Martin Atkins
parent fdeb3d929c
commit 8ad67991a5
2 changed files with 85 additions and 16 deletions

View File

@ -7,15 +7,19 @@ import (
"log"
"net/http"
"runtime"
"strconv"
"strings"
"golang.org/x/net/html"
cleanhttp "github.com/hashicorp/go-cleanhttp"
getter "github.com/hashicorp/go-getter"
)
const releasesURL = "https://releases.hashicorp.com/"
var httpClient = cleanhttp.DefaultClient()
// pluginURL generates URLs to lookup the versions of a plugin, or the file path.
//
// The URL for releases follows the pattern:
@ -69,6 +73,16 @@ func GetProvider(dst, provider string, req Constraints) error {
return err
}
if len(versions) == 0 {
return fmt.Errorf("no plugins found for provider %q", provider)
}
versions = filterProtocolVersions(provider, versions)
if len(versions) == 0 {
return fmt.Errorf("no versions of %q compatible with the plugin ProtocolVersion", provider)
}
version, err := newestVersion(versions, req)
if err != nil {
return fmt.Errorf("no version of %q available that fulfills constraints %s", provider, req)
@ -80,6 +94,48 @@ func GetProvider(dst, provider string, req Constraints) error {
return getter.Get(dst, url)
}
// Remove available versions that don't have the correct plugin protocol version.
// TODO: stop checking older versions if the protocol version is too low
func filterProtocolVersions(provider string, versions []Version) []Version {
var compatible []Version
for _, v := range versions {
log.Printf("[DEBUG] fetching provider info for %s version %s", provider, v)
url := providersURL.fileURL(provider, v.String())
resp, err := httpClient.Head(url)
if err != nil {
log.Printf("[ERROR] error fetching plugin headers: %s", err)
continue
}
if resp.StatusCode != http.StatusOK {
log.Println("[ERROR] non-200 status fetching plugin headers:", resp.Status)
continue
}
proto := resp.Header.Get("X-TERRAFORM_PROTOCOL_VERSION")
if proto == "" {
log.Println("[WARNING] missing X-TERRAFORM_PROTOCOL_VERSION from:", url)
continue
}
protoVersion, err := strconv.Atoi(proto)
if err != nil {
log.Println("[ERROR] invalid ProtocolVersion: %s", proto)
continue
}
// FIXME: this shouldn't be hardcoded
if protoVersion != 4 {
log.Printf("[INFO] incompatible ProtocolVersion %d from %s version %s", protoVersion, provider, v)
continue
}
compatible = append(compatible, v)
}
return compatible
}
var errVersionNotFound = errors.New("version not found")
// take the list of available versions for a plugin, and the required
@ -128,11 +184,8 @@ func listProvisionerVersions(name string) ([]Version, error) {
}
// return a list of the plugin versions at the given URL
// TODO: This doesn't yet take into account plugin protocol version.
// That may have to be checked via an http header via a separate request
// to each plugin file.
func listPluginVersions(url string) ([]Version, error) {
resp, err := http.Get(url)
resp, err := httpClient.Get(url)
if err != nil {
return nil, err
}
@ -167,8 +220,12 @@ func listPluginVersions(url string) ([]Version, error) {
}
f(body)
var versions []Version
return versionsFromNames(names), nil
}
// parse the list of directory names into a sorted list of available versions
func versionsFromNames(names []string) []Version {
var versions []Version
for _, name := range names {
parts := strings.SplitN(name, "_", 2)
if len(parts) == 2 && parts[1] != "" {
@ -183,5 +240,6 @@ func listPluginVersions(url string) ([]Version, error) {
}
}
return versions, nil
Versions(versions).Sort()
return versions
}

View File

@ -6,13 +6,19 @@ import (
"testing"
)
func TestVersionListing(t *testing.T) {
// lists a constant set of providers, and always returns a protocol version
// equal to the Patch number.
func testReleaseServer() *httptest.Server {
handler := http.NewServeMux()
handler.HandleFunc("/terraform-providers/terraform-provider-test/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(versionList))
})
server := httptest.NewServer(handler)
return httptest.NewServer(handler)
}
func TestVersionListing(t *testing.T) {
server := testReleaseServer()
defer server.Close()
providersURL.releases = server.URL + "/"
@ -22,17 +28,22 @@ func TestVersionListing(t *testing.T) {
t.Fatal(err)
}
expectedSet := map[string]bool{
"1.2.4": true,
"1.2.3": true,
"1.2.1": true,
Versions(versions).Sort()
expected := []string{
"1.2.4",
"1.2.3",
"1.2.1",
}
for _, v := range versions {
if !expectedSet[v.String()] {
t.Fatalf("didn't get version %s in listing", v)
if len(versions) != len(expected) {
t.Fatalf("Received wrong number of versions. expected: %q, got: %q", expected, versions)
}
for i, v := range versions {
if v.String() != expected[i] {
t.Fatalf("incorrect version: %q, expected %q", v, expected[i])
}
delete(expectedSet, v.String())
}
}