Merge pull request #793 from rithujohn191/token-revocation

storage: Add OfflineSession object to backend storage.
This commit is contained in:
rithu leena john 2017-02-09 19:46:00 -08:00 committed by GitHub
commit 53e383670a
10 changed files with 580 additions and 32 deletions

View file

@ -682,6 +682,75 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
// deleteToken determines if we need to delete the newly created refresh token
// due to a failure in updating/creating the OfflineSession object for the
// corresponding user.
var deleteToken bool
defer func() {
if deleteToken {
// Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
}
}()
tokenRef := storage.RefreshTokenRef{
ID: refresh.ID,
ClientID: refresh.ClientID,
CreatedAt: refresh.CreatedAt,
LastUsed: refresh.LastUsed,
}
// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
offlineSessions := storage.OfflineSessions{
UserID: refresh.Claims.UserID,
ConnID: refresh.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
}
offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef
// Create a new OfflineSession object for the user and add a reference object for
// the newly recieved refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
}
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
}
}
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
}
@ -815,6 +884,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return
}
lastUsed := s.now()
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
if old.Token != refresh.Token {
return old, errors.New("refresh token claimed twice")
@ -828,14 +898,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups
old.ConnectorData = ident.ConnectorData
old.LastUsed = s.now()
old.LastUsed = lastUsed
return old, nil
}
// Update LastUsed time stamp in refresh token reference object
// in offline session for the user.
if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if old.Refresh[refresh.ClientID].ID != refresh.ID {
return old, errors.New("refresh token invalid")
}
old.Refresh[refresh.ClientID].LastUsed = lastUsed
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
// Update refresh token in the storage.
if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
s.logger.Errorf("failed to update refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
}

View file

@ -971,3 +971,108 @@ func TestKeyCacher(t *testing.T) {
}
}
}
type oauth2Client struct {
config *oauth2.Config
token *oauth2.Token
server *httptest.Server
}
// TestRefreshTokenFlow tests the refresh token code flow for oauth2. The test verifies
// that only valid refresh tokens can be used to refresh an expired token.
func TestRefreshTokenFlow(t *testing.T) {
state := "state"
now := func() time.Time { return time.Now() }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Now = now
})
defer httpServer.Close()
p, err := oidc.NewProvider(ctx, httpServer.URL)
if err != nil {
t.Fatalf("failed to get provider: %v", err)
}
var oauth2Client oauth2Client
oauth2Client.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" {
// User is visiting app first time. Redirect to dex.
http.Redirect(w, r, oauth2Client.config.AuthCodeURL(state), http.StatusSeeOther)
return
}
// User is at '/callback' so they were just redirected _from_ dex.
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
if desc := q.Get("error_description"); desc != "" {
t.Errorf("got error from server %s: %s", errType, desc)
} else {
t.Errorf("got error from server %s", errType)
}
w.WriteHeader(http.StatusInternalServerError)
return
}
// Grab code, exchange for token.
if code := q.Get("code"); code != "" {
token, err := oauth2Client.config.Exchange(ctx, code)
if err != nil {
t.Errorf("failed to exchange code for token: %v", err)
return
}
oauth2Client.token = token
}
// Ensure state matches.
if gotState := q.Get("state"); gotState != state {
t.Errorf("state did not match, want=%q got=%q", state, gotState)
}
w.WriteHeader(http.StatusOK)
return
}))
defer oauth2Client.server.Close()
// Register the client above with dex.
redirectURL := oauth2Client.server.URL + "/callback"
client := storage.Client{
ID: "testclient",
Secret: "testclientsecret",
RedirectURIs: []string{redirectURL},
}
if err := s.storage.CreateClient(client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
oauth2Client.config = &oauth2.Config{
ClientID: client.ID,
ClientSecret: client.Secret,
Endpoint: p.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "email", "offline_access"},
RedirectURL: redirectURL,
}
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
t.Fatalf("get failed: %v", err)
}
tok := &oauth2.Token{
RefreshToken: oauth2Client.token.RefreshToken,
Expiry: time.Now().Add(-time.Hour),
}
// Login in again to recieve a new token.
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
t.Fatalf("get failed: %v", err)
}
// try to refresh expired token with old refresh token.
newToken, err := oauth2Client.config.TokenSource(ctx, tok).Token()
if newToken != nil {
t.Errorf("Token refreshed with invalid refresh token.")
}
}

