diff --git a/server/cross_client_test.go b/server/cross_client_test.go new file mode 100644 index 00000000..bb1a9771 --- /dev/null +++ b/server/cross_client_test.go @@ -0,0 +1,202 @@ +package server + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/coreos/go-oidc/oidc" + + "github.com/coreos/dex/client" + "github.com/coreos/dex/connector" +) + +func makeCrossClientTestFixtures() (*testFixtures, error) { + f, err := makeTestFixtures() + if err != nil { + return nil, fmt.Errorf("couldn't make test fixtures: %v", err) + } + + creds := map[string]oidc.ClientCredentials{} + for _, cliData := range []struct { + id string + authorized []string + }{ + { + id: "client_a", + }, { + id: "client_b", + authorized: []string{"client_a"}, + }, { + id: "client_c", + authorized: []string{"client_a", "client_b"}, + }, + } { + u := url.URL{ + Scheme: "https://", + Path: cliData.id, + Host: "auth.example.com", + } + cliCreds, err := f.clientRepo.New(client.Client{ + Credentials: oidc.ClientCredentials{ + ID: cliData.id, + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{u}, + }, + }) + if err != nil { + return nil, fmt.Errorf("Unexpected error creating clients: %v", err) + } + creds[cliData.id] = *cliCreds + err = f.clientRepo.SetTrustedPeers(cliData.id, cliData.authorized) + if err != nil { + return nil, fmt.Errorf("Unexpected error setting cross-client authorizers: %v", err) + } + } + return f, nil +} + +func TestServerCrossClientAuthAllowed(t *testing.T) { + f, err := makeCrossClientTestFixtures() + if err != nil { + t.Fatalf("couldn't make test fixtures: %v", err) + } + + tests := []struct { + reqClient string + authClient string + wantAuthorized bool + wantErr bool + }{ + { + reqClient: "client_b", + authClient: "client_a", + wantAuthorized: false, + wantErr: false, + }, + { + reqClient: "client_a", + authClient: "client_b", + wantAuthorized: true, + wantErr: false, + }, + { + reqClient: "client_a", + authClient: "client_c", + wantAuthorized: true, + wantErr: false, + }, + { + reqClient: "client_c", + authClient: "client_b", + wantAuthorized: false, + wantErr: false, + }, + { + reqClient: "client_c", + authClient: "nope", + wantErr: false, + }, + } + for i, tt := range tests { + got, err := f.srv.CrossClientAuthAllowed(tt.reqClient, tt.authClient) + if tt.wantErr { + if err == nil { + t.Errorf("case %d: want non-nil err", i) + } + continue + } + if err != nil { + t.Errorf("case %d: unexpected err %v: ", i, err) + } + + if got != tt.wantAuthorized { + t.Errorf("case %d: want=%v, got=%v", i, tt.wantAuthorized, got) + } + } +} + +func TestHandleAuthCrossClient(t *testing.T) { + f, err := makeCrossClientTestFixtures() + if err != nil { + t.Fatalf("couldn't make test fixtures: %v", err) + } + + tests := []struct { + scopes []string + clientID string + wantCode int + }{ + { + scopes: []string{ScopeGoogleCrossClient + "client_a"}, + clientID: "client_b", + wantCode: http.StatusBadRequest, + }, + { + scopes: []string{ScopeGoogleCrossClient + "client_b"}, + clientID: "client_a", + wantCode: http.StatusFound, + }, + { + scopes: []string{ScopeGoogleCrossClient + "client_b"}, + clientID: "client_a", + wantCode: http.StatusFound, + }, + { + scopes: []string{ScopeGoogleCrossClient + "client_c"}, + clientID: "client_a", + wantCode: http.StatusFound, + }, + { + // Two clients that client_a is authorized to mint tokens for. + scopes: []string{ + ScopeGoogleCrossClient + "client_c", + ScopeGoogleCrossClient + "client_b", + }, + clientID: "client_a", + wantCode: http.StatusFound, + }, + { + // Two clients that client_a is authorized to mint tokens for. + scopes: []string{ + ScopeGoogleCrossClient + "client_c", + ScopeGoogleCrossClient + "client_a", + }, + clientID: "client_b", + wantCode: http.StatusBadRequest, + }, + } + + idpcs := []connector.Connector{ + &fakeConnector{loginURL: "http://fake.example.com"}, + } + + for i, tt := range tests { + hdlr := handleAuthFunc(f.srv, idpcs, nil, true) + w := httptest.NewRecorder() + + query := url.Values{ + "response_type": []string{"code"}, + "client_id": []string{tt.clientID}, + "connector_id": []string{"fake"}, + "scope": []string{strings.Join(append([]string{"openid"}, tt.scopes...), " ")}, + } + u := fmt.Sprintf("http://server.example.com?%s", query.Encode()) + req, err := http.NewRequest("GET", u, nil) + if err != nil { + t.Errorf("case %d: unable to form HTTP request: %v", i, err) + continue + } + + hdlr.ServeHTTP(w, req) + if tt.wantCode != w.Code { + t.Errorf("case %d: HTTP code mismatch: want=%d got=%d", i, tt.wantCode, w.Code) + continue + } + } + +} diff --git a/server/http.go b/server/http.go index a9d2bc46..2753806a 100644 --- a/server/http.go +++ b/server/http.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/url" + "sort" "strings" "time" @@ -263,7 +264,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp execTemplate(w, tpl, td) } -func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc { +func handleAuthFunc(srv DexServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc { idx := makeConnectorMap(idpcs) return func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { @@ -341,30 +342,9 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T } // Check scopes. - var scopes []string - foundOpenIDScope := false - for _, scope := range acr.Scope { - switch scope { - case "openid": - foundOpenIDScope = true - scopes = append(scopes, scope) - case "offline_access": - // According to the spec, for offline_access scope, the client must - // use a response_type value that would result in an Authorization Code. - // Currently oauth2.ResponseTypeCode is the only supported response type, - // and it's been checked above, so we don't need to check it again here. - // - // TODO(yifan): Verify that 'consent' should be in 'prompt'. - scopes = append(scopes, scope) - default: - // Pass all other scopes. - scopes = append(scopes, scope) - } - } - - if !foundOpenIDScope { - log.Errorf("Invalid auth request: missing 'openid' in 'scope'") - writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) + if scopeErr := validateScopes(srv, acr.ClientID, acr.Scope); scopeErr != nil { + log.Error(scopeErr) + writeAuthError(w, scopeErr, acr.State) return } @@ -410,6 +390,69 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T } } +func validateScopes(srv DexServer, clientID string, scopes []string) error { + foundOpenIDScope := false + sort.Strings(scopes) + for i, scope := range scopes { + if i > 0 && scope == scopes[i-1] { + err := oauth2.NewError(oauth2.ErrorInvalidRequest) + err.Description = fmt.Sprintf( + "Duplicate scopes are not allowed: %q", + scope) + return err + } + + switch { + case strings.HasPrefix(scope, ScopeGoogleCrossClient): + otherClient := scope[len(ScopeGoogleCrossClient):] + + var allowed bool + var err error + if otherClient == clientID { + allowed = true + } else { + allowed, err = srv.CrossClientAuthAllowed(clientID, otherClient) + if err != nil { + return err + } + } + + if !allowed { + err := oauth2.NewError(oauth2.ErrorInvalidRequest) + err.Description = fmt.Sprintf( + "%q is not authorized to perform cross-client requests for %q", + clientID, otherClient) + return err + } + case scope == "openid": + foundOpenIDScope = true + case scope == "profile": + case scope == "email": + case scope == "offline_access": + // According to the spec, for offline_access scope, the client must + // use a response_type value that would result in an Authorization + // Code. Currently oauth2.ResponseTypeCode is the only supported + // response type, and it's been checked above, so we don't need to + // check it again here. + // + // TODO(yifan): Verify that 'consent' should be in 'prompt'. + default: + // Reject all other scopes. + err := oauth2.NewError(oauth2.ErrorInvalidRequest) + err.Description = fmt.Sprintf("%q is not a recognized scope", scope) + return err + } + } + + if !foundOpenIDScope { + log.Errorf("Invalid auth request: missing 'openid' in 'scope'") + err := oauth2.NewError(oauth2.ErrorInvalidRequest) + err.Description = "Invalid auth request: missing 'openid' in 'scope'" + return err + } + return nil +} + func handleTokenFunc(srv OIDCServer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { diff --git a/server/http_test.go b/server/http_test.go index 0dda1847..67b9e105 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -308,8 +308,110 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { } } -func TestHandleTokenFunc(t *testing.T) { +func TestValidateScopes(t *testing.T) { + f, err := makeCrossClientTestFixtures() + if err != nil { + t.Fatalf("couldn't make test fixtures: %v", err) + } + tests := []struct { + clientID string + scopes []string + wantErr bool + }{ + { + // ERR: no openid scope + clientID: "XXX", + scopes: []string{}, + wantErr: true, + }, + { + // OK: minimum scopes + clientID: "XXX", + scopes: []string{"openid"}, + wantErr: false, + }, + { + // OK: offline_access + clientID: "XXX", + scopes: []string{"openid", "offline_access"}, + wantErr: false, + }, + { + // ERR: unknown scope + clientID: "XXX", + scopes: []string{"openid", "wat"}, + wantErr: true, + }, + { + // ERR: invalid cross client auth + clientID: "XXX", + scopes: []string{"openid", ScopeGoogleCrossClient + "client_a"}, + wantErr: true, + }, + { + // OK: valid cross client auth (though perverse - a client + // requesting cross-client auth for itself) + clientID: "client_a", + scopes: []string{"openid", ScopeGoogleCrossClient + "client_a"}, + wantErr: false, + }, + { + + // OK: valid cross client auth + clientID: "client_a", + scopes: []string{"openid", ScopeGoogleCrossClient + "client_b"}, + wantErr: false, + }, + { + + // ERR: valid cross client auth...but duplicated scope. + clientID: "client_a", + scopes: []string{"openid", + ScopeGoogleCrossClient + "client_b", + ScopeGoogleCrossClient + "client_b", + }, + wantErr: true, + }, + { + // OK: valid cross client auth with >1 clients including itself + clientID: "client_a", + scopes: []string{ + "openid", + ScopeGoogleCrossClient + "client_a", + ScopeGoogleCrossClient + "client_b", + ScopeGoogleCrossClient + "client_c", + }, + wantErr: false, + }, + { + // ERR: valid cross client auth with >1 clients including itself...but no openid! + clientID: "client_a", + scopes: []string{ + ScopeGoogleCrossClient + "client_a", + ScopeGoogleCrossClient + "client_b", + ScopeGoogleCrossClient + "client_c", + }, + wantErr: true, + }, + } + + for i, tt := range tests { + err := validateScopes(f.srv, tt.clientID, tt.scopes) + if tt.wantErr { + if err == nil { + t.Errorf("case %d: want non-nil err", i) + } + continue + } + + if err != nil { + t.Errorf("case %d: unexpected err: %v", i, err) + } + } +} + +func TestHandleTokenFunc(t *testing.T) { fx, err := makeTestFixtures() if err != nil { t.Fatalf("could not run test fixtures: %v", err) diff --git a/server/server.go b/server/server.go index 2b794862..6eaf8b5e 100644 --- a/server/server.go +++ b/server/server.go @@ -39,21 +39,37 @@ const ( ResetPasswordTemplateName = "reset-password.html" APIVersion = "v1" + + // Scope prefix which indicates initiation of a cross-client authentication flow. + // See https://developers.google.com/identity/protocols/CrossClientAuth + ScopeGoogleCrossClient = "audience:server:client_id:" ) type OIDCServer interface { ClientMetadata(string) (*oidc.ClientMetadata, error) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) Login(oidc.Identity, string) (string, error) + // CodeToken exchanges a code for an ID token and a refresh token string on success. CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) + ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) + // RefreshToken takes a previously generated refresh token and returns a new ID token // if the token is valid. RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) + KillSession(string) error } +// DexServer is an OIDCServer that also has dex-specific features. +type DexServer interface { + OIDCServer + + // CrossClientAuthAllowed + CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) +} + type JWTVerifierFactory func(clientID string) oidc.JWTVerifier type Server struct { @@ -521,6 +537,19 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose return jwt, nil } +func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) { + alloweds, err := s.ClientRepo.GetTrustedPeers(authorizingClientID) + if err != nil { + return false, err + } + for _, allowed := range alloweds { + if requestingClientID == allowed { + return true, nil + } + } + return false, nil +} + func (s *Server) JWTVerifierFactory() JWTVerifierFactory { noop := func() error { return nil }