dex/db/refresh.go
Yifan Gu 44c6cb44f5 refresh: bcrypt raw bytes rather than base64 encoded string.
This enables us to control the length of the bytes that will be bcrypted,
by default it's 64.

Also changed the token's stored form from string('text') to []byte('bytea')
and added some test cases for different types of invalid tokens.
2015-09-02 14:23:20 -07:00

191 lines
4.1 KiB
Go

package db
import (
"encoding/base64"
"errors"
"fmt"
"strconv"
"strings"
"github.com/coreos/dex/refresh"
"github.com/go-gorp/gorp"
"golang.org/x/crypto/bcrypt"
)
const (
refreshTokenTableName = "refresh_token"
)
func init() {
register(table{
name: refreshTokenTableName,
model: refreshTokenModel{},
autoinc: true,
pkey: []string{"id"},
})
}
type refreshTokenRepo struct {
dbMap *gorp.DbMap
tokenGenerator refresh.RefreshTokenGenerator
}
type refreshTokenModel struct {
ID int64 `db:"id"`
PayloadHash []byte `db:"payload_hash"`
// TODO(yifan): Use some sort of foreign key to manage database level
// data integrity.
UserID string `db:"user_id"`
ClientID string `db:"client_id"`
}
// buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
// parseToken parses a token and returns the token ID and token payload.
func parseToken(token string) (int64, []byte, error) {
parts := strings.SplitN(token, refresh.TokenDelimer, 2)
if len(parts) != 2 {
return -1, nil, refresh.ErrorInvalidToken
}
id, err := strconv.ParseInt(parts[0], 0, 64)
if err != nil {
return -1, nil, refresh.ErrorInvalidToken
}
tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return -1, nil, refresh.ErrorInvalidToken
}
return id, tokenPayload, nil
}
func checkTokenPayload(payloadHash, payload []byte) error {
if err := bcrypt.CompareHashAndPassword(payloadHash, payload); err != nil {
switch err {
case bcrypt.ErrMismatchedHashAndPassword:
return refresh.ErrorInvalidToken
default:
return err
}
}
return nil
}
func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo {
return &refreshTokenRepo{
dbMap: dbm,
tokenGenerator: refresh.DefaultRefreshTokenGenerator,
}
}
func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
}
if clientID == "" {
return "", refresh.ErrorInvalidClientID
}
// TODO(yifan): Check the number of tokens given to the client-user pair.
tokenPayload, err := r.tokenGenerator.Generate()
if err != nil {
return "", err
}
payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
if err != nil {
return "", err
}
record := &refreshTokenModel{
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
}
if err := r.dbMap.Insert(record); err != nil {
return "", err
}
return buildToken(record.ID, tokenPayload), nil
}
func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return "", err
}
record, err := r.get(nil, tokenID)
if err != nil {
return "", err
}
if record.ClientID != clientID {
return "", refresh.ErrorInvalidClientID
}
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return "", err
}
return record.UserID, nil
}
func (r *refreshTokenRepo) Revoke(userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return err
}
record, err := r.get(nil, tokenID)
if err != nil {
return err
}
if record.UserID != userID {
return refresh.ErrorInvalidUserID
}
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return err
}
deleted, err := r.dbMap.Delete(record)
if err != nil {
return err
}
if deleted == 0 {
return refresh.ErrorInvalidToken
}
return nil
}
func (r *refreshTokenRepo) executor(tx *gorp.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
return tx
}
func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx)
result, err := ex.Get(refreshTokenModel{}, tokenID)
if err != nil {
return nil, err
}
if result == nil {
return nil, refresh.ErrorInvalidToken
}
record, ok := result.(*refreshTokenModel)
if !ok {
return nil, errors.New("unrecognized model")
}
return record, nil
}