Use oidc.Verifier to verify tokens
This commit is contained in:
parent
157c359f3e
commit
46f5726d11
4 changed files with 154 additions and 75 deletions
|
@ -2,7 +2,6 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -15,6 +14,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
oidc "github.com/coreos/go-oidc"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
jose "gopkg.in/square/go-jose.v2"
|
jose "gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
|
@ -23,10 +23,6 @@ import (
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errTokenExpired = errors.New("token has expired")
|
|
||||||
)
|
|
||||||
|
|
||||||
// newHealthChecker returns the healthz handler. The handler runs until the
|
// newHealthChecker returns the healthz handler. The handler runs until the
|
||||||
// provided context is canceled.
|
// provided context is canceled.
|
||||||
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
|
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
|
||||||
|
@ -1055,84 +1051,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
authorization := r.Header.Get("Authorization")
|
const prefix = "Bearer "
|
||||||
parts := strings.Fields(authorization)
|
|
||||||
|
|
||||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
|
auth := r.Header.Get("authorization")
|
||||||
msg := "invalid authorization header"
|
if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) {
|
||||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
|
w.Header().Set("WWW-Authenticate", "Bearer")
|
||||||
s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
|
s.tokenErrHelper(w, errAccessDenied, "Invalid bearer token.", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
rawIDToken := auth[len(prefix):]
|
||||||
|
|
||||||
token := parts[1]
|
verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
|
||||||
|
idToken, err := verifier.Verify(r.Context(), rawIDToken)
|
||||||
verified, err := s.verify(token)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == errTokenExpired {
|
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
|
||||||
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
|
|
||||||
|
var claims json.RawMessage
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
s.tokenErrHelper(w, errServerError, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Write(verified)
|
w.Write(claims)
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) verify(token string) ([]byte, error) {
|
|
||||||
keys, err := s.storage.GetKeys()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get keys: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if keys.SigningKey == nil {
|
|
||||||
return nil, fmt.Errorf("no private keys found")
|
|
||||||
}
|
|
||||||
|
|
||||||
object, err := jose.ParseSigned(token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse signed message")
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(token, ".")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return nil, fmt.Errorf("compact JWS format must have three parts")
|
|
||||||
}
|
|
||||||
|
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: check other claims
|
|
||||||
var tokenInfo struct {
|
|
||||||
Expiry int64 `json:"exp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(payload, &tokenInfo); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenInfo.Expiry < s.now().Unix() {
|
|
||||||
return nil, errTokenExpired
|
|
||||||
}
|
|
||||||
|
|
||||||
var allKeys []*jose.JSONWebKey
|
|
||||||
|
|
||||||
allKeys = append(allKeys, keys.SigningKeyPub)
|
|
||||||
for _, key := range keys.VerificationKeys {
|
|
||||||
allKeys = append(allKeys, key.PublicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, pubKey := range allKeys {
|
|
||||||
verified, err := object.Verify(pubKey)
|
|
||||||
if err == nil {
|
|
||||||
return verified, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, errors.New("unable to verify jwt")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
|
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
@ -566,3 +567,41 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
|
||||||
host, _, err := net.SplitHostPort(u.Host)
|
host, _, err := net.SplitHostPort(u.Host)
|
||||||
return err == nil && host == "localhost"
|
return err == nil && host == "localhost"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
|
||||||
|
type storageKeySet struct {
|
||||||
|
storage.Storage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
|
||||||
|
jws, err := jose.ParseSigned(jwt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
keyID := ""
|
||||||
|
for _, sig := range jws.Signatures {
|
||||||
|
keyID = sig.Header.KeyID
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
skeys, err := s.Storage.GetKeys()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := []*jose.JSONWebKey{skeys.SigningKeyPub}
|
||||||
|
for _, vk := range skeys.VerificationKeys {
|
||||||
|
keys = append(keys, vk.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
if keyID == "" || key.KeyID == keyID {
|
||||||
|
if payload, err := jws.Verify(key); err == nil {
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("failed to verify id token signature")
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,8 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -11,6 +13,7 @@ import (
|
||||||
jose "gopkg.in/square/go-jose.v2"
|
jose "gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseAuthorizationRequest(t *testing.T) {
|
func TestParseAuthorizationRequest(t *testing.T) {
|
||||||
|
@ -259,3 +262,87 @@ func TestValidRedirectURI(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStorageKeySet(t *testing.T) {
|
||||||
|
s := memory.New(logger)
|
||||||
|
if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
|
||||||
|
keys.SigningKey = &jose.JSONWebKey{
|
||||||
|
Key: testKey,
|
||||||
|
KeyID: "testkey",
|
||||||
|
Algorithm: "RS256",
|
||||||
|
Use: "sig",
|
||||||
|
}
|
||||||
|
keys.SigningKeyPub = &jose.JSONWebKey{
|
||||||
|
Key: testKey.Public(),
|
||||||
|
KeyID: "testkey",
|
||||||
|
Algorithm: "RS256",
|
||||||
|
Use: "sig",
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tokenGenerator func() (jwt string, err error)
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid token",
|
||||||
|
tokenGenerator: func() (string, error) {
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: testKey}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
jws, err := signer.Sign([]byte("payload"))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return jws.CompactSerialize()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token signed by different key",
|
||||||
|
tokenGenerator: func() (string, error) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
jws, err := signer.Sign([]byte("payload"))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return jws.CompactSerialize()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
jwt, err := tc.tokenGenerator()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keySet := &storageKeySet{s}
|
||||||
|
|
||||||
|
_, err = keySet.VerifySignature(context.Background(), jwt)
|
||||||
|
if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {
|
||||||
|
t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -200,6 +200,16 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "fetch userinfo",
|
||||||
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||||
|
_, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch userinfo: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "verify id token and oauth2 token expiry",
|
name: "verify id token and oauth2 token expiry",
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||||
|
|
Reference in a new issue