forked from mystiq/dex
*: remove in memory refresh repo
This commit is contained in:
parent
7bac93aa20
commit
95560404a3
6 changed files with 34 additions and 159 deletions
|
@ -83,6 +83,13 @@ func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo {
|
||||||
|
return &refreshTokenRepo{
|
||||||
|
dbMap: dbm,
|
||||||
|
tokenGenerator: gen,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
return "", refresh.ErrorInvalidUserID
|
return "", refresh.ErrorInvalidUserID
|
||||||
|
|
|
@ -145,10 +145,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
passwordInfoRepo := user.NewPasswordInfoRepo()
|
passwordInfoRepo := user.NewPasswordInfoRepo()
|
||||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := &server.Server{
|
srv := &server.Server{
|
||||||
IssuerURL: issuerURL,
|
IssuerURL: issuerURL,
|
||||||
|
|
|
@ -3,16 +3,17 @@ package refreshtest
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/refresh"
|
"github.com/coreos/dex/refresh"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase.
|
// NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase.
|
||||||
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
|
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
|
||||||
func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) {
|
func NewTestRefreshTokenRepo() refresh.RefreshTokenRepo {
|
||||||
var tokenIdx int
|
var tokenIdx int
|
||||||
tokenGenerator := func() ([]byte, error) {
|
tokenGenerator := func() ([]byte, error) {
|
||||||
tokenIdx++
|
tokenIdx++
|
||||||
return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
|
return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
|
||||||
}
|
}
|
||||||
return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil
|
return db.NewRefreshTokenRepoWithGenerator(db.NewMemDB(), tokenGenerator)
|
||||||
}
|
}
|
||||||
|
|
123
refresh/repo.go
123
refresh/repo.go
|
@ -1,13 +1,8 @@
|
||||||
package refresh
|
package refresh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -53,121 +48,3 @@ type RefreshTokenRepo interface {
|
||||||
// Revoke deletes the refresh token if the token belongs to the given userID.
|
// Revoke deletes the refresh token if the token belongs to the given userID.
|
||||||
Revoke(userID, token string) error
|
Revoke(userID, token string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type refreshToken struct {
|
|
||||||
payload []byte
|
|
||||||
userID string
|
|
||||||
clientID string
|
|
||||||
}
|
|
||||||
|
|
||||||
type memRefreshTokenRepo struct {
|
|
||||||
store map[int]refreshToken
|
|
||||||
tokenGenerator RefreshTokenGenerator
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildToken combines the token ID and token payload to create a new token.
|
|
||||||
func buildToken(tokenID int, tokenPayload []byte) string {
|
|
||||||
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseToken parses a token and returns the token ID and token payload.
|
|
||||||
func parseToken(token string) (int, []byte, error) {
|
|
||||||
parts := strings.SplitN(token, TokenDelimer, 2)
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return -1, nil, ErrorInvalidToken
|
|
||||||
}
|
|
||||||
id, err := strconv.Atoi(parts[0])
|
|
||||||
if err != nil {
|
|
||||||
return -1, nil, ErrorInvalidToken
|
|
||||||
}
|
|
||||||
tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
|
|
||||||
if err != nil {
|
|
||||||
return -1, nil, ErrorInvalidToken
|
|
||||||
}
|
|
||||||
return id, tokenPayload, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.
|
|
||||||
func NewRefreshTokenRepo() RefreshTokenRepo {
|
|
||||||
return NewRefreshTokenRepoWithTokenGenerator(DefaultRefreshTokenGenerator)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRefreshTokenRepoWithTokenGenerator(tokenGenerator RefreshTokenGenerator) RefreshTokenRepo {
|
|
||||||
repo := &memRefreshTokenRepo{}
|
|
||||||
repo.store = make(map[int]refreshToken)
|
|
||||||
repo.tokenGenerator = tokenGenerator
|
|
||||||
return repo
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memRefreshTokenRepo) Create(userID, clientID string) (string, error) {
|
|
||||||
// Validate userID.
|
|
||||||
if userID == "" {
|
|
||||||
return "", ErrorInvalidUserID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate clientID.
|
|
||||||
if clientID == "" {
|
|
||||||
return "", ErrorInvalidClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate and store token.
|
|
||||||
tokenPayload, err := r.tokenGenerator.Generate()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenID := len(r.store) // Should only be used in single threaded tests.
|
|
||||||
|
|
||||||
// No limits on the number of tokens per user/client for this in-memory repo.
|
|
||||||
r.store[tokenID] = refreshToken{
|
|
||||||
payload: tokenPayload,
|
|
||||||
userID: userID,
|
|
||||||
clientID: clientID,
|
|
||||||
}
|
|
||||||
return buildToken(tokenID, tokenPayload), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) {
|
|
||||||
tokenID, tokenPayload, err := parseToken(token)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
record, ok := r.store[tokenID]
|
|
||||||
if !ok {
|
|
||||||
return "", ErrorInvalidToken
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(record.payload, tokenPayload) {
|
|
||||||
return "", ErrorInvalidToken
|
|
||||||
}
|
|
||||||
|
|
||||||
if record.clientID != clientID {
|
|
||||||
return "", ErrorInvalidClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
return record.userID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memRefreshTokenRepo) Revoke(userID, token string) error {
|
|
||||||
tokenID, tokenPayload, err := parseToken(token)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
record, ok := r.store[tokenID]
|
|
||||||
if !ok {
|
|
||||||
return ErrorInvalidToken
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(record.payload, tokenPayload) {
|
|
||||||
return ErrorInvalidToken
|
|
||||||
}
|
|
||||||
|
|
||||||
if record.userID != userID {
|
|
||||||
return ErrorInvalidUserID
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(r.store, tokenID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ import (
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/email"
|
"github.com/coreos/dex/email"
|
||||||
"github.com/coreos/dex/refresh"
|
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
sessionmanager "github.com/coreos/dex/session/manager"
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
|
@ -139,7 +138,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
||||||
|
|
||||||
pwiRepo := user.NewPasswordInfoRepo()
|
pwiRepo := user.NewPasswordInfoRepo()
|
||||||
|
|
||||||
refTokRepo := refresh.NewRefreshTokenRepo()
|
refTokRepo := db.NewRefreshTokenRepo(db.NewMemDB())
|
||||||
|
|
||||||
txnFactory := repo.InMemTransactionFactory
|
txnFactory := repo.InMemTransactionFactory
|
||||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
||||||
|
|
|
@ -351,10 +351,7 @@ func TestServerCodeToken(t *testing.T) {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
|
@ -376,8 +373,10 @@ func TestServerCodeToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Have 'offline_access' in scope, should get non-empty refresh token.
|
// Have 'offline_access' in scope, should get non-empty refresh token.
|
||||||
{
|
{
|
||||||
|
// NOTE(ericchiang): This test assumes that the database ID of the first
|
||||||
|
// refresh token will be "1".
|
||||||
scope: []string{"openid", "offline_access"},
|
scope: []string{"openid", "offline_access"},
|
||||||
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -475,11 +474,13 @@ func TestServerTokenFail(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
// control test case to make sure fixtures check out
|
// control test case to make sure fixtures check out
|
||||||
{
|
{
|
||||||
|
// NOTE(ericchiang): This test assumes that the database ID of the first
|
||||||
|
// refresh token will be "1".
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
argCC: ccFixture,
|
argCC: ccFixture,
|
||||||
argKey: keyFixture,
|
argKey: keyFixture,
|
||||||
scope: []string{"openid", "offline_access"},
|
scope: []string{"openid", "offline_access"},
|
||||||
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
},
|
},
|
||||||
|
|
||||||
// no 'offline_access' in 'scope', should get empty refresh token
|
// no 'offline_access' in 'scope', should get empty refresh token
|
||||||
|
@ -549,10 +550,7 @@ func TestServerTokenFail(t *testing.T) {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: issuerURL,
|
IssuerURL: issuerURL,
|
||||||
|
@ -600,6 +598,8 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
|
|
||||||
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
||||||
|
|
||||||
|
// NOTE(ericchiang): These tests assume that the database ID of the first
|
||||||
|
// refresh token will be "1".
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
token string
|
token string
|
||||||
clientID string // The client that associates with the token.
|
clientID string // The client that associates with the token.
|
||||||
|
@ -609,7 +609,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
// Everything is good.
|
// Everything is good.
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
"XXX",
|
||||||
credXXX,
|
credXXX,
|
||||||
signerFixture,
|
signerFixture,
|
||||||
|
@ -625,7 +625,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid refresh token(invalid payload content).
|
// Invalid refresh token(invalid payload content).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
||||||
"XXX",
|
"XXX",
|
||||||
credXXX,
|
credXXX,
|
||||||
signerFixture,
|
signerFixture,
|
||||||
|
@ -633,7 +633,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid refresh token(invalid ID content).
|
// Invalid refresh token(invalid ID content).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
"XXX",
|
||||||
credXXX,
|
credXXX,
|
||||||
signerFixture,
|
signerFixture,
|
||||||
|
@ -641,7 +641,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(client is not associated with the token).
|
// Invalid client(client is not associated with the token).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
"XXX",
|
||||||
credYYY,
|
credYYY,
|
||||||
signerFixture,
|
signerFixture,
|
||||||
|
@ -649,7 +649,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(no client ID).
|
// Invalid client(no client ID).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
"XXX",
|
||||||
oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
||||||
signerFixture,
|
signerFixture,
|
||||||
|
@ -657,7 +657,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(no such client).
|
// Invalid client(no such client).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
"XXX",
|
||||||
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
||||||
signerFixture,
|
signerFixture,
|
||||||
|
@ -665,7 +665,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(no secrets).
|
// Invalid client(no secrets).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
"XXX",
|
||||||
oidc.ClientCredentials{ID: "XXX"},
|
oidc.ClientCredentials{ID: "XXX"},
|
||||||
signerFixture,
|
signerFixture,
|
||||||
|
@ -673,7 +673,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(invalid secret).
|
// Invalid client(invalid secret).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
"XXX",
|
||||||
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
|
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
|
||||||
signerFixture,
|
signerFixture,
|
||||||
|
@ -681,7 +681,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Signing operation fails.
|
// Signing operation fails.
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
"XXX",
|
||||||
credXXX,
|
credXXX,
|
||||||
&StaticSigner{sig: nil, err: errors.New("fail")},
|
&StaticSigner{sig: nil, err: errors.New("fail")},
|
||||||
|
@ -704,10 +704,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: issuerURL,
|
IssuerURL: issuerURL,
|
||||||
|
@ -764,10 +761,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: issuerURL,
|
IssuerURL: issuerURL,
|
||||||
|
@ -788,7 +782,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
srv.UserRepo = userRepo
|
srv.UserRepo = userRepo
|
||||||
|
|
||||||
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
||||||
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
|
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
|
||||||
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
|
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue