forked from mystiq/dex
storage: add extra fields to refresh token and update method
This commit is contained in:
parent
c66cce8b40
commit
312ca7491e
7 changed files with 180 additions and 43 deletions
|
@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
|
||||||
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
id := storage.NewID()
|
id := storage.NewID()
|
||||||
refresh := storage.RefreshToken{
|
refresh := storage.RefreshToken{
|
||||||
RefreshToken: id,
|
ID: id,
|
||||||
ClientID: "client_id",
|
Token: "bar",
|
||||||
ConnectorID: "client_secret",
|
Nonce: "foo",
|
||||||
Scopes: []string{"openid", "email", "profile"},
|
ClientID: "client_id",
|
||||||
|
ConnectorID: "client_secret",
|
||||||
|
Scopes: []string{"openid", "email", "profile"},
|
||||||
|
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
||||||
Claims: storage.Claims{
|
Claims: storage.Claims{
|
||||||
UserID: "1",
|
UserID: "1",
|
||||||
Username: "jane",
|
Username: "jane",
|
||||||
|
@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
|
|
||||||
getAndCompare(id, refresh)
|
getAndCompare(id, refresh)
|
||||||
|
|
||||||
|
updatedAt := time.Now().UTC().Round(time.Millisecond)
|
||||||
|
|
||||||
|
updater := func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
|
r.Token = "spam"
|
||||||
|
r.LastUsed = updatedAt
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
if err := s.UpdateRefreshToken(id, updater); err != nil {
|
||||||
|
t.Errorf("failed to udpate refresh token: %v", err)
|
||||||
|
}
|
||||||
|
refresh.Token = "spam"
|
||||||
|
refresh.LastUsed = updatedAt
|
||||||
|
getAndCompare(id, refresh)
|
||||||
|
|
||||||
if err := s.DeleteRefresh(id); err != nil {
|
if err := s.DeleteRefresh(id); err != nil {
|
||||||
t.Fatalf("failed to delete refresh request: %v", err)
|
t.Fatalf("failed to delete refresh request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *client) CreateRefresh(r storage.RefreshToken) error {
|
func (cli *client) CreateRefresh(r storage.RefreshToken) error {
|
||||||
refresh := RefreshToken{
|
return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
|
||||||
TypeMeta: k8sapi.TypeMeta{
|
|
||||||
Kind: kindRefreshToken,
|
|
||||||
APIVersion: cli.apiVersion,
|
|
||||||
},
|
|
||||||
ObjectMeta: k8sapi.ObjectMeta{
|
|
||||||
Name: r.RefreshToken,
|
|
||||||
Namespace: cli.namespace,
|
|
||||||
},
|
|
||||||
ClientID: r.ClientID,
|
|
||||||
ConnectorID: r.ConnectorID,
|
|
||||||
Scopes: r.Scopes,
|
|
||||||
Nonce: r.Nonce,
|
|
||||||
Claims: fromStorageClaims(r.Claims),
|
|
||||||
ConnectorData: r.ConnectorData,
|
|
||||||
}
|
|
||||||
return cli.post(resourceRefreshToken, refresh)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||||
|
@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) {
|
func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||||
var r RefreshToken
|
r, err := cli.getRefreshToken(id)
|
||||||
if err := cli.get(resourceRefreshToken, id, &r); err != nil {
|
if err != nil {
|
||||||
return storage.RefreshToken{}, err
|
return storage.RefreshToken{}, err
|
||||||
}
|
}
|
||||||
return storage.RefreshToken{
|
return toStorageRefreshToken(r), nil
|
||||||
RefreshToken: r.ObjectMeta.Name,
|
}
|
||||||
ClientID: r.ClientID,
|
|
||||||
ConnectorID: r.ConnectorID,
|
func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
|
||||||
Scopes: r.Scopes,
|
err = cli.get(resourceRefreshToken, id, &r)
|
||||||
Nonce: r.Nonce,
|
return
|
||||||
Claims: toStorageClaims(r.Claims),
|
|
||||||
ConnectorData: r.ConnectorData,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *client) ListClients() ([]storage.Client, error) {
|
func (cli *client) ListClients() ([]storage.Client, error) {
|
||||||
|
@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error {
|
||||||
return cli.delete(resourcePassword, p.ObjectMeta.Name)
|
return cli.delete(resourcePassword, p.ObjectMeta.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||||
|
r, err := cli.getRefreshToken(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updated, err := updater(toStorageRefreshToken(r))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updated.ID = id
|
||||||
|
|
||||||
|
newToken := cli.fromStorageRefreshToken(updated)
|
||||||
|
newToken.ObjectMeta = r.ObjectMeta
|
||||||
|
return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken)
|
||||||
|
}
|
||||||
|
|
||||||
func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||||
c, err := cli.getClient(id)
|
c, err := cli.getClient(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -362,9 +362,14 @@ type RefreshToken struct {
|
||||||
k8sapi.TypeMeta `json:",inline"`
|
k8sapi.TypeMeta `json:",inline"`
|
||||||
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||||
|
|
||||||
|
CreatedAt time.Time
|
||||||
|
LastUsed time.Time
|
||||||
|
|
||||||
ClientID string `json:"clientID"`
|
ClientID string `json:"clientID"`
|
||||||
Scopes []string `json:"scopes,omitempty"`
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
|
||||||
|
Token string `json:"token,omitempty"`
|
||||||
|
|
||||||
Nonce string `json:"nonce,omitempty"`
|
Nonce string `json:"nonce,omitempty"`
|
||||||
|
|
||||||
Claims Claims `json:"claims,omitempty"`
|
Claims Claims `json:"claims,omitempty"`
|
||||||
|
@ -379,6 +384,43 @@ type RefreshList struct {
|
||||||
RefreshTokens []RefreshToken `json:"items"`
|
RefreshTokens []RefreshToken `json:"items"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
|
||||||
|
return storage.RefreshToken{
|
||||||
|
ID: r.ObjectMeta.Name,
|
||||||
|
Token: r.Token,
|
||||||
|
CreatedAt: r.CreatedAt,
|
||||||
|
LastUsed: r.LastUsed,
|
||||||
|
ClientID: r.ClientID,
|
||||||
|
ConnectorID: r.ConnectorID,
|
||||||
|
ConnectorData: r.ConnectorData,
|
||||||
|
Scopes: r.Scopes,
|
||||||
|
Nonce: r.Nonce,
|
||||||
|
Claims: toStorageClaims(r.Claims),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
|
||||||
|
return RefreshToken{
|
||||||
|
TypeMeta: k8sapi.TypeMeta{
|
||||||
|
Kind: kindRefreshToken,
|
||||||
|
APIVersion: cli.apiVersion,
|
||||||
|
},
|
||||||
|
ObjectMeta: k8sapi.ObjectMeta{
|
||||||
|
Name: r.ID,
|
||||||
|
Namespace: cli.namespace,
|
||||||
|
},
|
||||||
|
Token: r.Token,
|
||||||
|
CreatedAt: r.CreatedAt,
|
||||||
|
LastUsed: r.LastUsed,
|
||||||
|
ClientID: r.ClientID,
|
||||||
|
ConnectorID: r.ConnectorID,
|
||||||
|
ConnectorData: r.ConnectorData,
|
||||||
|
Scopes: r.Scopes,
|
||||||
|
Nonce: r.Nonce,
|
||||||
|
Claims: fromStorageClaims(r.Claims),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Keys is a mirrored struct from storage with JSON struct tags and Kubernetes
|
// Keys is a mirrored struct from storage with JSON struct tags and Kubernetes
|
||||||
// type metadata.
|
// type metadata.
|
||||||
type Keys struct {
|
type Keys struct {
|
||||||
|
|
|
@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) {
|
||||||
|
|
||||||
func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) {
|
func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) {
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
if _, ok := s.refreshTokens[r.RefreshToken]; ok {
|
if _, ok := s.refreshTokens[r.ID]; ok {
|
||||||
err = storage.ErrAlreadyExists
|
err = storage.ErrAlreadyExists
|
||||||
} else {
|
} else {
|
||||||
s.refreshTokens[r.RefreshToken] = r
|
s.refreshTokens[r.ID] = r
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
r, ok := s.refreshTokens[id]
|
||||||
|
if !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r, err = updater(r); err == nil {
|
||||||
|
s.refreshTokens[id] = r
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||||
id, client_id, scopes, nonce,
|
id, client_id, scopes, nonce,
|
||||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||||
claims_groups,
|
claims_groups,
|
||||||
connector_id, connector_data
|
connector_id, connector_data,
|
||||||
|
token, created_at, last_used
|
||||||
)
|
)
|
||||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);
|
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
|
||||||
`,
|
`,
|
||||||
r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce,
|
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||||
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
||||||
encoder(r.Claims.Groups),
|
encoder(r.Claims.Groups),
|
||||||
r.ConnectorID, r.ConnectorData,
|
r.ConnectorID, r.ConnectorData,
|
||||||
|
r.Token, r.CreatedAt, r.LastUsed,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("insert refresh_token: %v", err)
|
return fmt.Errorf("insert refresh_token: %v", err)
|
||||||
|
@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||||
|
return c.ExecTx(func(tx *trans) error {
|
||||||
|
r, err := getRefresh(tx, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if r, err = updater(r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec(`
|
||||||
|
update refresh_token
|
||||||
|
set
|
||||||
|
client_id = $1,
|
||||||
|
scopes = $2,
|
||||||
|
nonce = $3,
|
||||||
|
claims_user_id = $4,
|
||||||
|
claims_username = $5,
|
||||||
|
claims_email = $6,
|
||||||
|
claims_email_verified = $7,
|
||||||
|
claims_groups = $8,
|
||||||
|
connector_id = $9,
|
||||||
|
connector_data = $10,
|
||||||
|
token = $11,
|
||||||
|
created_at = $12,
|
||||||
|
last_used = $13
|
||||||
|
`,
|
||||||
|
r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||||
|
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
||||||
|
encoder(r.Claims.Groups),
|
||||||
|
r.ConnectorID, r.ConnectorData,
|
||||||
|
r.Token, r.CreatedAt, r.LastUsed,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update refresh token: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
|
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||||
return scanRefresh(c.QueryRow(`
|
return getRefresh(c, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
|
||||||
|
return scanRefresh(q.QueryRow(`
|
||||||
select
|
select
|
||||||
id, client_id, scopes, nonce,
|
id, client_id, scopes, nonce,
|
||||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||||
claims_groups,
|
claims_groups,
|
||||||
connector_id, connector_data
|
connector_id, connector_data,
|
||||||
|
token, created_at, last_used
|
||||||
from refresh_token where id = $1;
|
from refresh_token where id = $1;
|
||||||
`, id))
|
`, id))
|
||||||
}
|
}
|
||||||
|
@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||||
id, client_id, scopes, nonce,
|
id, client_id, scopes, nonce,
|
||||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||||
claims_groups,
|
claims_groups,
|
||||||
connector_id, connector_data
|
connector_id, connector_data,
|
||||||
|
token, created_at, last_used
|
||||||
from refresh_token;
|
from refresh_token;
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||||
|
|
||||||
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
|
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
|
||||||
err = s.Scan(
|
err = s.Scan(
|
||||||
&r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
|
&r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
|
||||||
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
|
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
|
||||||
decoder(&r.Claims.Groups),
|
decoder(&r.Claims.Groups),
|
||||||
&r.ConnectorID, &r.ConnectorData,
|
&r.ConnectorID, &r.ConnectorData,
|
||||||
|
&r.Token, &r.CreatedAt, &r.LastUsed,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
|
|
@ -155,4 +155,14 @@ var migrations = []migration{
|
||||||
);
|
);
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
stmt: `
|
||||||
|
alter table refresh_token
|
||||||
|
add column token text not null default '';
|
||||||
|
alter table refresh_token
|
||||||
|
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';
|
||||||
|
alter table refresh_token
|
||||||
|
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
|
||||||
|
`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,6 +94,7 @@ type Storage interface {
|
||||||
UpdateClient(id string, updater func(old Client) (Client, error)) error
|
UpdateClient(id string, updater func(old Client) (Client, error)) error
|
||||||
UpdateKeys(updater func(old Keys) (Keys, error)) error
|
UpdateKeys(updater func(old Keys) (Keys, error)) error
|
||||||
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
|
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
|
||||||
|
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
|
||||||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||||
|
|
||||||
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
||||||
|
@ -216,8 +217,15 @@ type AuthCode struct {
|
||||||
// RefreshToken is an OAuth2 refresh token which allows a client to request new
|
// RefreshToken is an OAuth2 refresh token which allows a client to request new
|
||||||
// tokens on the end user's behalf.
|
// tokens on the end user's behalf.
|
||||||
type RefreshToken struct {
|
type RefreshToken struct {
|
||||||
// The actual refresh token.
|
ID string
|
||||||
RefreshToken string
|
|
||||||
|
// A single token that's rotated every time the refresh token is refreshed.
|
||||||
|
//
|
||||||
|
// May be empty.
|
||||||
|
Token string
|
||||||
|
|
||||||
|
CreatedAt time.Time
|
||||||
|
LastUsed time.Time
|
||||||
|
|
||||||
// Client this refresh token is valid for.
|
// Client this refresh token is valid for.
|
||||||
ClientID string
|
ClientID string
|
||||||
|
|
Loading…
Reference in a new issue