View file

@ -47,6 +47,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
{"RefreshTokenCRUD", testRefreshTokenCRUD},
{"PasswordCRUD", testPasswordCRUD},
{"KeysCRUD", testKeysCRUD},
{"OfflineSessionCRUD", testOfflineSessionCRUD},
{"GarbageCollection", testGC},
{"TimezoneSupport", testTimezones},
})
@ -340,6 +341,60 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
}
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
session := storage.OfflineSessions{
UserID: "User",
ConnID: "Conn",
Refresh: make(map[string]*storage.RefreshTokenRef),
}
// Creating an OfflineSession with an empty Refresh list to ensure that
// an empty map is translated as expected by the storage.
if err := s.CreateOfflineSessions(session); err != nil {
t.Fatalf("create offline session: %v", err)
}
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
gr, err := s.GetOfflineSessions(userID, connID)
if err != nil {
t.Errorf("get offline session: %v", err)
return
}
if diff := pretty.Compare(want, gr); diff != "" {
t.Errorf("offline session retrieved from storage did not match: %s", diff)
}
}
getAndCompare("User", "Conn", session)
id := storage.NewID()
tokenRef := storage.RefreshTokenRef{
ID: id,
ClientID: "client_id",
CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
}
session.Refresh[tokenRef.ClientID] = &tokenRef
if err := s.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
t.Fatalf("failed to update offline session: %v", err)
}
getAndCompare("User", "Conn", session)
if err := s.DeleteOfflineSessions(session.UserID, session.ConnID); err != nil {
t.Fatalf("failed to delete offline session: %v", err)
}
if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound {
t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err)
}
}
func testKeysCRUD(t *testing.T, s storage.Storage) {
updateAndCompare := func(k storage.Keys) {
err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) {

View file

@ -58,6 +58,12 @@ func (c *client) idToName(s string) string {
return idToName(s, c.hash)
}
// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name.
// This is used when more than one field is used to uniquely identify the object.
func (c *client) offlineTokenName(userID string, connID string) string {
return offlineTokenName(userID, connID, c.hash)
}
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
@ -65,6 +71,12 @@ func idToName(s string, h func() hash.Hash) string {
return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=")
}
func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
h().Write([]byte(userID))
h().Write([]byte(connID))
return strings.TrimRight(encoding.EncodeToString(h().Sum(nil)), "=")
}
func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
basePath := "apis/"
if apiVersion == "v1" {

View file

@ -15,21 +15,23 @@ import (
)
const (
kindAuthCode = "AuthCode"
kindAuthRequest = "AuthRequest"
kindClient = "OAuth2Client"
kindRefreshToken = "RefreshToken"
kindKeys = "SigningKey"
kindPassword = "Password"
kindAuthCode = "AuthCode"
kindAuthRequest = "AuthRequest"
kindClient = "OAuth2Client"
kindRefreshToken = "RefreshToken"
kindKeys = "SigningKey"
kindPassword = "Password"
kindOfflineSessions = "OfflineSessions"
)
const (
resourceAuthCode = "authcodes"
resourceAuthRequest = "authrequests"
resourceClient = "oauth2clients"
resourceRefreshToken = "refreshtokens"
resourceKeys = "signingkeies" // Kubernetes attempts to pluralize.
resourcePassword = "passwords"
resourceAuthCode = "authcodes"
resourceAuthRequest = "authrequests"
resourceClient = "oauth2clients"
resourceRefreshToken = "refreshtokens"
resourceKeys = "signingkeies" // Kubernetes attempts to pluralize.
resourcePassword = "passwords"
resourceOfflineSessions = "offlinesessions"
)
// Config values for the Kubernetes storage type.
@ -156,6 +158,10 @@ func (cli *client) CreateRefresh(r storage.RefreshToken) error {
return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
}
func (cli *client) CreateOfflineSessions(o storage.OfflineSessions) error {
return cli.post(resourceOfflineSessions, cli.fromStorageOfflineSessions(o))
}
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
var req AuthRequest
if err := cli.get(resourceAuthRequest, id, &req); err != nil {
@ -235,6 +241,25 @@ func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
return
}
func (cli *client) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
o, err := cli.getOfflineSessions(userID, connID)
if err != nil {
return storage.OfflineSessions{}, err
}
return toStorageOfflineSessions(o), nil
}
func (cli *client) getOfflineSessions(userID string, connID string) (o OfflineSessions, err error) {
name := cli.offlineTokenName(userID, connID)
if err = cli.get(resourceOfflineSessions, name, &o); err != nil {
return OfflineSessions{}, err
}
if userID != o.UserID || connID != o.ConnID {
return OfflineSessions{}, fmt.Errorf("get offline session: wrong session retrieved")
}
return o, nil
}
func (cli *client) ListClients() ([]storage.Client, error) {
return nil, errors.New("not implemented")
}
@ -292,6 +317,15 @@ func (cli *client) DeletePassword(email string) error {
return cli.delete(resourcePassword, p.ObjectMeta.Name)
}
func (cli *client) DeleteOfflineSessions(userID string, connID string) error {
// Check for hash collition.
o, err := cli.getOfflineSessions(userID, connID)
if err != nil {
return err
}
return cli.delete(resourceOfflineSessions, o.ObjectMeta.Name)
}
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
r, err := cli.getRefreshToken(id)
if err != nil {
@ -342,6 +376,22 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor
return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword)
}
func (cli *client) UpdateOfflineSessions(userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error {
o, err := cli.getOfflineSessions(userID, connID)
if err != nil {
return err
}
updated, err := updater(toStorageOfflineSessions(o))
if err != nil {
return err
}
newOfflineSessions := cli.fromStorageOfflineSessions(updated)
newOfflineSessions.ObjectMeta = o.ObjectMeta
return cli.put(resourceOfflineSessions, o.ObjectMeta.Name, newOfflineSessions)
}
func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
firstUpdate := false
var keys Keys

View file

@ -66,6 +66,14 @@ var thirdPartyResources = []k8sapi.ThirdPartyResource{
Description: "Passwords managed by the OIDC server.",
Versions: []k8sapi.APIVersion{{Name: "v1"}},
},
{
ObjectMeta: k8sapi.ObjectMeta{
Name: "offline-sessions.oidc.coreos.com",
},
TypeMeta: tprMeta,
Description: "User sessions with an active refresh token.",
Versions: []k8sapi.APIVersion{{Name: "v1"}},
},
}
// There will only ever be a single keys resource. Maintain this by setting a
@ -465,3 +473,38 @@ func toStorageKeys(keys Keys) storage.Keys {
NextRotation: keys.NextRotation,
}
}
// OfflineSessions is a mirrored struct from storage with JSON struct tags and Kubernetes
// type metadata.
type OfflineSessions struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
}
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
return OfflineSessions{
TypeMeta: k8sapi.TypeMeta{
Kind: kindOfflineSessions,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: cli.offlineTokenName(o.UserID, o.ConnID),
Namespace: cli.namespace,
},
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
}
}
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
return storage.OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
}
}

