forked from mystiq/dex
*: remove in memory refresh repo
This commit is contained in:
parent
7bac93aa20
commit
95560404a3
6 changed files with 34 additions and 159 deletions
|
@ -83,6 +83,13 @@ func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo {
|
|||
}
|
||||
}
|
||||
|
||||
func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo {
|
||||
return &refreshTokenRepo{
|
||||
dbMap: dbm,
|
||||
tokenGenerator: gen,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
||||
if userID == "" {
|
||||
return "", refresh.ErrorInvalidUserID
|
||||
|
|
|
@ -145,10 +145,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
}
|
||||
|
||||
passwordInfoRepo := user.NewPasswordInfoRepo()
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &server.Server{
|
||||
IssuerURL: issuerURL,
|
||||
|
|
|
@ -3,16 +3,17 @@ package refreshtest
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/refresh"
|
||||
)
|
||||
|
||||
// NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase.
|
||||
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
|
||||
func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) {
|
||||
func NewTestRefreshTokenRepo() refresh.RefreshTokenRepo {
|
||||
var tokenIdx int
|
||||
tokenGenerator := func() ([]byte, error) {
|
||||
tokenIdx++
|
||||
return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
|
||||
}
|
||||
return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil
|
||||
return db.NewRefreshTokenRepoWithGenerator(db.NewMemDB(), tokenGenerator)
|
||||
}
|
||||
|
|
123
refresh/repo.go
123
refresh/repo.go
|
@ -1,13 +1,8 @@
|
|||
package refresh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -53,121 +48,3 @@ type RefreshTokenRepo interface {
|
|||
// Revoke deletes the refresh token if the token belongs to the given userID.
|
||||
Revoke(userID, token string) error
|
||||
}
|
||||
|
||||
type refreshToken struct {
|
||||
payload []byte
|
||||
userID string
|
||||
clientID string
|
||||
}
|
||||
|
||||
type memRefreshTokenRepo struct {
|
||||
store map[int]refreshToken
|
||||
tokenGenerator RefreshTokenGenerator
|
||||
}
|
||||
|
||||
// buildToken combines the token ID and token payload to create a new token.
|
||||
func buildToken(tokenID int, tokenPayload []byte) string {
|
||||
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
|
||||
}
|
||||
|
||||
// parseToken parses a token and returns the token ID and token payload.
|
||||
func parseToken(token string) (int, []byte, error) {
|
||||
parts := strings.SplitN(token, TokenDelimer, 2)
|
||||
if len(parts) != 2 {
|
||||
return -1, nil, ErrorInvalidToken
|
||||
}
|
||||
id, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return -1, nil, ErrorInvalidToken
|
||||
}
|
||||
tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return -1, nil, ErrorInvalidToken
|
||||
}
|
||||
return id, tokenPayload, nil
|
||||
}
|
||||
|
||||
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.
|
||||
func NewRefreshTokenRepo() RefreshTokenRepo {
|
||||
return NewRefreshTokenRepoWithTokenGenerator(DefaultRefreshTokenGenerator)
|
||||
}
|
||||
|
||||
func NewRefreshTokenRepoWithTokenGenerator(tokenGenerator RefreshTokenGenerator) RefreshTokenRepo {
|
||||
repo := &memRefreshTokenRepo{}
|
||||
repo.store = make(map[int]refreshToken)
|
||||
repo.tokenGenerator = tokenGenerator
|
||||
return repo
|
||||
}
|
||||
|
||||
func (r *memRefreshTokenRepo) Create(userID, clientID string) (string, error) {
|
||||
// Validate userID.
|
||||
if userID == "" {
|
||||
return "", ErrorInvalidUserID
|
||||
}
|
||||
|
||||
// Validate clientID.
|
||||
if clientID == "" {
|
||||
return "", ErrorInvalidClientID
|
||||
}
|
||||
|
||||
// Generate and store token.
|
||||
tokenPayload, err := r.tokenGenerator.Generate()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
tokenID := len(r.store) // Should only be used in single threaded tests.
|
||||
|
||||
// No limits on the number of tokens per user/client for this in-memory repo.
|
||||
r.store[tokenID] = refreshToken{
|
||||
payload: tokenPayload,
|
||||
userID: userID,
|
||||
clientID: clientID,
|
||||
}
|
||||
return buildToken(tokenID, tokenPayload), nil
|
||||
}
|
||||
|
||||
func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) {
|
||||
tokenID, tokenPayload, err := parseToken(token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
record, ok := r.store[tokenID]
|
||||
if !ok {
|
||||
return "", ErrorInvalidToken
|
||||
}
|
||||
|
||||
if !bytes.Equal(record.payload, tokenPayload) {
|
||||
return "", ErrorInvalidToken
|
||||
}
|
||||
|
||||
if record.clientID != clientID {
|
||||
return "", ErrorInvalidClientID
|
||||
}
|
||||
|
||||
return record.userID, nil
|
||||
}
|
||||
|
||||
func (r *memRefreshTokenRepo) Revoke(userID, token string) error {
|
||||
tokenID, tokenPayload, err := parseToken(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record, ok := r.store[tokenID]
|
||||
if !ok {
|
||||
return ErrorInvalidToken
|
||||
}
|
||||
|
||||
if !bytes.Equal(record.payload, tokenPayload) {
|
||||
return ErrorInvalidToken
|
||||
}
|
||||
|
||||
if record.userID != userID {
|
||||
return ErrorInvalidUserID
|
||||
}
|
||||
|
||||
delete(r.store, tokenID)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@ import (
|
|||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/coreos/dex/repo"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
|
@ -139,7 +138,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
|
||||
pwiRepo := user.NewPasswordInfoRepo()
|
||||
|
||||
refTokRepo := refresh.NewRefreshTokenRepo()
|
||||
refTokRepo := db.NewRefreshTokenRepo(db.NewMemDB())
|
||||
|
||||
txnFactory := repo.InMemTransactionFactory
|
||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
||||
|
|
|
@ -351,10 +351,7 @@ func TestServerCodeToken(t *testing.T) {
|
|||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
|
@ -376,8 +373,10 @@ func TestServerCodeToken(t *testing.T) {
|
|||
},
|
||||
// Have 'offline_access' in scope, should get non-empty refresh token.
|
||||
{
|
||||
// NOTE(ericchiang): This test assumes that the database ID of the first
|
||||
// refresh token will be "1".
|
||||
scope: []string{"openid", "offline_access"},
|
||||
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -475,11 +474,13 @@ func TestServerTokenFail(t *testing.T) {
|
|||
}{
|
||||
// control test case to make sure fixtures check out
|
||||
{
|
||||
// NOTE(ericchiang): This test assumes that the database ID of the first
|
||||
// refresh token will be "1".
|
||||
signer: signerFixture,
|
||||
argCC: ccFixture,
|
||||
argKey: keyFixture,
|
||||
scope: []string{"openid", "offline_access"},
|
||||
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
},
|
||||
|
||||
// no 'offline_access' in 'scope', should get empty refresh token
|
||||
|
@ -549,10 +550,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: issuerURL,
|
||||
|
@ -600,6 +598,8 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
|
||||
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
||||
|
||||
// NOTE(ericchiang): These tests assume that the database ID of the first
|
||||
// refresh token will be "1".
|
||||
tests := []struct {
|
||||
token string
|
||||
clientID string // The client that associates with the token.
|
||||
|
@ -609,7 +609,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
}{
|
||||
// Everything is good.
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
|
@ -625,7 +625,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid refresh token(invalid payload content).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
|
@ -633,7 +633,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid refresh token(invalid ID content).
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
|
@ -641,7 +641,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(client is not associated with the token).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credYYY,
|
||||
signerFixture,
|
||||
|
@ -649,7 +649,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(no client ID).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
||||
signerFixture,
|
||||
|
@ -657,7 +657,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(no such client).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
||||
signerFixture,
|
||||
|
@ -665,7 +665,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(no secrets).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "XXX"},
|
||||
signerFixture,
|
||||
|
@ -673,7 +673,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(invalid secret).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
|
||||
signerFixture,
|
||||
|
@ -681,7 +681,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Signing operation fails.
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
&StaticSigner{sig: nil, err: errors.New("fail")},
|
||||
|
@ -704,10 +704,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: issuerURL,
|
||||
|
@ -764,10 +761,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: issuerURL,
|
||||
|
@ -788,7 +782,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
}
|
||||
srv.UserRepo = userRepo
|
||||
|
||||
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
||||
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
||||
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
|
||||
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue