forked from mystiq/dex
Use oidc.Verifier to verify tokens
This commit is contained in:
parent
157c359f3e
commit
46f5726d11
4 changed files with 154 additions and 75 deletions
|
@ -2,7 +2,6 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -15,6 +14,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/gorilla/mux"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
|
||||
|
@ -23,10 +23,6 @@ import (
|
|||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
errTokenExpired = errors.New("token has expired")
|
||||
)
|
||||
|
||||
// newHealthChecker returns the healthz handler. The handler runs until the
|
||||
// provided context is canceled.
|
||||
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
|
||||
|
@ -1055,84 +1051,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
}
|
||||
|
||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
authorization := r.Header.Get("Authorization")
|
||||
parts := strings.Fields(authorization)
|
||||
const prefix = "Bearer "
|
||||
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
|
||||
msg := "invalid authorization header"
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
|
||||
s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
|
||||
auth := r.Header.Get("authorization")
|
||||
if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) {
|
||||
w.Header().Set("WWW-Authenticate", "Bearer")
|
||||
s.tokenErrHelper(w, errAccessDenied, "Invalid bearer token.", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
rawIDToken := auth[len(prefix):]
|
||||
|
||||
verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
|
||||
idToken, err := verifier.Verify(r.Context(), rawIDToken)
|
||||
if err != nil {
|
||||
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
|
||||
verified, err := s.verify(token)
|
||||
if err != nil {
|
||||
if err == errTokenExpired {
|
||||
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
|
||||
var claims json.RawMessage
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
s.tokenErrHelper(w, errServerError, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(verified)
|
||||
}
|
||||
|
||||
func (s *Server) verify(token string) ([]byte, error) {
|
||||
keys, err := s.storage.GetKeys()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get keys: %v", err)
|
||||
}
|
||||
|
||||
if keys.SigningKey == nil {
|
||||
return nil, fmt.Errorf("no private keys found")
|
||||
}
|
||||
|
||||
object, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse signed message")
|
||||
}
|
||||
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("compact JWS format must have three parts")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: check other claims
|
||||
var tokenInfo struct {
|
||||
Expiry int64 `json:"exp"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(payload, &tokenInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenInfo.Expiry < s.now().Unix() {
|
||||
return nil, errTokenExpired
|
||||
}
|
||||
|
||||
var allKeys []*jose.JSONWebKey
|
||||
|
||||
allKeys = append(allKeys, keys.SigningKeyPub)
|
||||
for _, key := range keys.VerificationKeys {
|
||||
allKeys = append(allKeys, key.PublicKey)
|
||||
}
|
||||
|
||||
for _, pubKey := range allKeys {
|
||||
verified, err := object.Verify(pubKey)
|
||||
if err == nil {
|
||||
return verified, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("unable to verify jwt")
|
||||
w.Write(claims)
|
||||
}
|
||||
|
||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
|
@ -566,3 +567,41 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
|
|||
host, _, err := net.SplitHostPort(u.Host)
|
||||
return err == nil && host == "localhost"
|
||||
}
|
||||
|
||||
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
|
||||
type storageKeySet struct {
|
||||
storage.Storage
|
||||
}
|
||||
|
||||
func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
|
||||
jws, err := jose.ParseSigned(jwt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyID := ""
|
||||
for _, sig := range jws.Signatures {
|
||||
keyID = sig.Header.KeyID
|
||||
break
|
||||
}
|
||||
|
||||
skeys, err := s.Storage.GetKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := []*jose.JSONWebKey{skeys.SigningKeyPub}
|
||||
for _, vk := range skeys.VerificationKeys {
|
||||
keys = append(keys, vk.PublicKey)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if keyID == "" || key.KeyID == keyID {
|
||||
if payload, err := jws.Verify(key); err == nil {
|
||||
return payload, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("failed to verify id token signature")
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
@ -11,6 +13,7 @@ import (
|
|||
jose "gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/memory"
|
||||
)
|
||||
|
||||
func TestParseAuthorizationRequest(t *testing.T) {
|
||||
|
@ -259,3 +262,87 @@ func TestValidRedirectURI(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorageKeySet(t *testing.T) {
|
||||
s := memory.New(logger)
|
||||
if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
|
||||
keys.SigningKey = &jose.JSONWebKey{
|
||||
Key: testKey,
|
||||
KeyID: "testkey",
|
||||
Algorithm: "RS256",
|
||||
Use: "sig",
|
||||
}
|
||||
keys.SigningKeyPub = &jose.JSONWebKey{
|
||||
Key: testKey.Public(),
|
||||
KeyID: "testkey",
|
||||
Algorithm: "RS256",
|
||||
Use: "sig",
|
||||
}
|
||||
return keys, nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenGenerator func() (jwt string, err error)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
tokenGenerator: func() (string, error) {
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: testKey}, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
jws, err := signer.Sign([]byte("payload"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return jws.CompactSerialize()
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "token signed by different key",
|
||||
tokenGenerator: func() (string, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
jws, err := signer.Sign([]byte("payload"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return jws.CompactSerialize()
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwt, err := tc.tokenGenerator()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
keySet := &storageKeySet{s}
|
||||
|
||||
_, err = keySet.VerifySignature(context.Background(), jwt)
|
||||
if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {
|
||||
t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -200,6 +200,16 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fetch userinfo",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
_, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch userinfo: %v", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "verify id token and oauth2 token expiry",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
|
|
Loading…
Reference in a new issue