View file

@ -13,12 +13,13 @@ import (
// New returns an in memory storage.
func New(logger logrus.FieldLogger) storage.Storage {
return &memStorage{
clients: make(map[string]storage.Client),
authCodes: make(map[string]storage.AuthCode),
refreshTokens: make(map[string]storage.RefreshToken),
authReqs: make(map[string]storage.AuthRequest),
passwords: make(map[string]storage.Password),
logger: logger,
clients: make(map[string]storage.Client),
authCodes: make(map[string]storage.AuthCode),
refreshTokens: make(map[string]storage.RefreshToken),
authReqs: make(map[string]storage.AuthRequest),
passwords: make(map[string]storage.Password),
offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
logger: logger,
}
}
@ -37,17 +38,23 @@ func (c *Config) Open(logger logrus.FieldLogger) (storage.Storage, error) {
type memStorage struct {
mu sync.Mutex
clients map[string]storage.Client
authCodes map[string]storage.AuthCode
refreshTokens map[string]storage.RefreshToken
authReqs map[string]storage.AuthRequest
passwords map[string]storage.Password
clients map[string]storage.Client
authCodes map[string]storage.AuthCode
refreshTokens map[string]storage.RefreshToken
authReqs map[string]storage.AuthRequest
passwords map[string]storage.Password
offlineSessions map[offlineSessionID]storage.OfflineSessions
keys storage.Keys
logger logrus.FieldLogger
}
type offlineSessionID struct {
userID string
connID string
}
func (s *memStorage) tx(f func()) {
s.mu.Lock()
defer s.mu.Unlock()
@ -130,6 +137,32 @@ func (s *memStorage) CreatePassword(p storage.Password) (err error) {
return
}
func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error) {
id := offlineSessionID{
userID: o.UserID,
connID: o.ConnID,
}
s.tx(func() {
if _, ok := s.offlineSessions[id]; ok {
err = storage.ErrAlreadyExists
} else {
s.offlineSessions[id] = o
}
})
return
}
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
s.tx(func() {
var ok bool
if c, ok = s.authCodes[id]; !ok {
err = storage.ErrNotFound
return
}
})
return
}
func (s *memStorage) GetPassword(email string) (p storage.Password, err error) {
email = strings.ToLower(email)
s.tx(func() {
@ -156,10 +189,10 @@ func (s *memStorage) GetKeys() (keys storage.Keys, err error) {
return
}
func (s *memStorage) GetRefresh(token string) (tok storage.RefreshToken, err error) {
func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error) {
s.tx(func() {
var ok bool
if tok, ok = s.refreshTokens[token]; !ok {
if tok, ok = s.refreshTokens[id]; !ok {
err = storage.ErrNotFound
return
}
@ -178,6 +211,21 @@ func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err err
return
}
func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage.OfflineSessions, err error) {
id := offlineSessionID{
userID: userID,
connID: connID,
}
s.tx(func() {
var ok bool
if o, ok = s.offlineSessions[id]; !ok {
err = storage.ErrNotFound
return
}
})
return
}
func (s *memStorage) ListClients() (clients []storage.Client, err error) {
s.tx(func() {
for _, client := range s.clients {
@ -228,13 +276,13 @@ func (s *memStorage) DeleteClient(id string) (err error) {
return
}
func (s *memStorage) DeleteRefresh(token string) (err error) {
func (s *memStorage) DeleteRefresh(id string) (err error) {
s.tx(func() {
if _, ok := s.refreshTokens[token]; !ok {
if _, ok := s.refreshTokens[id]; !ok {
err = storage.ErrNotFound
return
}
delete(s.refreshTokens, token)
delete(s.refreshTokens, id)
})
return
}
@ -261,13 +309,17 @@ func (s *memStorage) DeleteAuthRequest(id string) (err error) {
return
}
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err error) {
id := offlineSessionID{
userID: userID,
connID: connID,
}
s.tx(func() {
var ok bool
if c, ok = s.authCodes[id]; !ok {
if _, ok := s.offlineSessions[id]; !ok {
err = storage.ErrNotFound
return
}
delete(s.offlineSessions, id)
})
return
}
@ -338,3 +390,21 @@ func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.Refres
})
return
}
func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) {
id := offlineSessionID{
userID: userID,
connID: connID,
}
s.tx(func() {
r, ok := s.offlineSessions[id]
if !ok {
err = storage.ErrNotFound
return
}
if r, err = updater(r); err == nil {
s.offlineSessions[id] = r
}
})
return
}

View file

@ -624,6 +624,75 @@ func scanPassword(s scanner) (p storage.Password, err error) {
return p, nil
}
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
_, err := c.Exec(`
insert into offline_session (
user_id, conn_id, refresh
)
values (
$1, $2, $3
);
`,
s.UserID, s.ConnID, encoder(s.Refresh),
)
if err != nil {
return fmt.Errorf("insert offline session: %v", err)
}
return nil
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID)
if err != nil {
return err
}
newSession, err := updater(s)
if err != nil {
return err
}
_, err = tx.Exec(`
update offline_session
set
refresh = $1
where user_id = $2 AND conn_id = $3;
`,
encoder(newSession.Refresh), s.UserID, s.ConnID,
)
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
})
}
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
return getOfflineSessions(c, userID, connID)
}
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
return scanOfflineSessions(q.QueryRow(`
select
user_id, conn_id, refresh
from offline_session
where user_id = $1 AND conn_id = $2;
`, userID, connID))
}
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
err = s.Scan(
&o.UserID, &o.ConnID, decoder(&o.Refresh),
)
if err != nil {
if err == sql.ErrNoRows {
return o, storage.ErrNotFound
}
return o, fmt.Errorf("select offline session: %v", err)
}
return o, nil
}
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) }
func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) }
@ -632,6 +701,24 @@ func (c *conn) DeletePassword(email string) error {
return c.delete("password", "email", strings.ToLower(email))
}
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
if err != nil {
return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID)
}
// For now mandate that the driver implements RowsAffected. If we ever need to support
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %v", err)
}
if n < 1 {
return storage.ErrNotFound
}
return nil
}
// Do NOT call directly. Does not escape table.
func (c *conn) delete(table, field, id string) error {
result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id)

