From 9162213b0122e68068c2322b2ce49f8631ab5986 Mon Sep 17 00:00:00 2001 From: James Bardin Date: Mon, 20 Nov 2017 15:12:10 -0500 Subject: [PATCH 1/8] reimport the registry regsrc module --- registry/regsrc/friendly_host.go | 28 +++++++++++++++++++++------ registry/regsrc/friendly_host_test.go | 24 +++++++---------------- registry/regsrc/module.go | 25 +++++++++--------------- registry/regsrc/module_test.go | 10 ---------- 4 files changed, 38 insertions(+), 49 deletions(-) diff --git a/registry/regsrc/friendly_host.go b/registry/regsrc/friendly_host.go index 648e2a193..28ca9b0aa 100644 --- a/registry/regsrc/friendly_host.go +++ b/registry/regsrc/friendly_host.go @@ -4,7 +4,7 @@ import ( "regexp" "strings" - "github.com/hashicorp/terraform/svchost" + "golang.org/x/net/idna" ) var ( @@ -95,26 +95,42 @@ func ParseFriendlyHost(source string) (host *FriendlyHost, rest string) { // name specifications. Not that IDN prefixes containing punycode are not valid // input which we expect to always be in user-input or normalised display form. func (h *FriendlyHost) Valid() bool { - return svchost.IsValid(h.Raw) + if h.Display() == InvalidHostString { + return false + } + if h.Normalized() == InvalidHostString { + return false + } + if containsPuny(h.Raw) { + return false + } + return true } // Display returns the host formatted for display to the user in CLI or web // output. func (h *FriendlyHost) Display() string { - hostname, err := svchost.ForComparison(h.Raw) + parts := strings.SplitN(h.Raw, ":", 2) + var err error + parts[0], err = idna.Display.ToUnicode(parts[0]) if err != nil { return InvalidHostString } - return hostname.ForDisplay() + return strings.Join(parts, ":") } // Normalized returns the host formatted for internal reference or comparison. func (h *FriendlyHost) Normalized() string { - hostname, err := svchost.ForComparison(h.Raw) + // For now IDNA does all the normalisation we need including case-folding + // pure ASCII to lower. But breaks if a custom port is included while we + // want to allow that and normalize comparison including it, + parts := strings.SplitN(h.Raw, ":", 2) + var err error + parts[0], err = idna.Lookup.ToASCII(parts[0]) if err != nil { return InvalidHostString } - return hostname.String() + return strings.Join(parts, ":") } // String returns the host formatted as the user originally typed it assuming it diff --git a/registry/regsrc/friendly_host_test.go b/registry/regsrc/friendly_host_test.go index 740395bf6..e87774cfe 100644 --- a/registry/regsrc/friendly_host_test.go +++ b/registry/regsrc/friendly_host_test.go @@ -95,36 +95,26 @@ func TestFriendlyHost(t *testing.T) { if v := gotHost.String(); v != tt.wantHost { t.Fatalf("String() = %v, want %v", v, tt.wantHost) } - if v := gotHost.Valid(); v != tt.wantValid { - t.Fatalf("Valid() = %v, want %v", v, tt.wantValid) - } - - // FIXME: should we allow punycode as input - if !tt.wantValid { - return - } - if v := gotHost.Display(); v != tt.wantDisplay { t.Fatalf("Display() = %v, want %v", v, tt.wantDisplay) } if v := gotHost.Normalized(); v != tt.wantNorm { t.Fatalf("Normalized() = %v, want %v", v, tt.wantNorm) } + if v := gotHost.Valid(); v != tt.wantValid { + t.Fatalf("Valid() = %v, want %v", v, tt.wantValid) + } if gotRest != strings.TrimLeft(sfx, "/") { t.Fatalf("ParseFriendlyHost() rest = %v, want %v", gotRest, strings.TrimLeft(sfx, "/")) } // Also verify that host compares equal with all the variants. if !gotHost.Equal(&FriendlyHost{Raw: tt.wantDisplay}) { - t.Fatalf("Equal() should be true for %s and %t", tt.wantHost, tt.wantValid) + t.Fatalf("Equal() should be true for %s and %s", tt.wantHost, tt.wantValid) + } + if !gotHost.Equal(&FriendlyHost{Raw: tt.wantNorm}) { + t.Fatalf("Equal() should be true for %s and %s", tt.wantHost, tt.wantNorm) } - - // FIXME: Do we need to accept normalized input? - //if !gotHost.Equal(&FriendlyHost{Raw: tt.wantNorm}) { - // fmt.Println(gotHost.Normalized(), tt.wantNorm) - // fmt.Println(" ", (&FriendlyHost{Raw: tt.wantNorm}).Normalized()) - // t.Fatalf("Equal() should be true for %s and %s", tt.wantHost, tt.wantNorm) - //} }) } diff --git a/registry/regsrc/module.go b/registry/regsrc/module.go index b6671c8a4..3080cddb6 100644 --- a/registry/regsrc/module.go +++ b/registry/regsrc/module.go @@ -33,12 +33,13 @@ var ( fmt.Sprintf("^(%s)\\/(%s)\\/(%s)(?:\\/\\/(.*))?$", nameSubRe, nameSubRe, providerSubRe)) - // disallowed is a set of hostnames that have special usage in modules and - // can't be registry hosts - disallowed = map[string]bool{ - "github.com": true, - "bitbucket.org": true, - } + // NameRe is a regular expression defining the format allowed for namespace + // or name fields in module registry implementations. + NameRe = regexp.MustCompile("^" + nameSubRe + "$") + + // ProviderRe is a regular expression defining the format allowed for + // provider fields in module registry implementations. + ProviderRe = regexp.MustCompile("^" + providerSubRe + "$") ) // Module describes a Terraform Registry Module source. @@ -84,10 +85,8 @@ func NewModule(host, namespace, name, provider, submodule string) *Module { func ParseModuleSource(source string) (*Module, error) { // See if there is a friendly host prefix. host, rest := ParseFriendlyHost(source) - if host != nil { - if !host.Valid() || disallowed[host.Display()] { - return nil, ErrInvalidModuleSource - } + if host != nil && !host.Valid() { + return nil, ErrInvalidModuleSource } matches := moduleSourceRe.FindStringSubmatch(rest) @@ -132,12 +131,6 @@ func (m *Module) String() string { return m.formatWithPrefix(hostPrefix, true) } -// Module returns just the registry ID of the module, without a hostname or -// suffix. -func (m *Module) Module() string { - return fmt.Sprintf("%s/%s/%s", m.RawNamespace, m.RawName, m.RawProvider) -} - // Equal compares the module source against another instance taking // normalization into account. func (m *Module) Equal(other *Module) bool { diff --git a/registry/regsrc/module_test.go b/registry/regsrc/module_test.go index bae502b0d..19d9dfa19 100644 --- a/registry/regsrc/module_test.go +++ b/registry/regsrc/module_test.go @@ -96,16 +96,6 @@ func TestModule(t *testing.T) { source: "foo.com/var/baz?otherthing", wantErr: true, }, - { - name: "disallow github", - source: "github.com/HashiCorp/Consul/aws", - wantErr: true, - }, - { - name: "disallow bitbucket", - source: "bitbucket.org/HashiCorp/Consul/aws", - wantErr: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 1ac5871a091f0ce26669d9d61614e595f1b2bfaa Mon Sep 17 00:00:00 2001 From: James Bardin Date: Mon, 20 Nov 2017 16:09:27 -0500 Subject: [PATCH 2/8] use svchost.Hostname for FriendlyHost validation Use the svchost.Hostname for FriendlyHost normalization and validation. --- registry/regsrc/friendly_host.go | 42 +++++++------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/registry/regsrc/friendly_host.go b/registry/regsrc/friendly_host.go index 28ca9b0aa..ff105bb75 100644 --- a/registry/regsrc/friendly_host.go +++ b/registry/regsrc/friendly_host.go @@ -4,7 +4,7 @@ import ( "regexp" "strings" - "golang.org/x/net/idna" + "github.com/hashicorp/terraform/svchost" ) var ( @@ -95,42 +95,22 @@ func ParseFriendlyHost(source string) (host *FriendlyHost, rest string) { // name specifications. Not that IDN prefixes containing punycode are not valid // input which we expect to always be in user-input or normalised display form. func (h *FriendlyHost) Valid() bool { - if h.Display() == InvalidHostString { - return false - } - if h.Normalized() == InvalidHostString { - return false - } - if containsPuny(h.Raw) { - return false - } - return true + return svchost.IsValid(h.Raw) } // Display returns the host formatted for display to the user in CLI or web // output. func (h *FriendlyHost) Display() string { - parts := strings.SplitN(h.Raw, ":", 2) - var err error - parts[0], err = idna.Display.ToUnicode(parts[0]) - if err != nil { - return InvalidHostString - } - return strings.Join(parts, ":") + return svchost.ForDisplay(h.Raw) } // Normalized returns the host formatted for internal reference or comparison. func (h *FriendlyHost) Normalized() string { - // For now IDNA does all the normalisation we need including case-folding - // pure ASCII to lower. But breaks if a custom port is included while we - // want to allow that and normalize comparison including it, - parts := strings.SplitN(h.Raw, ":", 2) - var err error - parts[0], err = idna.Lookup.ToASCII(parts[0]) + host, err := svchost.ForComparison(h.Raw) if err != nil { return InvalidHostString } - return strings.Join(parts, ":") + return string(host) } // String returns the host formatted as the user originally typed it assuming it @@ -140,19 +120,15 @@ func (h *FriendlyHost) String() string { } // Equal compares the FriendlyHost against another instance taking normalization -// into account. +// into account. Invalid hosts cannot be compared and will always return false. func (h *FriendlyHost) Equal(other *FriendlyHost) bool { if other == nil { return false } + return h.Normalized() == other.Normalized() } -func containsPuny(host string) bool { - for _, lbl := range strings.Split(host, ".") { - if strings.HasPrefix(strings.ToLower(lbl), "xn--") { - return true - } - } - return false +func (h *FriendlyHost) SvcHost() (svchost.Hostname, error) { + return svchost.ForComparison(h.Raw) } From 98d0d15ddcd90086037b85c5e8dac8ebd0d7247d Mon Sep 17 00:00:00 2001 From: James Bardin Date: Mon, 20 Nov 2017 16:11:39 -0500 Subject: [PATCH 3/8] Update the FriendlyHost tests for svchost.Hostname This no longer allows normalization of punycode hostnames. This shouldn't be a problem, as they were not valid in the first place. --- registry/regsrc/friendly_host_test.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/registry/regsrc/friendly_host_test.go b/registry/regsrc/friendly_host_test.go index e87774cfe..53cec8f84 100644 --- a/registry/regsrc/friendly_host_test.go +++ b/registry/regsrc/friendly_host_test.go @@ -59,7 +59,7 @@ func TestFriendlyHost(t *testing.T) { source: "xn--s-fka0wmm0zea7g8b.xn--o-8ta85a3b1dwcda1k.io", wantHost: "xn--s-fka0wmm0zea7g8b.xn--o-8ta85a3b1dwcda1k.io", wantDisplay: "ʎɹʇsıƃǝɹ.ɯɹoɟɐɹɹǝʇ.io", - wantNorm: "xn--s-fka0wmm0zea7g8b.xn--o-8ta85a3b1dwcda1k.io", + wantNorm: InvalidHostString, wantValid: false, }, { @@ -109,13 +109,9 @@ func TestFriendlyHost(t *testing.T) { } // Also verify that host compares equal with all the variants. - if !gotHost.Equal(&FriendlyHost{Raw: tt.wantDisplay}) { - t.Fatalf("Equal() should be true for %s and %s", tt.wantHost, tt.wantValid) + if gotHost.Valid() && !gotHost.Equal(&FriendlyHost{Raw: tt.wantDisplay}) { + t.Fatalf("Equal() should be true for %s and %s", tt.wantHost, tt.wantDisplay) } - if !gotHost.Equal(&FriendlyHost{Raw: tt.wantNorm}) { - t.Fatalf("Equal() should be true for %s and %s", tt.wantHost, tt.wantNorm) - } - }) } } From 92db96f783ba19fb48bdfbb91f084dd93c1ba011 Mon Sep 17 00:00:00 2001 From: James Bardin Date: Mon, 20 Nov 2017 16:44:50 -0500 Subject: [PATCH 4/8] disallow github and bitbucket --- registry/regsrc/module.go | 26 +++++++++++++++++++++----- registry/regsrc/module_test.go | 10 ++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/registry/regsrc/module.go b/registry/regsrc/module.go index 3080cddb6..05f9bb201 100644 --- a/registry/regsrc/module.go +++ b/registry/regsrc/module.go @@ -40,6 +40,13 @@ var ( // ProviderRe is a regular expression defining the format allowed for // provider fields in module registry implementations. ProviderRe = regexp.MustCompile("^" + providerSubRe + "$") + + // these hostnames are not allowed as registry sources, because they are + // already special case module sources in terraform. + disallowed = map[string]bool{ + "github.com": true, + "bitbucket.org": true, + } ) // Module describes a Terraform Registry Module source. @@ -60,7 +67,7 @@ type Module struct { // NewModule construct a new module source from separate parts. Pass empty // string if host or submodule are not needed. -func NewModule(host, namespace, name, provider, submodule string) *Module { +func NewModule(host, namespace, name, provider, submodule string) (*Module, error) { m := &Module{ RawNamespace: namespace, RawName: name, @@ -68,9 +75,16 @@ func NewModule(host, namespace, name, provider, submodule string) *Module { RawSubmodule: submodule, } if host != "" { - m.RawHost = NewFriendlyHost(host) + h := NewFriendlyHost(host) + if h != nil { + fmt.Println("HOST:", h) + if !h.Valid() || disallowed[h.Display()] { + return nil, ErrInvalidModuleSource + } + } + m.RawHost = h } - return m + return m, nil } // ParseModuleSource attempts to parse source as a Terraform registry module @@ -85,8 +99,10 @@ func NewModule(host, namespace, name, provider, submodule string) *Module { func ParseModuleSource(source string) (*Module, error) { // See if there is a friendly host prefix. host, rest := ParseFriendlyHost(source) - if host != nil && !host.Valid() { - return nil, ErrInvalidModuleSource + if host != nil { + if !host.Valid() || disallowed[host.Display()] { + return nil, ErrInvalidModuleSource + } } matches := moduleSourceRe.FindStringSubmatch(rest) diff --git a/registry/regsrc/module_test.go b/registry/regsrc/module_test.go index 19d9dfa19..bae502b0d 100644 --- a/registry/regsrc/module_test.go +++ b/registry/regsrc/module_test.go @@ -96,6 +96,16 @@ func TestModule(t *testing.T) { source: "foo.com/var/baz?otherthing", wantErr: true, }, + { + name: "disallow github", + source: "github.com/HashiCorp/Consul/aws", + wantErr: true, + }, + { + name: "disallow bitbucket", + source: "bitbucket.org/HashiCorp/Consul/aws", + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 87f23d9719afdffb517338dca268978525d96f39 Mon Sep 17 00:00:00 2001 From: James Bardin Date: Mon, 20 Nov 2017 16:48:11 -0500 Subject: [PATCH 5/8] add Module method for module name only --- registry/regsrc/module.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/registry/regsrc/module.go b/registry/regsrc/module.go index 05f9bb201..e267f302a 100644 --- a/registry/regsrc/module.go +++ b/registry/regsrc/module.go @@ -184,3 +184,9 @@ func (m *Module) formatWithPrefix(hostPrefix string, preserveCase bool) string { } return str } + +// Module returns just the registry ID of the module, without a hostname or +// suffix. +func (m *Module) Module() string { + return fmt.Sprintf("%s/%s/%s", m.RawNamespace, m.RawName, m.RawProvider) +} From bd576d780a5082d35bdc93c4f6ee8604a3ce9093 Mon Sep 17 00:00:00 2001 From: James Bardin Date: Mon, 20 Nov 2017 17:42:35 -0500 Subject: [PATCH 6/8] failing test for module not found error The "not found" error should use the raw string directly from the config source, but the existing method was adding the default registry if there was no host indicated. --- config/module/registry_test.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/config/module/registry_test.go b/config/module/registry_test.go index 54c8f818e..dab7444c2 100644 --- a/config/module/registry_test.go +++ b/config/module/registry_test.go @@ -2,6 +2,7 @@ package module import ( "os" + "strings" "testing" version "github.com/hashicorp/go-version" @@ -93,6 +94,7 @@ func TestRegistryAuth(t *testing.T) { } } + func TestLookupModuleLocationRelative(t *testing.T) { server := mockRegistry() defer server.Close() @@ -117,6 +119,7 @@ func TestLookupModuleLocationRelative(t *testing.T) { } } + func TestAccLookupModuleVersions(t *testing.T) { if os.Getenv("TF_ACC") == "" { t.Skip() @@ -163,3 +166,29 @@ func TestAccLookupModuleVersions(t *testing.T) { } } } + +// the error should reference the config source exatly, not the discovered path. +func TestLookupLookupModuleError(t *testing.T) { + server := mockRegistry() + defer server.Close() + + regDisco := testDisco(server) + storage := testStorage(t, regDisco) + + // this should not be found in teh registry + src := "bad/local/path" + mod, err := regsrc.ParseModuleSource(src) + if err != nil { + t.Fatal(err) + } + + _, err = storage.lookupModuleLocation(mod, "0.2.0") + if err == nil { + t.Fatal("expected error") + } + + // check for the exact quoted string to ensure we didn't prepend a hostname. + if !strings.Contains(err.Error(), `"bad/local/path"`) { + t.Fatal("error should not include the hostname. got:", err) + } +} From 8091bd627dfa7dbf29b398d56f524a892a7b1e8e Mon Sep 17 00:00:00 2001 From: James Bardin Date: Mon, 20 Nov 2017 17:43:36 -0500 Subject: [PATCH 7/8] move Svchost method to regsrc.Module The level of abstraction that needs the "svchost" is the Module, not the FriendlyHost. Us the new method in the module package for registry interaction. --- config/module/registry.go | 27 ++++++++++++++------------- config/module/storage.go | 4 +++- registry/regsrc/friendly_host.go | 4 ---- registry/regsrc/module.go | 13 +++++++++++++ 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/config/module/registry.go b/config/module/registry.go index 10209c4bf..da67c5ab9 100644 --- a/config/module/registry.go +++ b/config/module/registry.go @@ -44,8 +44,8 @@ func (e errModuleNotFound) Error() string { return `module "` + string(e) + `" not found` } -func (s *Storage) discoverRegURL(module *regsrc.Module) *url.URL { - regURL := s.Services.DiscoverServiceURL(svchost.Hostname(module.RawHost.Normalized()), serviceID) +func (s *Storage) discoverRegURL(host svchost.Hostname) *url.URL { + regURL := s.Services.DiscoverServiceURL(host, serviceID) if regURL == nil { return nil } @@ -75,13 +75,14 @@ func (s *Storage) addRequestCreds(host svchost.Hostname, req *http.Request) { // Lookup module versions in the registry. func (s *Storage) lookupModuleVersions(module *regsrc.Module) (*response.ModuleVersions, error) { - if module.RawHost == nil { - module.RawHost = regsrc.NewFriendlyHost(defaultRegistry) + host, err := module.SvcHost() + if err != nil { + return nil, err } - service := s.discoverRegURL(module) + service := s.discoverRegURL(host) if service == nil { - return nil, fmt.Errorf("host %s does not provide Terraform modules", module.RawHost.Display()) + return nil, fmt.Errorf("host %s does not provide Terraform modules", host) } p, err := url.Parse(path.Join(module.Module(), "versions")) @@ -98,7 +99,7 @@ func (s *Storage) lookupModuleVersions(module *regsrc.Module) (*response.ModuleV return nil, err } - s.addRequestCreds(svchost.Hostname(module.RawHost.Normalized()), req) + s.addRequestCreds(host, req) req.Header.Set(xTerraformVersion, tfVersion) resp, err := httpClient.Do(req) @@ -134,17 +135,17 @@ func (s *Storage) lookupModuleVersions(module *regsrc.Module) (*response.ModuleV // lookup the location of a specific module version in the registry func (s *Storage) lookupModuleLocation(module *regsrc.Module, version string) (string, error) { - if module.RawHost == nil { - module.RawHost = regsrc.NewFriendlyHost(defaultRegistry) + host, err := module.SvcHost() + if err != nil { + return "", err } - service := s.discoverRegURL(module) + service := s.discoverRegURL(host) if service == nil { - return "", fmt.Errorf("host %s does not provide Terraform modules", module.RawHost.Display()) + return "", fmt.Errorf("host %s does not provide Terraform modules", host.ForDisplay()) } var p *url.URL - var err error if version == "" { p, err = url.Parse(path.Join(module.Module(), "download")) } else { @@ -162,7 +163,7 @@ func (s *Storage) lookupModuleLocation(module *regsrc.Module, version string) (s return "", err } - s.addRequestCreds(svchost.Hostname(module.RawHost.Normalized()), req) + s.addRequestCreds(host, req) req.Header.Set(xTerraformVersion, tfVersion) resp, err := httpClient.Do(req) diff --git a/config/module/storage.go b/config/module/storage.go index 05065b3c6..121719765 100644 --- a/config/module/storage.go +++ b/config/module/storage.go @@ -343,7 +343,9 @@ func (s Storage) findRegistryModule(mSource, constraint string) (moduleRecord, e return rec, err } - s.output(fmt.Sprintf(" Found version %s of %s on %s", rec.Version, mod.Module(), mod.RawHost.Display())) + // we've already validated this by now + host, _ := mod.SvcHost() + s.output(fmt.Sprintf(" Found version %s of %s on %s", rec.Version, mod.Module(), host.ForDisplay())) } return rec, nil diff --git a/registry/regsrc/friendly_host.go b/registry/regsrc/friendly_host.go index ff105bb75..7d3cbdf2a 100644 --- a/registry/regsrc/friendly_host.go +++ b/registry/regsrc/friendly_host.go @@ -128,7 +128,3 @@ func (h *FriendlyHost) Equal(other *FriendlyHost) bool { return h.Normalized() == other.Normalized() } - -func (h *FriendlyHost) SvcHost() (svchost.Hostname, error) { - return svchost.ForComparison(h.Raw) -} diff --git a/registry/regsrc/module.go b/registry/regsrc/module.go index e267f302a..325706ec2 100644 --- a/registry/regsrc/module.go +++ b/registry/regsrc/module.go @@ -5,6 +5,8 @@ import ( "fmt" "regexp" "strings" + + "github.com/hashicorp/terraform/svchost" ) var ( @@ -190,3 +192,14 @@ func (m *Module) formatWithPrefix(hostPrefix string, preserveCase bool) string { func (m *Module) Module() string { return fmt.Sprintf("%s/%s/%s", m.RawNamespace, m.RawName, m.RawProvider) } + +// SvcHost returns the svchost.Hostname for this module. Since FriendlyHost may +// contain an invalid hostname, this also returns an error indicating if it +// could be converted to a svchost.Hostname. If no host is specified, the +// default PublicRegistryHost is returned. +func (m *Module) SvcHost() (svchost.Hostname, error) { + if m.RawHost == nil { + return svchost.ForComparison(PublicRegistryHost.Raw) + } + return svchost.ForComparison(m.RawHost.Raw) +} From 9034fdb050a17084c1198d77d71e2774665b4777 Mon Sep 17 00:00:00 2001 From: James Bardin Date: Mon, 20 Nov 2017 18:09:24 -0500 Subject: [PATCH 8/8] make sure invalid hosts aren't compared Comparing 2 invalid hosts would erroneously return equal, because they would compare the invalid host string. --- registry/regsrc/friendly_host.go | 12 +++++++++++- registry/regsrc/friendly_host_test.go | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/registry/regsrc/friendly_host.go b/registry/regsrc/friendly_host.go index 7d3cbdf2a..14b4dce9c 100644 --- a/registry/regsrc/friendly_host.go +++ b/registry/regsrc/friendly_host.go @@ -126,5 +126,15 @@ func (h *FriendlyHost) Equal(other *FriendlyHost) bool { return false } - return h.Normalized() == other.Normalized() + otherHost, err := svchost.ForComparison(other.Raw) + if err != nil { + return false + } + + host, err := svchost.ForComparison(h.Raw) + if err != nil { + return false + } + + return otherHost == host } diff --git a/registry/regsrc/friendly_host_test.go b/registry/regsrc/friendly_host_test.go index 53cec8f84..37589685d 100644 --- a/registry/regsrc/friendly_host_test.go +++ b/registry/regsrc/friendly_host_test.go @@ -116,3 +116,26 @@ func TestFriendlyHost(t *testing.T) { } } } + +func TestInvalidHostEquals(t *testing.T) { + invalid := NewFriendlyHost("NOT_A_HOST_NAME") + valid := PublicRegistryHost + + // invalid hosts are not comparable + if invalid.Equal(invalid) { + t.Fatal("invalid host names are not comparable") + } + + if valid.Equal(invalid) { + t.Fatalf("%q is not equal to %q", valid, invalid) + } + + puny := NewFriendlyHost("xn--s-fka0wmm0zea7g8b.xn--o-8ta85a3b1dwcda1k.io") + display := NewFriendlyHost("ʎɹʇsıƃǝɹ.ɯɹoɟɐɹɹǝʇ.io") + + // The pre-normalized host is not a valid source, and therefore not + // comparable to the display version. + if display.Equal(puny) { + t.Fatalf("invalid host %q should not be comparable", puny) + } +}