From a8d059a237d6eb18bf9ac0577a1f486123228837 Mon Sep 17 00:00:00 2001 From: Maarten den Braber Date: Mon, 27 May 2019 09:17:39 +0200 Subject: [PATCH] Add userinfo endpoint Co-authored-by: Yuxing Li <360983+jackielii@users.noreply.github.com> Co-authored-by: Francisco Santiago <1737357+fjbsantiago@users.noreply.github.com> --- server/handlers.go | 104 ++++++++++++++++++++++++++++++++++++++++++++- server/oauth2.go | 5 +++ server/server.go | 1 + 3 files changed, 108 insertions(+), 2 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index ae6a21db..058ea3ab 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "encoding/base64" "errors" "fmt" "net/http" @@ -22,6 +23,10 @@ import ( "github.com/dexidp/dex/storage" ) +var ( + errTokenExpired = errors.New("token has expired") +) + // newHealthChecker returns the healthz handler. The handler runs until the // provided context is canceled. func (s *Server) newHealthChecker(ctx context.Context) http.Handler { @@ -151,6 +156,7 @@ type discovery struct { Auth string `json:"authorization_endpoint"` Token string `json:"token_endpoint"` Keys string `json:"jwks_uri"` + UserInfo string `json:"userinfo_endpoint"` ResponseTypes []string `json:"response_types_supported"` Subjects []string `json:"subject_types_supported"` IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` @@ -165,6 +171,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { Auth: s.absURL("/auth"), Token: s.absURL("/token"), Keys: s.absURL("/keys"), + Keys: s.absURL("/userinfo"), Subjects: []string{"public"}, IDTokenAlgs: []string{string(jose.RS256)}, Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, @@ -559,7 +566,12 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe idToken string idTokenExpiry time.Time - accessToken = storage.NewID() +i accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) + if err != nil { + s.logger.Errorf("failed to create new access token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } ) for _, responseType := range authReq.ResponseTypes { @@ -965,7 +977,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Groups: ident.Groups, } - accessToken := storage.NewID() + accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) + if err != nil { + s.logger.Errorf("failed to create new access token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) @@ -1026,6 +1044,88 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) } +func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { + authorization := r.Header.Get("Authorization") + parts := strings.Fields(authorization) + + if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") { + msg := "invalid authorization header" + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg)) + s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest) + return + } + + token := parts[1] + + verified, err := s.verify(token) + if err != nil { + if err == errTokenExpired { + s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized) + return + } + s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(verified) + } + +func (s *Server) verify(token string) ([]byte, error) { + keys, err := s.storage.GetKeys() + if err != nil { + return nil, fmt.Errorf("failed to get keys: %v", err) + } + + if keys.SigningKey == nil { + return nil, fmt.Errorf("no private keys found") + } + + object, err := jose.ParseSigned(token) + if err != nil { + return nil, fmt.Errorf("unable to parse signed message") + } + + // Parse the message to check expiry, as it jose doesn't distinguish expiry error from others + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("square/go-jose: compact JWS format must have three parts") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, err + } + + // TODO: check other claims + var tokenInfo struct { + Expiry int64 `json:"exp"` + } + + if err := json.Unmarshal(payload, &tokenInfo); err != nil { + return nil, err + } + + if tokenInfo.Expiry < s.now().Unix() { + return nil, errTokenExpired + } + + var allKeys []*jose.JSONWebKey + + allKeys = append(allKeys, keys.SigningKeyPub) + for _, key := range keys.VerificationKeys { + allKeys = append(allKeys, key.PublicKey) + } + + for _, pubKey := range allKeys { + verified, err := object.Verify(pubKey) + if err == nil { + return verified, nil + } + } + return nil, errors.New("unable to verify jwt") +} + func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) { // TODO(ericchiang): figure out an access token story and support the user info // endpoint. For now use a random value so no one depends on the access_token diff --git a/server/oauth2.go b/server/oauth2.go index 8c9494f5..26d152f4 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -265,6 +265,11 @@ type federatedIDClaims struct { UserID string `json:"user_id,omitempty"` } +func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, err error) { + idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), connID) + return idToken, err +} + func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, connID string) (idToken string, expiry time.Time, err error) { keys, err := s.storage.GetKeys() if err != nil { diff --git a/server/server.go b/server/server.go index 9dc259fb..69b4d0d7 100644 --- a/server/server.go +++ b/server/server.go @@ -270,6 +270,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) // TODO(ericchiang): rate limit certain paths based on IP. handleWithCORS("/token", s.handleToken) handleWithCORS("/keys", s.handlePublicKeys) + handleWithCORS("/userinfo", s.handleUserInfo) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {