From 13be146d2a27a87091d9e36b0c9a89cf7049c839 Mon Sep 17 00:00:00 2001 From: Zach Brown Date: Tue, 2 Jan 2018 22:15:01 -0500 Subject: [PATCH 1/2] Add support for password grant #926 --- cmd/dex/config.go | 2 + cmd/dex/serve.go | 4 + examples/config-dev.yaml | 2 + server/handlers.go | 247 +++++++++++++++++++++++++++++++++++++++ server/oauth2.go | 1 + server/server.go | 6 + 6 files changed, 262 insertions(+) diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 2519f6f5..3162aaa5 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -129,6 +129,8 @@ type OAuth2 struct { SkipApprovalScreen bool `json:"skipApprovalScreen"` // If specified, show the connector selection screen even if there's only one AlwaysShowLoginScreen bool `json:"alwaysShowLoginScreen"` + // This is the connector that can be used for password grant + PasswordConnector string `json:"passwordConnector"` } // Web is the config format for the HTTP server. diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 293d3e66..86f02c78 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -201,6 +201,9 @@ func serve(cmd *cobra.Command, args []string) error { if c.OAuth2.SkipApprovalScreen { logger.Infof("config skipping approval screen") } + if c.OAuth2.PasswordConnector != "" { + logger.Infof("config using password grant connector: %s", c.OAuth2.PasswordConnector) + } if len(c.Web.AllowedOrigins) > 0 { logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins) } @@ -212,6 +215,7 @@ func serve(cmd *cobra.Command, args []string) error { SupportedResponseTypes: c.OAuth2.ResponseTypes, SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen, + PasswordConnector: c.OAuth2.PasswordConnector, AllowedOrigins: c.Web.AllowedOrigins, Issuer: c.Issuer, Storage: s, diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 111a0224..099624f8 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -53,6 +53,8 @@ telemetry: # go directly to it. For connected IdPs, this redirects the browser away # from application to upstream provider such as the Google login page # alwaysShowLoginScreen: false + # Uncommend the passwordConnector to use a specific connector for password grants +# passwordConnector: local # Instead of reading from an external storage, use this list of clients. # diff --git a/server/handlers.go b/server/handlers.go index 24d39999..0fe8084d 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -756,6 +756,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { s.handleAuthCode(w, r, client) case grantTypeRefreshToken: s.handleRefreshToken(w, r, client) + case grantTypePassword: + s.handlePasswordGrant(w, r, client) default: s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) } @@ -1150,6 +1152,251 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { w.Write(claims) } +func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) { + + // Parse the fields + if err := r.ParseForm(); err != nil { + s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest) + return + } + q := r.Form + + // Get the clientID and secret from basic auth or form variables + clientID, clientSecret, ok := r.BasicAuth() + if ok { + var err error + if clientID, err = url.QueryUnescape(clientID); err != nil { + s.tokenErrHelper(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest) + return + } + if clientSecret, err = url.QueryUnescape(clientSecret); err != nil { + s.tokenErrHelper(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest) + return + } + } else { + clientID = q.Get("client_id") + clientSecret = q.Get("client_secret") + } + + nonce := q.Get("nonce") + // Some clients, like the old go-oidc, provide extra whitespace. Tolerate this. + scopes := strings.Fields(q.Get("scope")) + + // Get the client from the database + client, err := s.storage.GetClient(clientID) + if err != nil { + if err == storage.ErrNotFound { + s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Invalid client_id (%q).", clientID), http.StatusBadRequest) + return + } + s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Failed to get client %v.", err), http.StatusBadRequest) + return + } + + // Parse the scopes if they are passed + var ( + unrecognized []string + invalidScopes []string + ) + hasOpenIDScope := false + for _, scope := range scopes { + switch scope { + case scopeOpenID: + hasOpenIDScope = true + case scopeOfflineAccess, scopeEmail, scopeProfile, scopeGroups, scopeFederatedID: + default: + peerID, ok := parseCrossClientScope(scope) + if !ok { + unrecognized = append(unrecognized, scope) + continue + } + + isTrusted, err := s.validateCrossClientTrust(clientID, peerID) + if err != nil { + s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Error validating cross client trust %v.", err), http.StatusBadRequest) + return + } + if !isTrusted { + invalidScopes = append(invalidScopes, scope) + } + } + } + if !hasOpenIDScope { + s.tokenErrHelper(w, errInvalidRequest, `Missing required scope(s) ["openid"].`, http.StatusBadRequest) + return + } + if len(unrecognized) > 0 { + s.tokenErrHelper(w, errInvalidRequest, fmt.Sprintf("Unrecognized scope(s) %q", unrecognized), http.StatusBadRequest) + return + } + if len(invalidScopes) > 0 { + s.tokenErrHelper(w, errInvalidRequest, fmt.Sprintf("Client can't request scope(s) %q", invalidScopes), http.StatusBadRequest) + return + } + + // Which connector + connID := s.passwordConnector + conn, err := s.getConnector(connID) + if err != nil { + s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) + return + } + + passwordConnector, ok := conn.Connector.(connector.PasswordConnector) + if !ok { + s.tokenErrHelper(w, errInvalidRequest, "Requested password connector does not correct type.", http.StatusBadRequest) + return + } + + // Login + username := q.Get("username") + password := q.Get("password") + identity, ok, err := passwordConnector.Login(r.Context(), parseScopes(scopes), username, password) + if err != nil { + s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest) + return + } + if !ok { + s.tokenErrHelper(w, errAccessDenied, "Invalid username or password", http.StatusUnauthorized) + return + } + + // Build the claims to send the id token + claims := storage.Claims{ + UserID: identity.UserID, + Username: identity.Username, + PreferredUsername: identity.PreferredUsername, + Email: identity.Email, + EmailVerified: identity.EmailVerified, + Groups: identity.Groups, + } + + accessToken := storage.NewID() + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, connID) + if err != nil { + s.tokenErrHelper(w, errServerError, fmt.Sprintf("failed to create ID token: %v", err), http.StatusInternalServerError) + return + } + + reqRefresh := func() bool { + // Ensure the connector supports refresh tokens. + // + // Connectors like `saml` do not implement RefreshConnector. + _, ok := conn.Connector.(connector.RefreshConnector) + if !ok { + return false + } + + for _, scope := range scopes { + if scope == scopeOfflineAccess { + return true + } + } + return false + }() + var refreshToken string + if reqRefresh { + refresh := storage.RefreshToken{ + ID: storage.NewID(), + Token: storage.NewID(), + ClientID: clientID, + ConnectorID: connID, + Scopes: scopes, + Claims: claims, + Nonce: nonce, + // ConnectorData: authCode.ConnectorData, + CreatedAt: s.now(), + LastUsed: s.now(), + } + token := &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: refresh.Token, + } + if refreshToken, err = internal.Marshal(token); err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + if err := s.storage.CreateRefresh(refresh); err != nil { + s.logger.Errorf("failed to create refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + // deleteToken determines if we need to delete the newly created refresh token + // due to a failure in updating/creating the OfflineSession object for the + // corresponding user. + var deleteToken bool + defer func() { + if deleteToken { + // Delete newly created refresh token from storage. + if err := s.storage.DeleteRefresh(refresh.ID); err != nil { + s.logger.Errorf("failed to delete refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + } + }() + + tokenRef := storage.RefreshTokenRef{ + ID: refresh.ID, + ClientID: refresh.ClientID, + CreatedAt: refresh.CreatedAt, + LastUsed: refresh.LastUsed, + } + + // Try to retrieve an existing OfflineSession object for the corresponding user. + if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + offlineSessions := storage.OfflineSessions{ + UserID: refresh.Claims.UserID, + ConnID: refresh.ConnectorID, + Refresh: make(map[string]*storage.RefreshTokenRef), + } + offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef + + // Create a new OfflineSession object for the user and add a reference object for + // the newly received refreshtoken. + if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil { + s.logger.Errorf("failed to create offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + } else { + if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { + // Delete old refresh token from storage. + if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil { + s.logger.Errorf("failed to delete refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + } + + // Update existing OfflineSession obj with new RefreshTokenRef. + if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + old.Refresh[tokenRef.ClientID] = &tokenRef + return old, nil + }); err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + + } + } + + s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) +} + func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) { resp := struct { AccessToken string `json:"access_token"` diff --git a/server/oauth2.go b/server/oauth2.go index 0cd26814..ecc619d6 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -121,6 +121,7 @@ const ( const ( grantTypeAuthorizationCode = "authorization_code" grantTypeRefreshToken = "refresh_token" + grantTypePassword = "password" ) const ( diff --git a/server/server.go b/server/server.go index 27d93064..369d60f9 100644 --- a/server/server.go +++ b/server/server.go @@ -76,6 +76,8 @@ type Config struct { RotateKeysAfter time.Duration // Defaults to 6 hours. IDTokensValidFor time.Duration // Defaults to 24 hours AuthRequestsValidFor time.Duration // Defaults to 24 hours + // If set, the server will use this connector to handle password grants + PasswordConnector string GCFrequency time.Duration // Defaults to 5 minutes @@ -145,6 +147,9 @@ type Server struct { // If enabled, show the connector selection screen even if there's only one alwaysShowLogin bool + // Used for password grant + passwordConnector string + supportedResponseTypes map[string]bool now func() time.Time @@ -216,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) alwaysShowLogin: c.AlwaysShowLoginScreen, now: now, templates: tmpls, + passwordConnector: c.PasswordConnector, logger: c.Logger, } From 0f9a74f1d055a4780d5be1f93a2e1dba905f0e9c Mon Sep 17 00:00:00 2001 From: Rui Yang Date: Fri, 10 Jan 2020 14:39:08 -0500 Subject: [PATCH 2/2] Remove uneccesary client verification --- server/handlers.go | 34 ++-------------------------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index 0fe8084d..694ababb 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1153,7 +1153,6 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { } func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) { - // Parse the fields if err := r.ParseForm(); err != nil { s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest) @@ -1161,38 +1160,10 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli } q := r.Form - // Get the clientID and secret from basic auth or form variables - clientID, clientSecret, ok := r.BasicAuth() - if ok { - var err error - if clientID, err = url.QueryUnescape(clientID); err != nil { - s.tokenErrHelper(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest) - return - } - if clientSecret, err = url.QueryUnescape(clientSecret); err != nil { - s.tokenErrHelper(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest) - return - } - } else { - clientID = q.Get("client_id") - clientSecret = q.Get("client_secret") - } - nonce := q.Get("nonce") // Some clients, like the old go-oidc, provide extra whitespace. Tolerate this. scopes := strings.Fields(q.Get("scope")) - // Get the client from the database - client, err := s.storage.GetClient(clientID) - if err != nil { - if err == storage.ErrNotFound { - s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Invalid client_id (%q).", clientID), http.StatusBadRequest) - return - } - s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Failed to get client %v.", err), http.StatusBadRequest) - return - } - // Parse the scopes if they are passed var ( unrecognized []string @@ -1211,7 +1182,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli continue } - isTrusted, err := s.validateCrossClientTrust(clientID, peerID) + isTrusted, err := s.validateCrossClientTrust(client.ID, peerID) if err != nil { s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Error validating cross client trust %v.", err), http.StatusBadRequest) return @@ -1299,7 +1270,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli refresh := storage.RefreshToken{ ID: storage.NewID(), Token: storage.NewID(), - ClientID: clientID, + ClientID: client.ID, ConnectorID: connID, Scopes: scopes, Claims: claims, @@ -1390,7 +1361,6 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli deleteToken = true return } - } }