View file

@ -153,6 +153,7 @@ var migrations = []migration{
signing_key_pub bytea not null, -- JSON object
next_rotation timestamptz not null
);
`,
},
{
@ -165,4 +166,14 @@ var migrations = []migration{
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
`,
},
{
stmt: `
create table offline_session (
user_id text not null,
conn_id text not null,
refresh bytea not null,
PRIMARY KEY (user_id, conn_id)
);
`,
},
}

View file

@ -52,6 +52,7 @@ type Storage interface {
CreateAuthCode(c AuthCode) error
CreateRefresh(r RefreshToken) error
CreatePassword(p Password) error
CreateOfflineSessions(s OfflineSessions) error
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
// requests that way instead of using ErrNotFound.
@ -61,6 +62,7 @@ type Storage interface {
GetKeys() (Keys, error)
GetRefresh(id string) (RefreshToken, error)
GetPassword(email string) (Password, error)
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
ListClients() ([]Client, error)
ListRefreshTokens() ([]RefreshToken, error)
@ -72,6 +74,7 @@ type Storage interface {
DeleteClient(id string) error
DeleteRefresh(id string) error
DeletePassword(email string) error
DeleteOfflineSessions(userID string, connID string) error
// Update methods take a function for updating an object then performs that update within
// a transaction. "updater" functions may be called multiple times by a single update call.
@ -92,6 +95,7 @@ type Storage interface {
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
UpdatePassword(email string, updater func(p Password) (Password, error)) error
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
GarbageCollect(now time.Time) (GCResult, error)
@ -241,6 +245,30 @@ type RefreshToken struct {
Nonce string
}
// RefreshTokenRef is a reference object that contains metadata about refresh tokens.
type RefreshTokenRef struct {
ID string
// Client the refresh token is valid for.
ClientID string
CreatedAt time.Time
LastUsed time.Time
}
// OfflineSessions objects are sessions pertaining to users with refresh tokens.
type OfflineSessions struct {
// UserID of an end user who has logged in to the server.
UserID string
// The ID of the connector used to login the user.
ConnID string
// Refresh is a hash table of refresh token reference objects
// indexed by the ClientID of the refresh token.
Refresh map[string]*RefreshTokenRef
}
// Password is an email to password mapping managed by the storage.
type Password struct {
// Email and identifying name of the password. Emails are assumed to be valid and