This repository has been archived on 2022-08-17. You can view files and clone it, but cannot push or open issues or pull requests.
dex/db/refresh.go
Rubén Soleto Buenvarón c91b37aa9e refresh token rotation
Update refresh token flow to revoke old refresh token and generates a new one.

Fixes #519
2016-08-16 08:05:01 +02:00

280 lines
6.7 KiB
Go

package db
import (
"encoding/base64"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/go-gorp/gorp"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/scope"
)
const (
refreshTokenTableName = "refresh_token"
)
func init() {
register(table{
name: refreshTokenTableName,
model: refreshTokenModel{},
autoinc: true,
pkey: []string{"id"},
})
}
type refreshTokenRepo struct {
*db
tokenGenerator refresh.RefreshTokenGenerator
}
type refreshTokenModel struct {
ID int64 `db:"id"`
PayloadHash []byte `db:"payload_hash"`
UserID string `db:"user_id"`
ClientID string `db:"client_id"`
ConnectorID string `db:"connector_id"`
Scopes string `db:"scopes"`
}
// buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
// parseToken parses a token and returns the token ID and token payload.
func parseToken(token string) (int64, []byte, error) {
parts := strings.SplitN(token, refresh.TokenDelimer, 2)
if len(parts) != 2 {
return -1, nil, refresh.ErrorInvalidToken
}
id, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return -1, nil, refresh.ErrorInvalidToken
}
tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return -1, nil, refresh.ErrorInvalidToken
}
return id, tokenPayload, nil
}
func checkTokenPayload(payloadHash, payload []byte) error {
if err := bcrypt.CompareHashAndPassword(payloadHash, payload); err != nil {
switch err {
case bcrypt.ErrMismatchedHashAndPassword:
return refresh.ErrorInvalidToken
default:
return err
}
}
return nil
}
func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo {
return NewRefreshTokenRepoWithGenerator(dbm, refresh.DefaultRefreshTokenGenerator)
}
func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo {
return &refreshTokenRepo{
db: &db{dbm},
tokenGenerator: gen,
}
}
func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
return r.create(nil, userID, clientID, connectorID, scopes)
}
func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
return r.verify(nil, clientID, token)
}
func (r *refreshTokenRepo) Revoke(userID, token string) error {
tx, err := r.begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := r.revoke(tx, userID, token); err != nil {
return err
}
return tx.Commit()
}
func (r *refreshTokenRepo) RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error) {
// Verify
userID, connectorID, scopes, err := r.verify(nil, clientID, oldToken)
if err != nil {
return "", err
}
// Revoke old refresh token
tx, err := r.begin()
if err != nil {
return "", err
}
defer tx.Rollback()
if err := r.revoke(tx, userID, oldToken); err != nil {
return "", err
}
// Renew refresh token
newRefreshToken, err = r.create(tx, userID, clientID, connectorID, scopes)
if err != nil {
return "", err
}
return newRefreshToken, tx.Commit()
}
func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error {
q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName))
_, err := r.executor(nil).Exec(q, userID, clientID)
return err
}
func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Client, error) {
q := `SELECT c.* FROM %s as c
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName))
var clients []clientModel
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
return nil, err
}
c := make([]client.Client, len(clients))
for i, client := range clients {
ident, err := client.Client()
if err != nil {
return nil, err
}
c[i] = *ident
// Do not share the secret.
c[i].Credentials.Secret = ""
}
return c, nil
}
func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx)
result, err := ex.Get(refreshTokenModel{}, tokenID)
if err != nil {
return nil, err
}
if result == nil {
return nil, refresh.ErrorInvalidToken
}
record, ok := result.(*refreshTokenModel)
if !ok {
log.Errorf("expected refreshTokenModel but found %v", reflect.TypeOf(result))
return nil, errors.New("unrecognized model")
}
return record, nil
}
func (r *refreshTokenRepo) verify(tx repo.Transaction, clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return
}
record, err := r.get(tx, tokenID)
if err != nil {
return
}
if record.ClientID != clientID {
return "", "", nil, refresh.ErrorInvalidClientID
}
// Check if the hash of token received is the same stored in database
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return
}
var scopes []string
if len(record.Scopes) > 0 {
scopes = strings.Split(record.Scopes, " ")
}
return record.UserID, record.ConnectorID, scopes, nil
}
func (r *refreshTokenRepo) create(tx repo.Transaction, userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
}
if clientID == "" {
return "", refresh.ErrorInvalidClientID
}
// TODO(yifan): Check the number of tokens given to the client-user pair.
tokenPayload, err := r.tokenGenerator.Generate()
if err != nil {
return "", err
}
payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
if err != nil {
return "", err
}
record := &refreshTokenModel{
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "),
}
if err := r.executor(tx).Insert(record); err != nil {
return "", err
}
return buildToken(record.ID, tokenPayload), nil
}
func (r *refreshTokenRepo) revoke(tx repo.Transaction, userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return err
}
exec := r.executor(tx)
record, err := r.get(tx, tokenID)
if err != nil {
return err
}
if record.UserID != userID {
return refresh.ErrorInvalidUserID
}
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return err
}
deleted, err := exec.Delete(record)
if err != nil {
return err
}
if deleted == 0 {
return refresh.ErrorInvalidToken
}
return nil
}