forked from mystiq/dex
refresh: bcrypt raw bytes rather than base64 encoded string.
This enables us to control the length of the bytes that will be bcrypted, by default it's 64. Also changed the token's stored form from string('text') to []byte('bytea') and added some test cases for different types of invalid tokens.
This commit is contained in:
parent
081bfdd13d
commit
44c6cb44f5
7 changed files with 153 additions and 69 deletions
|
@ -1,7 +1,7 @@
|
|||
-- +migrate Up
|
||||
CREATE TABLE refresh_token (
|
||||
id bigint NOT NULL,
|
||||
payload_hash text,
|
||||
payload_hash bytea,
|
||||
user_id text,
|
||||
client_id text
|
||||
);
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
@ -31,7 +32,7 @@ type refreshTokenRepo struct {
|
|||
|
||||
type refreshTokenModel struct {
|
||||
ID int64 `db:"id"`
|
||||
PayloadHash string `db:"payload_hash"`
|
||||
PayloadHash []byte `db:"payload_hash"`
|
||||
// TODO(yifan): Use some sort of foreign key to manage database level
|
||||
// data integrity.
|
||||
UserID string `db:"user_id"`
|
||||
|
@ -39,25 +40,29 @@ type refreshTokenModel struct {
|
|||
}
|
||||
|
||||
// buildToken combines the token ID and token payload to create a new token.
|
||||
func buildToken(tokenID int64, tokenPayload string) string {
|
||||
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, tokenPayload)
|
||||
func buildToken(tokenID int64, tokenPayload []byte) string {
|
||||
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
|
||||
}
|
||||
|
||||
// parseToken parses a token and returns the token ID and token payload.
|
||||
func parseToken(token string) (int64, string, error) {
|
||||
func parseToken(token string) (int64, []byte, error) {
|
||||
parts := strings.SplitN(token, refresh.TokenDelimer, 2)
|
||||
if len(parts) != 2 {
|
||||
return -1, "", refresh.ErrorInvalidToken
|
||||
return -1, nil, refresh.ErrorInvalidToken
|
||||
}
|
||||
id, err := strconv.ParseInt(parts[0], 0, 64)
|
||||
if err != nil {
|
||||
return -1, "", refresh.ErrorInvalidToken
|
||||
return -1, nil, refresh.ErrorInvalidToken
|
||||
}
|
||||
return id, parts[1], nil
|
||||
tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return -1, nil, refresh.ErrorInvalidToken
|
||||
}
|
||||
return id, tokenPayload, nil
|
||||
}
|
||||
|
||||
func checkTokenPayload(payloadHash, payload string) error {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(payloadHash), []byte(payload)); err != nil {
|
||||
func checkTokenPayload(payloadHash, payload []byte) error {
|
||||
if err := bcrypt.CompareHashAndPassword(payloadHash, payload); err != nil {
|
||||
switch err {
|
||||
case bcrypt.ErrMismatchedHashAndPassword:
|
||||
return refresh.ErrorInvalidToken
|
||||
|
@ -89,13 +94,13 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
payloadHash, err := bcrypt.GenerateFromPassword([]byte(tokenPayload), bcrypt.DefaultCost)
|
||||
payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
record := &refreshTokenModel{
|
||||
PayloadHash: string(payloadHash),
|
||||
PayloadHash: payloadHash,
|
||||
UserID: userID,
|
||||
ClientID: clientID,
|
||||
}
|
||||
|
@ -109,6 +114,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
|||
|
||||
func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) {
|
||||
tokenID, tokenPayload, err := parseToken(token)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package functional
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
@ -342,6 +343,12 @@ func TestDBClientIdentityAll(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// buildRefreshToken combines the token ID and token payload to create a new token.
|
||||
// used in the tests to created a refresh token.
|
||||
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
|
||||
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
|
||||
}
|
||||
|
||||
func TestDBRefreshRepoCreate(t *testing.T) {
|
||||
r := db.NewRefreshTokenRepo(connect(t))
|
||||
|
||||
|
@ -383,6 +390,13 @@ func TestDBRefreshRepoVerify(t *testing.T) {
|
|||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
tokenWithBadID := "404" + token[1:]
|
||||
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
||||
|
||||
tests := []struct {
|
||||
token string
|
||||
creds oidc.ClientCredentials
|
||||
|
@ -390,7 +404,39 @@ func TestDBRefreshRepoVerify(t *testing.T) {
|
|||
expected string
|
||||
}{
|
||||
{
|
||||
"invalid-token-foo",
|
||||
"invalid-token-format",
|
||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||
refresh.ErrorInvalidToken,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"b/invalid-base64-encoded-format",
|
||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||
refresh.ErrorInvalidToken,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"1/invalid-base64-encoded-format",
|
||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||
refresh.ErrorInvalidToken,
|
||||
"",
|
||||
},
|
||||
{
|
||||
token + "corrupted-token-payload",
|
||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||
refresh.ErrorInvalidToken,
|
||||
"",
|
||||
},
|
||||
{
|
||||
// The token's ID content is invalid.
|
||||
tokenWithBadID,
|
||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||
refresh.ErrorInvalidToken,
|
||||
"",
|
||||
},
|
||||
{
|
||||
// The token's payload content is invalid.
|
||||
tokenWithBadPayload,
|
||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||
refresh.ErrorInvalidToken,
|
||||
"",
|
||||
|
@ -428,13 +474,42 @@ func TestDBRefreshRepoRevoke(t *testing.T) {
|
|||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
tokenWithBadID := "404" + token[1:]
|
||||
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
||||
|
||||
tests := []struct {
|
||||
token string
|
||||
userID string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"invalid-token-foo",
|
||||
"invalid-token-format",
|
||||
"user-foo",
|
||||
refresh.ErrorInvalidToken,
|
||||
},
|
||||
{
|
||||
"1/invalid-base64-encoded-format",
|
||||
"user-foo",
|
||||
refresh.ErrorInvalidToken,
|
||||
},
|
||||
{
|
||||
token + "corrupted-token-payload",
|
||||
"user-foo",
|
||||
refresh.ErrorInvalidToken,
|
||||
},
|
||||
{
|
||||
// The token's ID is invalid.
|
||||
tokenWithBadID,
|
||||
"user-foo",
|
||||
refresh.ErrorInvalidToken,
|
||||
},
|
||||
{
|
||||
// The token's payload is invalid.
|
||||
tokenWithBadPayload,
|
||||
"user-foo",
|
||||
refresh.ErrorInvalidToken,
|
||||
},
|
||||
|
|
|
@ -10,9 +10,9 @@ import (
|
|||
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
|
||||
func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) {
|
||||
var tokenIdx int
|
||||
tokenGenerator := func() (string, error) {
|
||||
tokenGenerator := func() ([]byte, error) {
|
||||
tokenIdx++
|
||||
return fmt.Sprintf("refresh-%d", tokenIdx), nil
|
||||
return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
|
||||
}
|
||||
return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package refresh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
@ -17,25 +17,26 @@ const (
|
|||
var (
|
||||
ErrorInvalidUserID = errors.New("invalid user ID")
|
||||
ErrorInvalidClientID = errors.New("invalid client ID")
|
||||
ErrorInvalidToken = errors.New("invalid token")
|
||||
|
||||
ErrorInvalidToken = errors.New("invalid token")
|
||||
)
|
||||
|
||||
type RefreshTokenGenerator func() (string, error)
|
||||
type RefreshTokenGenerator func() ([]byte, error)
|
||||
|
||||
func (g RefreshTokenGenerator) Generate() (string, error) {
|
||||
func (g RefreshTokenGenerator) Generate() ([]byte, error) {
|
||||
return g()
|
||||
}
|
||||
|
||||
func DefaultRefreshTokenGenerator() (string, error) {
|
||||
func DefaultRefreshTokenGenerator() ([]byte, error) {
|
||||
// TODO(yifan) Remove this duplicated token generate function.
|
||||
b := make([]byte, DefaultRefreshTokenPayloadLength)
|
||||
n, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
} else if n != DefaultRefreshTokenPayloadLength {
|
||||
return "", errors.New("unable to read enough random bytes")
|
||||
return nil, errors.New("unable to read enough random bytes")
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
return b, nil
|
||||
}
|
||||
|
||||
type RefreshTokenRepo interface {
|
||||
|
@ -52,7 +53,7 @@ type RefreshTokenRepo interface {
|
|||
}
|
||||
|
||||
type refreshToken struct {
|
||||
payload string
|
||||
payload []byte
|
||||
userID string
|
||||
clientID string
|
||||
}
|
||||
|
@ -63,21 +64,21 @@ type memRefreshTokenRepo struct {
|
|||
}
|
||||
|
||||
// buildToken combines the token ID and token payload to create a new token.
|
||||
func buildToken(tokenID int, tokenPayload string) string {
|
||||
func buildToken(tokenID int, tokenPayload []byte) string {
|
||||
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, tokenPayload)
|
||||
}
|
||||
|
||||
// parseToken parses a token and returns the token ID and token payload.
|
||||
func parseToken(token string) (int, string, error) {
|
||||
func parseToken(token string) (int, []byte, error) {
|
||||
parts := strings.SplitN(token, TokenDelimer, 2)
|
||||
if len(parts) != 2 {
|
||||
return -1, "", ErrorInvalidToken
|
||||
return -1, nil, ErrorInvalidToken
|
||||
}
|
||||
id, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return -1, "", ErrorInvalidToken
|
||||
return -1, nil, ErrorInvalidToken
|
||||
}
|
||||
return id, parts[1], nil
|
||||
return id, []byte(parts[1]), nil
|
||||
}
|
||||
|
||||
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.
|
||||
|
@ -131,7 +132,7 @@ func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) {
|
|||
return "", ErrorInvalidToken
|
||||
}
|
||||
|
||||
if record.payload != tokenPayload {
|
||||
if !bytes.Equal(record.payload, tokenPayload) {
|
||||
return "", ErrorInvalidToken
|
||||
}
|
||||
|
||||
|
@ -153,7 +154,7 @@ func (r *memRefreshTokenRepo) Revoke(userID, token string) error {
|
|||
return ErrorInvalidToken
|
||||
}
|
||||
|
||||
if record.payload != tokenPayload {
|
||||
if !bytes.Equal(record.payload, tokenPayload) {
|
||||
return ErrorInvalidToken
|
||||
}
|
||||
|
||||
|
|
|
@ -397,7 +397,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
signer jose.Signer
|
||||
argCC oidc.ClientCredentials
|
||||
argKey string
|
||||
err string
|
||||
err error
|
||||
scope []string
|
||||
refreshToken string
|
||||
}{
|
||||
|
@ -423,7 +423,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
signer: signerFixture,
|
||||
argCC: ccFixture,
|
||||
argKey: "foo",
|
||||
err: oauth2.ErrorInvalidGrant,
|
||||
err: oauth2.NewError(oauth2.ErrorInvalidGrant),
|
||||
scope: []string{"openid", "offline_access"},
|
||||
},
|
||||
|
||||
|
@ -432,7 +432,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
signer: signerFixture,
|
||||
argCC: oidc.ClientCredentials{ID: "YYY"},
|
||||
argKey: keyFixture,
|
||||
err: oauth2.ErrorInvalidClient,
|
||||
err: oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
scope: []string{"openid", "offline_access"},
|
||||
},
|
||||
|
||||
|
@ -441,7 +441,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
||||
argCC: ccFixture,
|
||||
argKey: keyFixture,
|
||||
err: oauth2.ErrorServerError,
|
||||
err: oauth2.NewError(oauth2.ErrorServerError),
|
||||
scope: []string{"openid", "offline_access"},
|
||||
},
|
||||
}
|
||||
|
@ -502,18 +502,14 @@ func TestServerTokenFail(t *testing.T) {
|
|||
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
|
||||
panic("")
|
||||
}
|
||||
if tt.err == "" {
|
||||
if err != nil {
|
||||
t.Errorf("case %d: got non-nil error: %v", i, err)
|
||||
} else if jwt == nil {
|
||||
t.Errorf("case %d: got nil JWT", i)
|
||||
}
|
||||
} else {
|
||||
if err.Error() != tt.err {
|
||||
t.Errorf("case %d: want err %q, got %q", i, tt.err, err.Error())
|
||||
} else if jwt != nil {
|
||||
t.Errorf("case %d: got non-nil JWT", i)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.err) {
|
||||
t.Errorf("case %d: expect %v, got %v", i, tt.err, err)
|
||||
}
|
||||
if err == nil && jwt == nil {
|
||||
t.Errorf("case %d: got nil JWT", i)
|
||||
}
|
||||
if err != nil && jwt != nil {
|
||||
t.Errorf("case %d: got non-nil JWT %v", i, jwt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -537,7 +533,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
clientID string // The client that associates with the token.
|
||||
creds oidc.ClientCredentials
|
||||
signer jose.Signer
|
||||
err string
|
||||
err error
|
||||
}{
|
||||
// Everything is good.
|
||||
{
|
||||
|
@ -545,7 +541,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
// Invalid refresh token(malformatted).
|
||||
{
|
||||
|
@ -553,15 +549,23 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
oauth2.ErrorInvalidRequest,
|
||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||
},
|
||||
// Invalid refresh token.
|
||||
// Invalid refresh token(invalid payload content).
|
||||
{
|
||||
"0/refresh-1",
|
||||
"0/refresh-2",
|
||||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
oauth2.ErrorInvalidRequest,
|
||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||
},
|
||||
// Invalid refresh token(invalid ID content).
|
||||
{
|
||||
"1/refresh-2",
|
||||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||
},
|
||||
// Invalid client(client is not associated with the token).
|
||||
{
|
||||
|
@ -569,7 +573,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
"XXX",
|
||||
credYYY,
|
||||
signerFixture,
|
||||
oauth2.ErrorInvalidClient,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
},
|
||||
// Invalid client(no client ID).
|
||||
{
|
||||
|
@ -577,7 +581,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
"XXX",
|
||||
oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
||||
signerFixture,
|
||||
oauth2.ErrorInvalidClient,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
},
|
||||
// Invalid client(no such client).
|
||||
{
|
||||
|
@ -585,7 +589,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
"XXX",
|
||||
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
||||
signerFixture,
|
||||
oauth2.ErrorInvalidClient,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
},
|
||||
// Invalid client(no secrets).
|
||||
{
|
||||
|
@ -593,7 +597,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
"XXX",
|
||||
oidc.ClientCredentials{ID: "XXX"},
|
||||
signerFixture,
|
||||
oauth2.ErrorInvalidClient,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
},
|
||||
// Invalid client(invalid secret).
|
||||
{
|
||||
|
@ -601,7 +605,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
"XXX",
|
||||
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
|
||||
signerFixture,
|
||||
oauth2.ErrorInvalidClient,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
},
|
||||
// Signing operation fails.
|
||||
{
|
||||
|
@ -609,7 +613,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
"XXX",
|
||||
credXXX,
|
||||
&StaticSigner{sig: nil, err: errors.New("fail")},
|
||||
oauth2.ErrorServerError,
|
||||
oauth2.NewError(oauth2.ErrorServerError),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -646,10 +650,8 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
}
|
||||
|
||||
jwt, err := srv.RefreshToken(tt.creds, tt.token)
|
||||
if err != nil {
|
||||
if err.Error() != tt.err {
|
||||
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.err) {
|
||||
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
|
||||
}
|
||||
|
||||
if jwt != nil {
|
||||
|
@ -715,7 +717,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
srv.UserRepo = userRepo
|
||||
|
||||
_, err = srv.RefreshToken(credXXX, "0/refresh-1")
|
||||
if err == nil || err.Error() != oauth2.ErrorServerError {
|
||||
t.Errorf("Expect: %v, got: %v", oauth2.ErrorServerError, err)
|
||||
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
|
||||
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue