From 179b32d426025f9cb94f57488634ee4e3d9699a3 Mon Sep 17 00:00:00 2001 From: Sander van Harmelen Date: Thu, 5 Jul 2018 21:28:29 +0200 Subject: [PATCH] Add a `CredentialsForHost` method to disco.Disco By adding this method you now only have to pass a `*disco.Disco` object around in order to do discovery and use any configured credentials for the discovered hosts. Of course you can also still pass around both a `*disco.Disco` and a `auth.CredentialsSource` object if there is a need or a reason for that! --- command/command_test.go | 2 +- command/init.go | 2 +- command/meta.go | 7 +----- commands.go | 8 ++----- config/module/module_test.go | 2 +- config/module/storage.go | 13 ++++------- config/module/storage_test.go | 4 ++-- configs/configload/loader.go | 8 +------ configs/configload/module_mgr.go | 4 ---- main.go | 5 ++++- registry/client.go | 18 +++------------ registry/client_test.go | 37 ++++++++++++++++--------------- registry/test/mock_registry.go | 6 ++--- svchost/auth/credentials.go | 3 +++ svchost/auth/token_credentials.go | 5 +++++ svchost/disco/disco.go | 34 ++++++++++++++++++---------- svchost/disco/disco_test.go | 18 +++++++-------- 17 files changed, 81 insertions(+), 95 deletions(-) diff --git a/command/command_test.go b/command/command_test.go index 12d48761b..c0a8529c6 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -117,7 +117,7 @@ func testModule(t *testing.T, name string) *module.Tree { t.Fatalf("err: %s", err) } - s := module.NewStorage(tempDir(t), nil, nil) + s := module.NewStorage(tempDir(t), nil) s.Mode = module.GetModeGet if err := mod.Load(s); err != nil { t.Fatalf("err: %s", err) diff --git a/command/init.go b/command/init.go index efa4b5724..b96cdc6ec 100644 --- a/command/init.go +++ b/command/init.go @@ -129,7 +129,7 @@ func (c *InitCommand) Run(args []string) int { ))) header = true - s := module.NewStorage("", c.Services, c.Credentials) + s := module.NewStorage("", c.Services) if err := s.GetModule(path, src); err != nil { c.Ui.Error(fmt.Sprintf("Error copying source module: %s", err)) return 1 diff --git a/command/meta.go b/command/meta.go index 91f1008fe..f154f2d6c 100644 --- a/command/meta.go +++ b/command/meta.go @@ -25,7 +25,6 @@ import ( "github.com/hashicorp/terraform/helper/experiment" "github.com/hashicorp/terraform/helper/variables" "github.com/hashicorp/terraform/helper/wrappedstreams" - "github.com/hashicorp/terraform/svchost/auth" "github.com/hashicorp/terraform/svchost/disco" "github.com/hashicorp/terraform/terraform" "github.com/hashicorp/terraform/tfdiags" @@ -51,10 +50,6 @@ type Meta struct { // "terraform-native' services running at a specific user-facing hostname. Services *disco.Disco - // Credentials provides access to credentials for "terraform-native" - // services, which are accessed by a service hostname. - Credentials auth.CredentialsSource - // RunningInAutomation indicates that commands are being run by an // automated system rather than directly at a command prompt. // @@ -410,7 +405,7 @@ func (m *Meta) flagSet(n string) *flag.FlagSet { // moduleStorage returns the module.Storage implementation used to store // modules for commands. func (m *Meta) moduleStorage(root string, mode module.GetMode) *module.Storage { - s := module.NewStorage(filepath.Join(root, "modules"), m.Services, m.Credentials) + s := module.NewStorage(filepath.Join(root, "modules"), m.Services) s.Ui = m.Ui s.Mode = mode return s diff --git a/commands.go b/commands.go index 3335d2cdb..113c771eb 100644 --- a/commands.go +++ b/commands.go @@ -30,15 +30,12 @@ const ( OutputPrefix = "o:" ) -func initCommands(config *Config) { +func initCommands(config *Config, services *disco.Disco) { var inAutomation bool if v := os.Getenv(runningInAutomationEnvName); v != "" { inAutomation = true } - credsSrc := credentialsSource(config) - services := disco.NewDisco() - services.SetCredentialsSource(credsSrc) for userHost, hostConfig := range config.Hosts { host, err := svchost.ForComparison(userHost) if err != nil { @@ -57,8 +54,7 @@ func initCommands(config *Config) { PluginOverrides: &PluginOverrides, Ui: Ui, - Services: services, - Credentials: credsSrc, + Services: services, RunningInAutomation: inAutomation, PluginCacheDir: config.PluginCacheDir, diff --git a/config/module/module_test.go b/config/module/module_test.go index 62e7ed2a7..80e931e0b 100644 --- a/config/module/module_test.go +++ b/config/module/module_test.go @@ -44,5 +44,5 @@ func testConfig(t *testing.T, n string) *config.Config { func testStorage(t *testing.T, d *disco.Disco) *Storage { t.Helper() - return NewStorage(tempDir(t), d, nil) + return NewStorage(tempDir(t), d) } diff --git a/config/module/storage.go b/config/module/storage.go index fa5e1c621..4b828dcb0 100644 --- a/config/module/storage.go +++ b/config/module/storage.go @@ -11,7 +11,6 @@ import ( getter "github.com/hashicorp/go-getter" "github.com/hashicorp/terraform/registry" "github.com/hashicorp/terraform/registry/regsrc" - "github.com/hashicorp/terraform/svchost/auth" "github.com/hashicorp/terraform/svchost/disco" "github.com/mitchellh/cli" ) @@ -64,14 +63,10 @@ type Storage struct { // StorageDir is the full path to the directory where all modules will be // stored. StorageDir string - // Services is a required *disco.Disco, which may have services and - // credentials pre-loaded. - Services *disco.Disco - // Creds optionally provides credentials for communicating with service - // providers. - Creds auth.CredentialsSource + // Ui is an optional cli.Ui for user output Ui cli.Ui + // Mode is the GetMode that will be used for various operations. Mode GetMode @@ -79,8 +74,8 @@ type Storage struct { } // NewStorage returns a new initialized Storage object. -func NewStorage(dir string, services *disco.Disco, creds auth.CredentialsSource) *Storage { - regClient := registry.NewClient(services, creds, nil) +func NewStorage(dir string, services *disco.Disco) *Storage { + regClient := registry.NewClient(services, nil) return &Storage{ StorageDir: dir, diff --git a/config/module/storage_test.go b/config/module/storage_test.go index 10811190e..cb41f6d65 100644 --- a/config/module/storage_test.go +++ b/config/module/storage_test.go @@ -22,7 +22,7 @@ func TestGetModule(t *testing.T) { t.Fatal(err) } defer os.RemoveAll(td) - storage := NewStorage(td, disco, nil) + storage := NewStorage(td, disco) // this module exists in a test fixture, and is known by the test.Registry // relative to our cwd. @@ -139,7 +139,7 @@ func TestAccRegistryDiscover(t *testing.T) { t.Fatal(err) } - s := NewStorage("/tmp", nil, nil) + s := NewStorage("/tmp", nil) loc, err := s.registry.Location(module, "") if err != nil { t.Fatal(err) diff --git a/configs/configload/loader.go b/configs/configload/loader.go index 06ff27400..45e60f77c 100644 --- a/configs/configload/loader.go +++ b/configs/configload/loader.go @@ -5,7 +5,6 @@ import ( "github.com/hashicorp/terraform/configs" "github.com/hashicorp/terraform/registry" - "github.com/hashicorp/terraform/svchost/auth" "github.com/hashicorp/terraform/svchost/disco" "github.com/spf13/afero" ) @@ -39,10 +38,6 @@ type Config struct { // not supported, which should be true only in specialized circumstances // such as in tests. Services *disco.Disco - - // Creds is a credentials store for communicating with remote module - // registry endpoints. If this is nil then no credentials will be used. - Creds auth.CredentialsSource } // NewLoader creates and returns a loader that reads configuration from the @@ -54,7 +49,7 @@ type Config struct { func NewLoader(config *Config) (*Loader, error) { fs := afero.NewOsFs() parser := configs.NewParser(fs) - reg := registry.NewClient(config.Services, config.Creds, nil) + reg := registry.NewClient(config.Services, nil) ret := &Loader{ parser: parser, @@ -63,7 +58,6 @@ func NewLoader(config *Config) (*Loader, error) { CanInstall: true, Dir: config.ModulesDir, Services: config.Services, - Creds: config.Creds, Registry: reg, }, } diff --git a/configs/configload/module_mgr.go b/configs/configload/module_mgr.go index ef17fda7a..6b2a5199f 100644 --- a/configs/configload/module_mgr.go +++ b/configs/configload/module_mgr.go @@ -2,7 +2,6 @@ package configload import ( "github.com/hashicorp/terraform/registry" - "github.com/hashicorp/terraform/svchost/auth" "github.com/hashicorp/terraform/svchost/disco" "github.com/spf13/afero" ) @@ -25,9 +24,6 @@ type moduleMgr struct { // cached discovery information. Services *disco.Disco - // Creds provides optional credentials for communicating with service hosts. - Creds auth.CredentialsSource - // Registry is a client for the module registry protocol, which is used // when a module is requested from a registry source. Registry *registry.Client diff --git a/main.go b/main.go index 1818a91c4..523863e7b 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/terraform/command/format" "github.com/hashicorp/terraform/helper/logging" + "github.com/hashicorp/terraform/svchost/disco" "github.com/hashicorp/terraform/terraform" "github.com/mattn/go-colorable" "github.com/mattn/go-shellwords" @@ -144,7 +145,9 @@ func wrappedMain() int { // In tests, Commands may already be set to provide mock commands if Commands == nil { - initCommands(config) + credsSrc := credentialsSource(config) + services := disco.NewWithCredentialsSource(credsSrc) + initCommands(config, services) } // Run checkpoint diff --git a/registry/client.go b/registry/client.go index fba59ec87..8e31a6a3e 100644 --- a/registry/client.go +++ b/registry/client.go @@ -15,7 +15,6 @@ import ( "github.com/hashicorp/terraform/registry/regsrc" "github.com/hashicorp/terraform/registry/response" "github.com/hashicorp/terraform/svchost" - "github.com/hashicorp/terraform/svchost/auth" "github.com/hashicorp/terraform/svchost/disco" "github.com/hashicorp/terraform/version" ) @@ -37,20 +36,14 @@ type Client struct { // services is a required *disco.Disco, which may have services and // credentials pre-loaded. services *disco.Disco - - // Creds optionally provides credentials for communicating with service - // providers. - creds auth.CredentialsSource } // NewClient returns a new initialized registry client. -func NewClient(services *disco.Disco, creds auth.CredentialsSource, client *http.Client) *Client { +func NewClient(services *disco.Disco, client *http.Client) *Client { if services == nil { - services = disco.NewDisco() + services = disco.New() } - services.SetCredentialsSource(creds) - if client == nil { client = httpclient.New() client.Timeout = requestTimeout @@ -61,7 +54,6 @@ func NewClient(services *disco.Disco, creds auth.CredentialsSource, client *http return &Client{ client: client, services: services, - creds: creds, } } @@ -138,11 +130,7 @@ func (c *Client) Versions(module *regsrc.Module) (*response.ModuleVersions, erro } func (c *Client) addRequestCreds(host svchost.Hostname, req *http.Request) { - if c.creds == nil { - return - } - - creds, err := c.creds.ForHost(host) + creds, err := c.services.CredentialsForHost(host) if err != nil { log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err) return diff --git a/registry/client_test.go b/registry/client_test.go index 279c5a483..5ee712f7f 100644 --- a/registry/client_test.go +++ b/registry/client_test.go @@ -15,7 +15,7 @@ func TestLookupModuleVersions(t *testing.T) { server := test.Registry() defer server.Close() - client := NewClient(test.Disco(server), nil, nil) + client := NewClient(test.Disco(server), nil) // test with and without a hostname for _, src := range []string{ @@ -59,7 +59,7 @@ func TestInvalidRegistry(t *testing.T) { server := test.Registry() defer server.Close() - client := NewClient(test.Disco(server), nil, nil) + client := NewClient(test.Disco(server), nil) src := "non-existent.localhost.localdomain/test-versions/name/provider" modsrc, err := regsrc.ParseModuleSource(src) @@ -76,7 +76,7 @@ func TestRegistryAuth(t *testing.T) { server := test.Registry() defer server.Close() - client := NewClient(test.Disco(server), nil, nil) + client := NewClient(test.Disco(server), nil) src := "private/name/provider" mod, err := regsrc.ParseModuleSource(src) @@ -84,6 +84,18 @@ func TestRegistryAuth(t *testing.T) { t.Fatal(err) } + _, err = client.Versions(mod) + if err != nil { + t.Fatal(err) + } + _, err = client.Location(mod, "1.0.0") + if err != nil { + t.Fatal(err) + } + + // Also test without a credentials source + client.services.SetCredentialsSource(nil) + // both should fail without auth _, err = client.Versions(mod) if err == nil { @@ -93,24 +105,13 @@ func TestRegistryAuth(t *testing.T) { if err == nil { t.Fatal("expected error") } - - client = NewClient(test.Disco(server), test.Credentials, nil) - - _, err = client.Versions(mod) - if err != nil { - t.Fatal(err) - } - _, err = client.Location(mod, "1.0.0") - if err != nil { - t.Fatal(err) - } } func TestLookupModuleLocationRelative(t *testing.T) { server := test.Registry() defer server.Close() - client := NewClient(test.Disco(server), nil, nil) + client := NewClient(test.Disco(server), nil) src := "relative/foo/bar" mod, err := regsrc.ParseModuleSource(src) @@ -133,7 +134,7 @@ func TestAccLookupModuleVersions(t *testing.T) { if os.Getenv("TF_ACC") == "" { t.Skip() } - regDisco := disco.NewDisco() + regDisco := disco.New() // test with and without a hostname for _, src := range []string{ @@ -145,7 +146,7 @@ func TestAccLookupModuleVersions(t *testing.T) { t.Fatal(err) } - s := NewClient(regDisco, nil, nil) + s := NewClient(regDisco, nil) resp, err := s.Versions(modsrc) if err != nil { t.Fatal(err) @@ -179,7 +180,7 @@ func TestLookupLookupModuleError(t *testing.T) { server := test.Registry() defer server.Close() - client := NewClient(test.Disco(server), nil, nil) + client := NewClient(test.Disco(server), nil) // this should not be found in teh registry src := "bad/local/path" diff --git a/registry/test/mock_registry.go b/registry/test/mock_registry.go index c1fabbc25..bd3d80b7f 100644 --- a/registry/test/mock_registry.go +++ b/registry/test/mock_registry.go @@ -27,7 +27,7 @@ func Disco(s *httptest.Server) *disco.Disco { // TODO: add specific tests to enumerate both possibilities. "modules.v1": fmt.Sprintf("%s/v1/modules", s.URL), } - d := disco.NewDisco() + d := disco.NewWithCredentialsSource(credsSrc) d.ForceHostServices(svchost.Hostname("registry.terraform.io"), services) d.ForceHostServices(svchost.Hostname("localhost"), services) @@ -48,8 +48,8 @@ const ( ) var ( - regHost = svchost.Hostname(regsrc.PublicRegistryHost.Normalized()) - Credentials = auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{ + regHost = svchost.Hostname(regsrc.PublicRegistryHost.Normalized()) + credsSrc = auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{ regHost: {"token": testCred}, }) ) diff --git a/svchost/auth/credentials.go b/svchost/auth/credentials.go index 0bc6db4f1..0372c1609 100644 --- a/svchost/auth/credentials.go +++ b/svchost/auth/credentials.go @@ -42,6 +42,9 @@ type HostCredentials interface { // receiving credentials. The usual behavior of this method is to // add some sort of Authorization header to the request. PrepareRequest(req *http.Request) + + // Token returns the authentication token. + Token() string } // ForHost iterates over the contained CredentialsSource objects and diff --git a/svchost/auth/token_credentials.go b/svchost/auth/token_credentials.go index 8f771b0d9..9358bcb64 100644 --- a/svchost/auth/token_credentials.go +++ b/svchost/auth/token_credentials.go @@ -18,3 +18,8 @@ func (tc HostCredentialsToken) PrepareRequest(req *http.Request) { } req.Header.Set("Authorization", "Bearer "+string(tc)) } + +// Token returns the authentication token. +func (tc HostCredentialsToken) Token() string { + return string(tc) +} diff --git a/svchost/disco/disco.go b/svchost/disco/disco.go index 76a1b3b0d..7fc49da9c 100644 --- a/svchost/disco/disco.go +++ b/svchost/disco/disco.go @@ -42,9 +42,15 @@ type Disco struct { Transport http.RoundTripper } -// NewDisco returns a new initialized Disco object. -func NewDisco() *Disco { - return &Disco{} +// New returns a new initialized discovery object. +func New() *Disco { + return NewWithCredentialsSource(nil) +} + +// NewWithCredentialsSource returns a new discovery object initialized with +// the given credentials source. +func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco { + return &Disco{credsSrc: credsSrc} } // SetCredentialsSource provides a credentials source that will be used to @@ -56,6 +62,15 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) { d.credsSrc = src } +// CredentialsForHost returns a non-nil HostCredentials if the embedded source has +// credentials available for the host, and a nil HostCredentials if it does not. +func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, error) { + if d.credsSrc == nil { + return nil, nil + } + return d.credsSrc.ForHost(host) +} + // ForceHostServices provides a pre-defined set of services for a given // host, which prevents the receiver from attempting network-based discovery // for the given host. Instead, the given services map will be returned @@ -145,15 +160,10 @@ func (d *Disco) discover(host svchost.Hostname) Host { URL: discoURL, } - if d.credsSrc != nil { - creds, err := d.credsSrc.ForHost(host) - if err == nil { - if creds != nil { - creds.PrepareRequest(req) // alters req to include credentials - } - } else { - log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err) - } + if creds, err := d.CredentialsForHost(host); err != nil { + log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err) + } else if creds != nil { + creds.PrepareRequest(req) // alters req to include credentials } log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL) diff --git a/svchost/disco/disco_test.go b/svchost/disco/disco_test.go index 94d2a220f..c8bc16c45 100644 --- a/svchost/disco/disco_test.go +++ b/svchost/disco/disco_test.go @@ -45,7 +45,7 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - d := NewDisco() + d := New() discovered := d.Discover(host) gotURL := discovered.ServiceURL("thingy.v1") if gotURL == nil { @@ -80,7 +80,7 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - d := NewDisco() + d := New() discovered := d.Discover(host) gotURL := discovered.ServiceURL("wotsit.v2") if gotURL == nil { @@ -107,7 +107,7 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - d := NewDisco() + d := New() d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{ host: map[string]interface{}{ "token": "abc123", @@ -124,7 +124,7 @@ func TestDiscover(t *testing.T) { "wotsit.v2": "/foo", } - d := NewDisco() + d := New() d.ForceHostServices(svchost.Hostname("example.com"), forced) givenHost := "example.com" @@ -167,7 +167,7 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - d := NewDisco() + d := New() discovered := d.Discover(host) // result should be empty, which we can verify only by reaching into @@ -190,7 +190,7 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - d := NewDisco() + d := New() discovered := d.Discover(host) // result should be empty, which we can verify only by reaching into @@ -217,7 +217,7 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - d := NewDisco() + d := New() discovered := d.Discover(host) if discovered.services == nil { @@ -236,7 +236,7 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - d := NewDisco() + d := New() discovered := d.Discover(host) // result should be empty, which we can verify only by reaching into @@ -267,7 +267,7 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - d := NewDisco() + d := New() discovered := d.Discover(host) gotURL := discovered.ServiceURL("thingy.v1")