diff --git a/db/refresh.go b/db/refresh.go index 66ad4ae2..f2dc193a 100644 --- a/db/refresh.go +++ b/db/refresh.go @@ -83,6 +83,13 @@ func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo { } } +func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo { + return &refreshTokenRepo{ + dbMap: dbm, + tokenGenerator: gen, + } +} + func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { if userID == "" { return "", refresh.ErrorInvalidUserID diff --git a/integration/oidc_test.go b/integration/oidc_test.go index e4fe1802..4eefcea0 100644 --- a/integration/oidc_test.go +++ b/integration/oidc_test.go @@ -145,10 +145,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { } passwordInfoRepo := user.NewPasswordInfoRepo() - refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &server.Server{ IssuerURL: issuerURL, diff --git a/refresh/refreshtest/repo.go b/refresh/refreshtest/repo.go index e93a852e..a149987f 100644 --- a/refresh/refreshtest/repo.go +++ b/refresh/refreshtest/repo.go @@ -3,16 +3,17 @@ package refreshtest import ( "fmt" + "github.com/coreos/dex/db" "github.com/coreos/dex/refresh" ) // NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase. // The tokens are in the form { refresh-1, refresh-2 ... refresh-n}. -func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) { +func NewTestRefreshTokenRepo() refresh.RefreshTokenRepo { var tokenIdx int tokenGenerator := func() ([]byte, error) { tokenIdx++ return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil } - return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil + return db.NewRefreshTokenRepoWithGenerator(db.NewMemDB(), tokenGenerator) } diff --git a/refresh/repo.go b/refresh/repo.go index 03e28e22..0c65c0e6 100644 --- a/refresh/repo.go +++ b/refresh/repo.go @@ -1,13 +1,8 @@ package refresh import ( - "bytes" "crypto/rand" - "encoding/base64" "errors" - "fmt" - "strconv" - "strings" ) const ( @@ -53,121 +48,3 @@ type RefreshTokenRepo interface { // Revoke deletes the refresh token if the token belongs to the given userID. Revoke(userID, token string) error } - -type refreshToken struct { - payload []byte - userID string - clientID string -} - -type memRefreshTokenRepo struct { - store map[int]refreshToken - tokenGenerator RefreshTokenGenerator -} - -// buildToken combines the token ID and token payload to create a new token. -func buildToken(tokenID int, tokenPayload []byte) string { - return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload)) -} - -// parseToken parses a token and returns the token ID and token payload. -func parseToken(token string) (int, []byte, error) { - parts := strings.SplitN(token, TokenDelimer, 2) - if len(parts) != 2 { - return -1, nil, ErrorInvalidToken - } - id, err := strconv.Atoi(parts[0]) - if err != nil { - return -1, nil, ErrorInvalidToken - } - tokenPayload, err := base64.URLEncoding.DecodeString(parts[1]) - if err != nil { - return -1, nil, ErrorInvalidToken - } - return id, tokenPayload, nil -} - -// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development. -func NewRefreshTokenRepo() RefreshTokenRepo { - return NewRefreshTokenRepoWithTokenGenerator(DefaultRefreshTokenGenerator) -} - -func NewRefreshTokenRepoWithTokenGenerator(tokenGenerator RefreshTokenGenerator) RefreshTokenRepo { - repo := &memRefreshTokenRepo{} - repo.store = make(map[int]refreshToken) - repo.tokenGenerator = tokenGenerator - return repo -} - -func (r *memRefreshTokenRepo) Create(userID, clientID string) (string, error) { - // Validate userID. - if userID == "" { - return "", ErrorInvalidUserID - } - - // Validate clientID. - if clientID == "" { - return "", ErrorInvalidClientID - } - - // Generate and store token. - tokenPayload, err := r.tokenGenerator.Generate() - if err != nil { - return "", err - } - - tokenID := len(r.store) // Should only be used in single threaded tests. - - // No limits on the number of tokens per user/client for this in-memory repo. - r.store[tokenID] = refreshToken{ - payload: tokenPayload, - userID: userID, - clientID: clientID, - } - return buildToken(tokenID, tokenPayload), nil -} - -func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) { - tokenID, tokenPayload, err := parseToken(token) - if err != nil { - return "", err - } - - record, ok := r.store[tokenID] - if !ok { - return "", ErrorInvalidToken - } - - if !bytes.Equal(record.payload, tokenPayload) { - return "", ErrorInvalidToken - } - - if record.clientID != clientID { - return "", ErrorInvalidClientID - } - - return record.userID, nil -} - -func (r *memRefreshTokenRepo) Revoke(userID, token string) error { - tokenID, tokenPayload, err := parseToken(token) - if err != nil { - return err - } - - record, ok := r.store[tokenID] - if !ok { - return ErrorInvalidToken - } - - if !bytes.Equal(record.payload, tokenPayload) { - return ErrorInvalidToken - } - - if record.userID != userID { - return ErrorInvalidUserID - } - - delete(r.store, tokenID) - return nil -} diff --git a/server/config.go b/server/config.go index 278369a0..60ad64da 100644 --- a/server/config.go +++ b/server/config.go @@ -17,7 +17,6 @@ import ( "github.com/coreos/dex/connector" "github.com/coreos/dex/db" "github.com/coreos/dex/email" - "github.com/coreos/dex/refresh" "github.com/coreos/dex/repo" sessionmanager "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" @@ -139,7 +138,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { pwiRepo := user.NewPasswordInfoRepo() - refTokRepo := refresh.NewRefreshTokenRepo() + refTokRepo := db.NewRefreshTokenRepo(db.NewMemDB()) txnFactory := repo.InMemTransactionFactory userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{}) diff --git a/server/server_test.go b/server/server_test.go index 7bedc333..8fef717a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -351,10 +351,7 @@ func TestServerCodeToken(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, @@ -376,8 +373,10 @@ func TestServerCodeToken(t *testing.T) { }, // Have 'offline_access' in scope, should get non-empty refresh token. { + // NOTE(ericchiang): This test assumes that the database ID of the first + // refresh token will be "1". scope: []string{"openid", "offline_access"}, - refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), }, } @@ -475,11 +474,13 @@ func TestServerTokenFail(t *testing.T) { }{ // control test case to make sure fixtures check out { + // NOTE(ericchiang): This test assumes that the database ID of the first + // refresh token will be "1". signer: signerFixture, argCC: ccFixture, argKey: keyFixture, scope: []string{"openid", "offline_access"}, - refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), }, // no 'offline_access' in 'scope', should get empty refresh token @@ -549,10 +550,7 @@ func TestServerTokenFail(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &Server{ IssuerURL: issuerURL, @@ -600,6 +598,8 @@ func TestServerRefreshToken(t *testing.T) { signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} + // NOTE(ericchiang): These tests assume that the database ID of the first + // refresh token will be "1". tests := []struct { token string clientID string // The client that associates with the token. @@ -609,7 +609,7 @@ func TestServerRefreshToken(t *testing.T) { }{ // Everything is good. { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", credXXX, signerFixture, @@ -625,7 +625,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid refresh token(invalid payload content). { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), + fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), "XXX", credXXX, signerFixture, @@ -633,7 +633,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid refresh token(invalid ID content). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", credXXX, signerFixture, @@ -641,7 +641,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(client is not associated with the token). { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", credYYY, signerFixture, @@ -649,7 +649,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(no client ID). { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", oidc.ClientCredentials{ID: "", Secret: "aaa"}, signerFixture, @@ -657,7 +657,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(no such client). { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, signerFixture, @@ -665,7 +665,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(no secrets). { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", oidc.ClientCredentials{ID: "XXX"}, signerFixture, @@ -673,7 +673,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(invalid secret). { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"}, signerFixture, @@ -681,7 +681,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Signing operation fails. { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), "XXX", credXXX, &StaticSigner{sig: nil, err: errors.New("fail")}, @@ -704,10 +704,7 @@ func TestServerRefreshToken(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &Server{ IssuerURL: issuerURL, @@ -764,10 +761,7 @@ func TestServerRefreshToken(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &Server{ IssuerURL: issuerURL, @@ -788,7 +782,7 @@ func TestServerRefreshToken(t *testing.T) { } srv.UserRepo = userRepo - _, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1")))) + _, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1")))) if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) { t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err) }