diff --git a/config/module/get.go b/config/module/get.go index 96b4a63c3..e6eb1afbd 100644 --- a/config/module/get.go +++ b/config/module/get.go @@ -1,10 +1,17 @@ package module import ( + "fmt" "io/ioutil" + "log" + "net/http" "os" + "regexp" + "strings" "github.com/hashicorp/go-getter" + + cleanhttp "github.com/hashicorp/go-cleanhttp" ) // GetMode is an enum that describes how modules are loaded. @@ -69,3 +76,96 @@ func getStorage(s getter.Storage, key string, src string, mode GetMode) (string, // Get the directory where the module is. return s.Dir(key) } + +const ( + registryAPI = "https://registry.terraform.io/v1/modules/" + xTerraformGet = "X-Terraform-Get" +) + +var detectors = []getter.Detector{ + new(getter.GitHubDetector), + new(getter.BitBucketDetector), + new(getter.S3Detector), + new(registryDetector), + new(getter.FileDetector), +} + +// these prefixes can't be registry IDs +// "http", "./", "/", "getter::" +var skipRegistry = regexp.MustCompile(`^(http|\./|/|[A-Za-z0-9]+::)`).MatchString + +// registryDetector implements getter.Detector to detect Terraform Registry modules. +// If a path looks like a registry module identifier, attempt to locate it in +// the registry. If it's not found, pass it on in case it can be found by +// other means. +type registryDetector struct { + // override the default registry URL + api string + + client *http.Client +} + +func (d registryDetector) Detect(src, _ string) (string, bool, error) { + // the namespace can't start with "http", a relative or absolute path, or + // contain a go-getter "forced getter" + if skipRegistry(src) { + return "", false, nil + } + + // there are 3 parts to a registry ID + if len(strings.Split(src, "/")) != 3 { + return "", false, nil + } + + return d.lookupModule(src) +} + +// Lookup the module in the registry. +// Since existing module sources may match a registry ID format, we only log +// registry errors and continue discovery. +func (d registryDetector) lookupModule(src string) (string, bool, error) { + if d.api == "" { + d.api = registryAPI + } + + if d.client == nil { + d.client = cleanhttp.DefaultClient() + } + + // src is already partially validated in Detect. We know it's a path, and + // if it can be parsed as a URL we will hand it off to the registry to + // determine if it's truly valid. + resp, err := d.client.Get(fmt.Sprintf("%s/%s/download", d.api, src)) + if err != nil { + log.Println("[WARN] error looking up module %q: %s", src, err) + return "", false, nil + } + defer resp.Body.Close() + + // there should be no body, but save it for logging + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Println("[WARN] error reading response body from registry: %s", err) + return "", false, nil + } + + switch resp.StatusCode { + case http.StatusOK, http.StatusNoContent: + // OK + case http.StatusNotFound: + log.Printf("[INFO] module %q not found in registry", src) + return "", false, nil + default: + // anything else is an error: + log.Printf("[WARN] error getting download location for %q: %s resp:%s", src, resp.Status, body) + return "", false, nil + } + + // the download location is in the X-Terraform-Get header + location := resp.Header.Get(xTerraformGet) + if location == "" { + return "", false, fmt.Errorf("failed to get download URL for %q: %s resp:%s", src, resp.Status, body) + } + + return location, true, nil +} diff --git a/config/module/get_test.go b/config/module/get_test.go new file mode 100644 index 000000000..039c8b6b8 --- /dev/null +++ b/config/module/get_test.go @@ -0,0 +1,143 @@ +package module + +import ( + "fmt" + "net/http" + "net/http/httptest" + "regexp" + "sort" + "strings" + "testing" + + version "github.com/hashicorp/go-version" +) + +// map of module names and version for test module. +// only one version for now, as we only lookup latest from the registry +var testMods = map[string]string{ + "registry/foo/bar": "0.2.3", + "registry/foo/baz": "1.10.0", +} + +func latestVersion(versions []string) string { + var col version.Collection + for _, v := range versions { + ver, err := version.NewVersion(v) + if err != nil { + panic(err) + } + col = append(col, ver) + } + + sort.Sort(col) + return col[len(col)-1].String() +} + +// Just enough like a registry to exercise our code. +// Returns the location of the latest version +func mockRegistry() *httptest.Server { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + + mux.Handle("/v1/modules/", + http.StripPrefix("/v1/modules/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p := strings.TrimLeft(r.URL.Path, "/") + // handle download request + download := regexp.MustCompile(`^(\w+/\w+/\w+)/download$`) + + // download lookup + matches := download.FindStringSubmatch(p) + if len(matches) != 2 { + w.WriteHeader(http.StatusBadRequest) + return + } + + version, ok := testMods[matches[1]] + if !ok { + w.WriteHeader(http.StatusNotFound) + return + } + + location := fmt.Sprintf("%s/download/%s/%s", server.URL, matches[1], version) + w.Header().Set(xTerraformGet, location) + w.WriteHeader(http.StatusNoContent) + // no body + return + })), + ) + + return server +} + +func TestDetectRegistry(t *testing.T) { + server := mockRegistry() + defer server.Close() + + detector := registryDetector{ + api: server.URL + "/v1/modules/", + client: server.Client(), + } + + for _, tc := range []struct { + module string + location string + found bool + err bool + }{ + { + module: "registry/foo/bar", + location: "download/registry/foo/bar/0.2.3", + found: true, + }, + { + module: "registry/foo/baz", + location: "download/registry/foo/baz/1.10.0", + found: true, + }, + // this should not be found, but not stop detection + { + module: "registry/foo/notfound", + found: false, + }, + + // a full url should not be detected + { + module: "http://example.com/registry/foo/notfound", + found: false, + }, + + // paths should not be detected + { + module: "./local/foo/notfound", + found: false, + }, + { + module: "/local/foo/notfound", + found: false, + }, + + // wrong number of parts can't be regisry IDs + { + module: "something/registry/foo/notfound", + found: false, + }, + } { + + t.Run(tc.module, func(t *testing.T) { + loc, ok, err := detector.Detect(tc.module, "") + if (err == nil) == tc.err { + t.Fatalf("expected error? %t; got error :%v", tc.err, err) + } + + if ok != tc.found { + t.Fatalf("expected OK == %t", tc.found) + } + + loc = strings.TrimPrefix(loc, server.URL+"/") + if strings.TrimPrefix(loc, server.URL) != tc.location { + t.Fatalf("expected location: %q, got %q", tc.location, loc) + } + }) + + } +} diff --git a/config/module/tree.go b/config/module/tree.go index 4b0b153f7..4ac0b11fb 100644 --- a/config/module/tree.go +++ b/config/module/tree.go @@ -180,7 +180,7 @@ func (t *Tree) Load(s getter.Storage, mode GetMode) error { // Split out the subdir if we have one source, subDir := getter.SourceDirSubdir(m.Source) - source, err := getter.Detect(source, t.config.Dir, getter.Detectors) + source, err := getter.Detect(source, t.config.Dir, detectors) if err != nil { return fmt.Errorf("module %s: %s", m.Name, err) }