From 31a9790080514744500656c9e229033e2d90e0e7 Mon Sep 17 00:00:00 2001 From: Martin Atkins Date: Wed, 7 Aug 2019 16:30:56 -0700 Subject: [PATCH] svchost/disco: Allow oauth client services to specify grant types Previously we just assumed support for the authorization code grant type, but now we'll allow the host to declare which grant types it supports to allow for more flexibility in host login implementations. We may extend the set of supported grant types in future. --- svchost/disco/host.go | 33 ++++- svchost/disco/host_test.go | 238 ++++++++++++++++++++++++++++++++++ svchost/disco/oauth_client.go | 123 ++++++++++++++++++ 3 files changed, 391 insertions(+), 3 deletions(-) diff --git a/svchost/disco/host.go b/svchost/disco/host.go index eee84809f..228eadeef 100644 --- a/svchost/disco/host.go +++ b/svchost/disco/host.go @@ -166,7 +166,30 @@ func (h *Host) ServiceOAuthClient(id string) (*OAuthClient, error) { return nil, fmt.Errorf("Service %s must be declared with an object value in the service discovery document", id) } - ret := &OAuthClient{} + var grantTypes OAuthGrantTypeSet + if rawGTs, ok := raw["grant_types"]; ok { + if gts, ok := rawGTs.([]interface{}); ok { + var kws []string + for _, gtI := range gts { + gt, ok := gtI.(string) + if !ok { + // We'll ignore this so that we can potentially introduce + // other types into this array later if we need to. + continue + } + kws = append(kws, gt) + } + grantTypes = NewOAuthGrantTypeSet(kws...) + } else { + return nil, fmt.Errorf("Service %s is defined with invalid grant_types property: must be an array of grant type strings", id) + } + } else { + grantTypes = NewOAuthGrantTypeSet("authz_code") + } + + ret := &OAuthClient{ + SupportedGrantTypes: grantTypes, + } if clientIDStr, ok := raw["client"].(string); ok { ret.ID = clientIDStr } else { @@ -179,7 +202,9 @@ func (h *Host) ServiceOAuthClient(id string) (*OAuthClient, error) { } ret.AuthorizationURL = u } else { - return nil, fmt.Errorf("Service %s definition is missing required property \"authz\"", id) + if grantTypes.RequiresAuthorizationEndpoint() { + return nil, fmt.Errorf("Service %s definition is missing required property \"authz\"", id) + } } if urlStr, ok := raw["token"].(string); ok { u, err := h.parseURL(urlStr) @@ -188,7 +213,9 @@ func (h *Host) ServiceOAuthClient(id string) (*OAuthClient, error) { } ret.TokenURL = u } else { - return nil, fmt.Errorf("Service %s definition is missing required property \"token\"", id) + if grantTypes.RequiresTokenEndpoint() { + return nil, fmt.Errorf("Service %s definition is missing required property \"token\"", id) + } } if portsRaw, ok := raw["ports"].([]interface{}); ok { if len(portsRaw) != 2 { diff --git a/svchost/disco/host_test.go b/svchost/disco/host_test.go index 5a96aa418..91f7861ec 100644 --- a/svchost/disco/host_test.go +++ b/svchost/disco/host_test.go @@ -11,6 +11,8 @@ import ( "strconv" "strings" "testing" + + "github.com/google/go-cmp/cmp" ) func TestHostServiceURL(t *testing.T) { @@ -69,6 +71,242 @@ func TestHostServiceURL(t *testing.T) { } } +func TestHostServiceOAuthClient(t *testing.T) { + baseURL, _ := url.Parse("https://example.com/disco/foo.json") + host := Host{ + discoURL: baseURL, + hostname: "test-server", + services: map[string]interface{}{ + "explicitgranttype.v1": map[string]interface{}{ + "client": "explicitgranttype", + "authz": "./authz", + "token": "./token", + "grant_types": []interface{}{"authz_code", "password", "tbd"}, + }, + "customports.v1": map[string]interface{}{ + "client": "customports", + "authz": "./authz", + "token": "./token", + "ports": []interface{}{1025, 1026}, + }, + "invalidports.v1": map[string]interface{}{ + "client": "invalidports", + "authz": "./authz", + "token": "./token", + "ports": []interface{}{1, 65535}, + }, + "missingauthz.v1": map[string]interface{}{ + "client": "missingauthz", + "token": "./token", + }, + "missingtoken.v1": map[string]interface{}{ + "client": "missingtoken", + "authz": "./authz", + }, + "passwordmissingauthz.v1": map[string]interface{}{ + "client": "passwordmissingauthz", + "token": "./token", + "grant_types": []interface{}{"password"}, + }, + "absolute.v1": map[string]interface{}{ + "client": "absolute", + "authz": "http://example.net/foo/authz", + "token": "http://example.net/foo/token", + }, + "absolutewithport.v1": map[string]interface{}{ + "client": "absolutewithport", + "authz": "http://example.net:8000/foo/authz", + "token": "http://example.net:8000/foo/token", + }, + "relative.v1": map[string]interface{}{ + "client": "relative", + "authz": "./authz", + "token": "./token", + }, + "rootrelative.v1": map[string]interface{}{ + "client": "rootrelative", + "authz": "/authz", + "token": "/token", + }, + "protorelative.v1": map[string]interface{}{ + "client": "protorelative", + "authz": "//example.net/authz", + "token": "//example.net/token", + }, + "nothttp.v1": map[string]interface{}{ + "client": "nothttp", + "authz": "ftp://127.0.0.1/pub/authz", + "token": "ftp://127.0.0.1/pub/token", + }, + "invalidauthz.v1": map[string]interface{}{ + "client": "invalidauthz", + "authz": "***not A URL at all!:/<@@@@>***", + "token": "/foo", + }, + "invalidtoken.v1": map[string]interface{}{ + "client": "invalidauthz", + "authz": "/foo", + "token": "***not A URL at all!:/<@@@@>***", + }, + }, + } + + mustURL := func(t *testing.T, s string) *url.URL { + t.Helper() + u, err := url.Parse(s) + if err != nil { + t.Fatalf("invalid wanted URL %s in test case: %s", s, err) + } + return u + } + + tests := []struct { + ID string + want *OAuthClient + err string + }{ + { + "explicitgranttype.v1", + &OAuthClient{ + ID: "explicitgranttype", + AuthorizationURL: mustURL(t, "https://example.com/disco/authz"), + TokenURL: mustURL(t, "https://example.com/disco/token"), + MinPort: 1024, + MaxPort: 65535, + SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code", "password", "tbd"), + }, + "", + }, + { + "customports.v1", + &OAuthClient{ + ID: "customports", + AuthorizationURL: mustURL(t, "https://example.com/disco/authz"), + TokenURL: mustURL(t, "https://example.com/disco/token"), + MinPort: 1025, + MaxPort: 1026, + SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"), + }, + "", + }, + { + "invalidports.v1", + nil, + `Invalid "ports" definition for service invalidports.v1: both ports must be whole numbers between 1024 and 65535`, + }, + { + "missingauthz.v1", + nil, + `Service missingauthz.v1 definition is missing required property "authz"`, + }, + { + "missingtoken.v1", + nil, + `Service missingtoken.v1 definition is missing required property "token"`, + }, + { + "passwordmissingauthz.v1", + &OAuthClient{ + ID: "passwordmissingauthz", + TokenURL: mustURL(t, "https://example.com/disco/token"), + MinPort: 1024, + MaxPort: 65535, + SupportedGrantTypes: NewOAuthGrantTypeSet("password"), + }, + "", + }, + { + "absolute.v1", + &OAuthClient{ + ID: "absolute", + AuthorizationURL: mustURL(t, "http://example.net/foo/authz"), + TokenURL: mustURL(t, "http://example.net/foo/token"), + MinPort: 1024, + MaxPort: 65535, + SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"), + }, + "", + }, + { + "absolutewithport.v1", + &OAuthClient{ + ID: "absolutewithport", + AuthorizationURL: mustURL(t, "http://example.net:8000/foo/authz"), + TokenURL: mustURL(t, "http://example.net:8000/foo/token"), + MinPort: 1024, + MaxPort: 65535, + SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"), + }, + "", + }, + { + "relative.v1", + &OAuthClient{ + ID: "relative", + AuthorizationURL: mustURL(t, "https://example.com/disco/authz"), + TokenURL: mustURL(t, "https://example.com/disco/token"), + MinPort: 1024, + MaxPort: 65535, + SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"), + }, + "", + }, + { + "rootrelative.v1", + &OAuthClient{ + ID: "rootrelative", + AuthorizationURL: mustURL(t, "https://example.com/authz"), + TokenURL: mustURL(t, "https://example.com/token"), + MinPort: 1024, + MaxPort: 65535, + SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"), + }, + "", + }, + { + "protorelative.v1", + &OAuthClient{ + ID: "protorelative", + AuthorizationURL: mustURL(t, "https://example.net/authz"), + TokenURL: mustURL(t, "https://example.net/token"), + MinPort: 1024, + MaxPort: 65535, + SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"), + }, + "", + }, + { + "nothttp.v1", + nil, + "Failed to parse authorization URL: unsupported scheme ftp", + }, + { + "invalidauthz.v1", + nil, + "Failed to parse authorization URL: parse ***not A URL at all!:/<@@@@>***: first path segment in URL cannot contain colon", + }, + { + "invalidtoken.v1", + nil, + "Failed to parse token URL: parse ***not A URL at all!:/<@@@@>***: first path segment in URL cannot contain colon", + }, + } + + for _, test := range tests { + t.Run(test.ID, func(t *testing.T) { + got, err := host.ServiceOAuthClient(test.ID) + if (err != nil || test.err != "") && + (err == nil || !strings.Contains(err.Error(), test.err)) { + t.Fatalf("unexpected service URL error: %s", err) + } + + if diff := cmp.Diff(test.want, got); diff != "" { + t.Errorf("wrong result\n%s", diff) + } + }) + } +} + func TestVersionConstrains(t *testing.T) { baseURL, _ := url.Parse("https://example.com/disco/foo.json") diff --git a/svchost/disco/oauth_client.go b/svchost/disco/oauth_client.go index 77376cde7..0dc8a6d09 100644 --- a/svchost/disco/oauth_client.go +++ b/svchost/disco/oauth_client.go @@ -1,7 +1,9 @@ package disco import ( + "fmt" "net/url" + "strings" "golang.org/x/oauth2" ) @@ -16,10 +18,16 @@ type OAuthClient struct { // Authorization URL is the URL of the authorization endpoint that must // be used for this OAuth client, as defined in the OAuth2 specifications. + // + // Not all grant types use the authorization endpoint, so it may be omitted + // if none of the grant types in SupportedGrantTypes require it. AuthorizationURL *url.URL // Token URL is the URL of the token endpoint that must be used for this // OAuth client, as defined in the OAuth2 specifications. + // + // Not all grant types use the token endpoint, so it may be omitted + // if none of the grant types in SupportedGrantTypes require it. TokenURL *url.URL // MinPort and MaxPort define a range of TCP ports on localhost that this @@ -32,6 +40,12 @@ type OAuthClient struct { // to respect the common convention (enforced on some operating systems) // that lower port numbers are reserved for "privileged" services. MinPort, MaxPort uint16 + + // SupportedGrantTypes is a set of the grant types that the client may + // choose from. This includes an entry for each distinct type advertised + // by the server, even if a particular keyword is not supported by the + // current version of Terraform. + SupportedGrantTypes OAuthGrantTypeSet } // Endpoint returns an oauth2.Endpoint value ready to be used with the oauth2 @@ -47,3 +61,112 @@ func (c *OAuthClient) Endpoint() oauth2.Endpoint { AuthStyle: oauth2.AuthStyleInParams, } } + +// OAuthGrantType is an enumeration of grant type strings that a host can +// advertise support for. +// +// Values of this type don't necessarily match with a known constant of the +// type, because they may represent grant type keywords defined in a later +// version of Terraform which this version doesn't yet know about. +type OAuthGrantType string + +const ( + // OAuthAuthzCodeGrant represents an authorization code grant, as + // defined in IETF RFC 6749 section 4.1. + OAuthAuthzCodeGrant = OAuthGrantType("authz_code") + + // OAuthOwnerPasswordGrant represents a resource owner password + // credentials grant, as defined in IETF RFC 6749 section 4.3. + OAuthOwnerPasswordGrant = OAuthGrantType("password") +) + +// UsesAuthorizationEndpoint returns true if the receiving grant type makes +// use of the authorization endpoint from the client configuration, and thus +// if the authorization endpoint ought to be required. +func (t OAuthGrantType) UsesAuthorizationEndpoint() bool { + switch t { + case OAuthAuthzCodeGrant: + return true + case OAuthOwnerPasswordGrant: + return false + default: + // We'll default to false so that we don't impose any requirements + // on any grant type keywords that might be defined for future + // versions of Terraform. + return false + } +} + +// UsesTokenEndpoint returns true if the receiving grant type makes +// use of the token endpoint from the client configuration, and thus +// if the authorization endpoint ought to be required. +func (t OAuthGrantType) UsesTokenEndpoint() bool { + switch t { + case OAuthAuthzCodeGrant: + return true + case OAuthOwnerPasswordGrant: + return true + default: + // We'll default to false so that we don't impose any requirements + // on any grant type keywords that might be defined for future + // versions of Terraform. + return false + } +} + +// OAuthGrantTypeSet represents a set of OAuthGrantType values. +type OAuthGrantTypeSet map[OAuthGrantType]struct{} + +// NewOAuthGrantTypeSet constructs a new grant type set from the given list +// of grant type keyword strings. Any duplicates in the list are ignored. +func NewOAuthGrantTypeSet(keywords ...string) OAuthGrantTypeSet { + ret := make(OAuthGrantTypeSet, len(keywords)) + for _, kw := range keywords { + ret[OAuthGrantType(kw)] = struct{}{} + } + return ret +} + +// Has returns true if the given grant type is in the receiving set. +func (s OAuthGrantTypeSet) Has(t OAuthGrantType) bool { + _, ok := s[t] + return ok +} + +// RequiresAuthorizationEndpoint returns true if any of the grant types in +// the set are known to require an authorization endpoint. +func (s OAuthGrantTypeSet) RequiresAuthorizationEndpoint() bool { + for t := range s { + if t.UsesAuthorizationEndpoint() { + return true + } + } + return false +} + +// RequiresTokenEndpoint returns true if any of the grant types in +// the set are known to require a token endpoint. +func (s OAuthGrantTypeSet) RequiresTokenEndpoint() bool { + for t := range s { + if t.UsesTokenEndpoint() { + return true + } + } + return false +} + +// GoString implements fmt.GoStringer. +func (s OAuthGrantTypeSet) GoString() string { + var buf strings.Builder + i := 0 + buf.WriteString("disco.NewOAuthGrantTypeSet(") + for t := range s { + if i > 0 { + buf.WriteString(", ") + } + fmt.Fprintf(&buf, "%q", string(t)) + i++ + } + buf.WriteString(")") + return buf.String() +}