From 312ca7491e16107f23d0235be709c3e11473716e Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Thu, 22 Dec 2016 15:56:09 -0800 Subject: [PATCH] storage: add extra fields to refresh token and update method --- storage/conformance/conformance.go | 26 +++++++++++-- storage/kubernetes/storage.go | 53 ++++++++++++------------- storage/kubernetes/types.go | 42 ++++++++++++++++++++ storage/memory/memory.go | 18 ++++++++- storage/sql/crud.go | 62 ++++++++++++++++++++++++++---- storage/sql/migrate.go | 10 +++++ storage/storage.go | 12 +++++- 7 files changed, 180 insertions(+), 43 deletions(-) diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 8cb911aa..0a6fe1c9 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) { func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id := storage.NewID() refresh := storage.RefreshToken{ - RefreshToken: id, - ClientID: "client_id", - ConnectorID: "client_secret", - Scopes: []string{"openid", "email", "profile"}, + ID: id, + Token: "bar", + Nonce: "foo", + ClientID: "client_id", + ConnectorID: "client_secret", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "1", Username: "jane", @@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { getAndCompare(id, refresh) + updatedAt := time.Now().UTC().Round(time.Millisecond) + + updater := func(r storage.RefreshToken) (storage.RefreshToken, error) { + r.Token = "spam" + r.LastUsed = updatedAt + return r, nil + } + if err := s.UpdateRefreshToken(id, updater); err != nil { + t.Errorf("failed to udpate refresh token: %v", err) + } + refresh.Token = "spam" + refresh.LastUsed = updatedAt + getAndCompare(id, refresh) + if err := s.DeleteRefresh(id); err != nil { t.Fatalf("failed to delete refresh request: %v", err) } diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index e744ab2d..102a7494 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error { } func (cli *client) CreateRefresh(r storage.RefreshToken) error { - refresh := RefreshToken{ - TypeMeta: k8sapi.TypeMeta{ - Kind: kindRefreshToken, - APIVersion: cli.apiVersion, - }, - ObjectMeta: k8sapi.ObjectMeta{ - Name: r.RefreshToken, - Namespace: cli.namespace, - }, - ClientID: r.ClientID, - ConnectorID: r.ConnectorID, - Scopes: r.Scopes, - Nonce: r.Nonce, - Claims: fromStorageClaims(r.Claims), - ConnectorData: r.ConnectorData, - } - return cli.post(resourceRefreshToken, refresh) + return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r)) } func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { @@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) { } func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) { - var r RefreshToken - if err := cli.get(resourceRefreshToken, id, &r); err != nil { + r, err := cli.getRefreshToken(id) + if err != nil { return storage.RefreshToken{}, err } - return storage.RefreshToken{ - RefreshToken: r.ObjectMeta.Name, - ClientID: r.ClientID, - ConnectorID: r.ConnectorID, - Scopes: r.Scopes, - Nonce: r.Nonce, - Claims: toStorageClaims(r.Claims), - ConnectorData: r.ConnectorData, - }, nil + return toStorageRefreshToken(r), nil +} + +func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) { + err = cli.get(resourceRefreshToken, id, &r) + return } func (cli *client) ListClients() ([]storage.Client, error) { @@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error { return cli.delete(resourcePassword, p.ObjectMeta.Name) } +func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { + r, err := cli.getRefreshToken(id) + if err != nil { + return err + } + updated, err := updater(toStorageRefreshToken(r)) + if err != nil { + return err + } + updated.ID = id + + newToken := cli.fromStorageRefreshToken(updated) + newToken.ObjectMeta = r.ObjectMeta + return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken) +} + func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { c, err := cli.getClient(id) if err != nil { diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 9009c800..660f86d8 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -362,9 +362,14 @@ type RefreshToken struct { k8sapi.TypeMeta `json:",inline"` k8sapi.ObjectMeta `json:"metadata,omitempty"` + CreatedAt time.Time + LastUsed time.Time + ClientID string `json:"clientID"` Scopes []string `json:"scopes,omitempty"` + Token string `json:"token,omitempty"` + Nonce string `json:"nonce,omitempty"` Claims Claims `json:"claims,omitempty"` @@ -379,6 +384,43 @@ type RefreshList struct { RefreshTokens []RefreshToken `json:"items"` } +func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { + return storage.RefreshToken{ + ID: r.ObjectMeta.Name, + Token: r.Token, + CreatedAt: r.CreatedAt, + LastUsed: r.LastUsed, + ClientID: r.ClientID, + ConnectorID: r.ConnectorID, + ConnectorData: r.ConnectorData, + Scopes: r.Scopes, + Nonce: r.Nonce, + Claims: toStorageClaims(r.Claims), + } +} + +func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken { + return RefreshToken{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindRefreshToken, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: r.ID, + Namespace: cli.namespace, + }, + Token: r.Token, + CreatedAt: r.CreatedAt, + LastUsed: r.LastUsed, + ClientID: r.ClientID, + ConnectorID: r.ConnectorID, + ConnectorData: r.ConnectorData, + Scopes: r.Scopes, + Nonce: r.Nonce, + Claims: fromStorageClaims(r.Claims), + } +} + // Keys is a mirrored struct from storage with JSON struct tags and Kubernetes // type metadata. type Keys struct { diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 6d609717..8bfbdce2 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) { func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) { s.tx(func() { - if _, ok := s.refreshTokens[r.RefreshToken]; ok { + if _, ok := s.refreshTokens[r.ID]; ok { err = storage.ErrAlreadyExists } else { - s.refreshTokens[r.RefreshToken] = r + s.refreshTokens[r.ID] = r } }) return @@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor }) return } + +func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) { + s.tx(func() { + r, ok := s.refreshTokens[id] + if !ok { + err = storage.ErrNotFound + return + } + if r, err = updater(r); err == nil { + s.refreshTokens[id] = r + } + }) + return +} diff --git a/storage/sql/crud.go b/storage/sql/crud.go index e3270363..494f1c20 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { id, client_id, scopes, nonce, claims_user_id, claims_username, claims_email, claims_email_verified, claims_groups, - connector_id, connector_data + connector_id, connector_data, + token, created_at, last_used ) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); `, - r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce, + r.ID, r.ClientID, encoder(r.Scopes), r.Nonce, r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, encoder(r.Claims.Groups), r.ConnectorID, r.ConnectorData, + r.Token, r.CreatedAt, r.LastUsed, ) if err != nil { return fmt.Errorf("insert refresh_token: %v", err) @@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { return nil } +func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { + return c.ExecTx(func(tx *trans) error { + r, err := getRefresh(tx, id) + if err != nil { + return err + } + if r, err = updater(r); err != nil { + return err + } + _, err = tx.Exec(` + update refresh_token + set + client_id = $1, + scopes = $2, + nonce = $3, + claims_user_id = $4, + claims_username = $5, + claims_email = $6, + claims_email_verified = $7, + claims_groups = $8, + connector_id = $9, + connector_data = $10, + token = $11, + created_at = $12, + last_used = $13 + `, + r.ClientID, encoder(r.Scopes), r.Nonce, + r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, + encoder(r.Claims.Groups), + r.ConnectorID, r.ConnectorData, + r.Token, r.CreatedAt, r.LastUsed, + ) + if err != nil { + return fmt.Errorf("update refresh token: %v", err) + } + return nil + }) +} + func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { - return scanRefresh(c.QueryRow(` + return getRefresh(c, id) +} + +func getRefresh(q querier, id string) (storage.RefreshToken, error) { + return scanRefresh(q.QueryRow(` select id, client_id, scopes, nonce, claims_user_id, claims_username, claims_email, claims_email_verified, claims_groups, - connector_id, connector_data + connector_id, connector_data, + token, created_at, last_used from refresh_token where id = $1; `, id)) } @@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { id, client_id, scopes, nonce, claims_user_id, claims_username, claims_email, claims_email_verified, claims_groups, - connector_id, connector_data + connector_id, connector_data, + token, created_at, last_used from refresh_token; `) if err != nil { @@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { func scanRefresh(s scanner) (r storage.RefreshToken, err error) { err = s.Scan( - &r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce, + &r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce, &r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified, decoder(&r.Claims.Groups), &r.ConnectorID, &r.ConnectorData, + &r.Token, &r.CreatedAt, &r.LastUsed, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 3bb410aa..b2b66d39 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -155,4 +155,14 @@ var migrations = []migration{ ); `, }, + { + stmt: ` + alter table refresh_token + add column token text not null default ''; + alter table refresh_token + add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC'; + alter table refresh_token + add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC'; + `, + }, } diff --git a/storage/storage.go b/storage/storage.go index 22a9ea50..47f5dcc6 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -94,6 +94,7 @@ type Storage interface { UpdateClient(id string, updater func(old Client) (Client, error)) error UpdateKeys(updater func(old Keys) (Keys, error)) error UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error + UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error // GarbageCollect deletes all expired AuthCodes and AuthRequests. @@ -216,8 +217,15 @@ type AuthCode struct { // RefreshToken is an OAuth2 refresh token which allows a client to request new // tokens on the end user's behalf. type RefreshToken struct { - // The actual refresh token. - RefreshToken string + ID string + + // A single token that's rotated every time the refresh token is refreshed. + // + // May be empty. + Token string + + CreatedAt time.Time + LastUsed time.Time // Client this refresh token is valid for. ClientID string