Add userinfo endpoint
Co-authored-by: Yuxing Li <360983+jackielii@users.noreply.github.com> Co-authored-by: Francisco Santiago <1737357+fjbsantiago@users.noreply.github.com>
This commit is contained in:
parent
49e59fb54f
commit
a8d059a237
3 changed files with 108 additions and 2 deletions
|
@ -3,6 +3,7 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -22,6 +23,10 @@ import (
|
|||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
errTokenExpired = errors.New("token has expired")
|
||||
)
|
||||
|
||||
// newHealthChecker returns the healthz handler. The handler runs until the
|
||||
// provided context is canceled.
|
||||
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
|
||||
|
@ -151,6 +156,7 @@ type discovery struct {
|
|||
Auth string `json:"authorization_endpoint"`
|
||||
Token string `json:"token_endpoint"`
|
||||
Keys string `json:"jwks_uri"`
|
||||
UserInfo string `json:"userinfo_endpoint"`
|
||||
ResponseTypes []string `json:"response_types_supported"`
|
||||
Subjects []string `json:"subject_types_supported"`
|
||||
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
||||
|
@ -165,6 +171,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
|||
Auth: s.absURL("/auth"),
|
||||
Token: s.absURL("/token"),
|
||||
Keys: s.absURL("/keys"),
|
||||
Keys: s.absURL("/userinfo"),
|
||||
Subjects: []string{"public"},
|
||||
IDTokenAlgs: []string{string(jose.RS256)},
|
||||
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
||||
|
@ -559,7 +566,12 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
|
|||
idToken string
|
||||
idTokenExpiry time.Time
|
||||
|
||||
accessToken = storage.NewID()
|
||||
i accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create new access token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
)
|
||||
|
||||
for _, responseType := range authReq.ResponseTypes {
|
||||
|
@ -965,7 +977,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
Groups: ident.Groups,
|
||||
}
|
||||
|
||||
accessToken := storage.NewID()
|
||||
accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create new access token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create ID token: %v", err)
|
||||
|
@ -1026,6 +1044,88 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
||||
}
|
||||
|
||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
authorization := r.Header.Get("Authorization")
|
||||
parts := strings.Fields(authorization)
|
||||
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
|
||||
msg := "invalid authorization header"
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
|
||||
s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
|
||||
verified, err := s.verify(token)
|
||||
if err != nil {
|
||||
if err == errTokenExpired {
|
||||
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(verified)
|
||||
}
|
||||
|
||||
func (s *Server) verify(token string) ([]byte, error) {
|
||||
keys, err := s.storage.GetKeys()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get keys: %v", err)
|
||||
}
|
||||
|
||||
if keys.SigningKey == nil {
|
||||
return nil, fmt.Errorf("no private keys found")
|
||||
}
|
||||
|
||||
object, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse signed message")
|
||||
}
|
||||
|
||||
// Parse the message to check expiry, as it jose doesn't distinguish expiry error from others
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("square/go-jose: compact JWS format must have three parts")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: check other claims
|
||||
var tokenInfo struct {
|
||||
Expiry int64 `json:"exp"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(payload, &tokenInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenInfo.Expiry < s.now().Unix() {
|
||||
return nil, errTokenExpired
|
||||
}
|
||||
|
||||
var allKeys []*jose.JSONWebKey
|
||||
|
||||
allKeys = append(allKeys, keys.SigningKeyPub)
|
||||
for _, key := range keys.VerificationKeys {
|
||||
allKeys = append(allKeys, key.PublicKey)
|
||||
}
|
||||
|
||||
for _, pubKey := range allKeys {
|
||||
verified, err := object.Verify(pubKey)
|
||||
if err == nil {
|
||||
return verified, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("unable to verify jwt")
|
||||
}
|
||||
|
||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
|
||||
// TODO(ericchiang): figure out an access token story and support the user info
|
||||
// endpoint. For now use a random value so no one depends on the access_token
|
||||
|
|
|
@ -265,6 +265,11 @@ type federatedIDClaims struct {
|
|||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, err error) {
|
||||
idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), connID)
|
||||
return idToken, err
|
||||
}
|
||||
|
||||
func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, connID string) (idToken string, expiry time.Time, err error) {
|
||||
keys, err := s.storage.GetKeys()
|
||||
if err != nil {
|
||||
|
|
|
@ -270,6 +270,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||
// TODO(ericchiang): rate limit certain paths based on IP.
|
||||
handleWithCORS("/token", s.handleToken)
|
||||
handleWithCORS("/keys", s.handlePublicKeys)
|
||||
handleWithCORS("/userinfo", s.handleUserInfo)
|
||||
handleFunc("/auth", s.handleAuthorization)
|
||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
Reference in a new issue