storage: Add OfflineSession object to backend storage.
This commit is contained in:
parent
49f446c1a7
commit
d928ac0677
10 changed files with 580 additions and 32 deletions
|
@ -682,6 +682,75 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// deleteToken determines if we need to delete the newly created refresh token
|
||||
// due to a failure in updating/creating the OfflineSession object for the
|
||||
// corresponding user.
|
||||
var deleteToken bool
|
||||
defer func() {
|
||||
if deleteToken {
|
||||
// Delete newly created refresh token from storage.
|
||||
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
|
||||
s.logger.Errorf("failed to delete refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
tokenRef := storage.RefreshTokenRef{
|
||||
ID: refresh.ID,
|
||||
ClientID: refresh.ClientID,
|
||||
CreatedAt: refresh.CreatedAt,
|
||||
LastUsed: refresh.LastUsed,
|
||||
}
|
||||
|
||||
// Try to retrieve an existing OfflineSession object for the corresponding user.
|
||||
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
}
|
||||
offlineSessions := storage.OfflineSessions{
|
||||
UserID: refresh.Claims.UserID,
|
||||
ConnID: refresh.ConnectorID,
|
||||
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||
}
|
||||
offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef
|
||||
|
||||
// Create a new OfflineSession object for the user and add a reference object for
|
||||
// the newly recieved refreshtoken.
|
||||
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
|
||||
s.logger.Errorf("failed to create offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
|
||||
// Delete old refresh token from storage.
|
||||
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
|
||||
s.logger.Errorf("failed to delete refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update existing OfflineSession obj with new RefreshTokenRef.
|
||||
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
||||
old.Refresh[tokenRef.ClientID] = &tokenRef
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
s.logger.Errorf("failed to update offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
||||
}
|
||||
|
@ -815,6 +884,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
return
|
||||
}
|
||||
|
||||
lastUsed := s.now()
|
||||
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
||||
if old.Token != refresh.Token {
|
||||
return old, errors.New("refresh token claimed twice")
|
||||
|
@ -828,14 +898,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
old.Claims.EmailVerified = ident.EmailVerified
|
||||
old.Claims.Groups = ident.Groups
|
||||
old.ConnectorData = ident.ConnectorData
|
||||
old.LastUsed = s.now()
|
||||
old.LastUsed = lastUsed
|
||||
return old, nil
|
||||
}
|
||||
|
||||
// Update LastUsed time stamp in refresh token reference object
|
||||
// in offline session for the user.
|
||||
if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
||||
if old.Refresh[refresh.ClientID].ID != refresh.ID {
|
||||
return old, errors.New("refresh token invalid")
|
||||
}
|
||||
old.Refresh[refresh.ClientID].LastUsed = lastUsed
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
s.logger.Errorf("failed to update offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Update refresh token in the storage.
|
||||
if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
|
||||
s.logger.Errorf("failed to update refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
||||
}
|
||||
|
||||
|
|
|
@ -971,3 +971,108 @@ func TestKeyCacher(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
type oauth2Client struct {
|
||||
config *oauth2.Config
|
||||
token *oauth2.Token
|
||||
server *httptest.Server
|
||||
}
|
||||
|
||||
// TestRefreshTokenFlow tests the refresh token code flow for oauth2. The test verifies
|
||||
// that only valid refresh tokens can be used to refresh an expired token.
|
||||
func TestRefreshTokenFlow(t *testing.T) {
|
||||
state := "state"
|
||||
now := func() time.Time { return time.Now() }
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||
c.Now = now
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get provider: %v", err)
|
||||
}
|
||||
|
||||
var oauth2Client oauth2Client
|
||||
|
||||
oauth2Client.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/callback" {
|
||||
// User is visiting app first time. Redirect to dex.
|
||||
http.Redirect(w, r, oauth2Client.config.AuthCodeURL(state), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// User is at '/callback' so they were just redirected _from_ dex.
|
||||
q := r.URL.Query()
|
||||
|
||||
if errType := q.Get("error"); errType != "" {
|
||||
if desc := q.Get("error_description"); desc != "" {
|
||||
t.Errorf("got error from server %s: %s", errType, desc)
|
||||
} else {
|
||||
t.Errorf("got error from server %s", errType)
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Grab code, exchange for token.
|
||||
if code := q.Get("code"); code != "" {
|
||||
token, err := oauth2Client.config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
t.Errorf("failed to exchange code for token: %v", err)
|
||||
return
|
||||
}
|
||||
oauth2Client.token = token
|
||||
}
|
||||
|
||||
// Ensure state matches.
|
||||
if gotState := q.Get("state"); gotState != state {
|
||||
t.Errorf("state did not match, want=%q got=%q", state, gotState)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}))
|
||||
defer oauth2Client.server.Close()
|
||||
|
||||
// Register the client above with dex.
|
||||
redirectURL := oauth2Client.server.URL + "/callback"
|
||||
client := storage.Client{
|
||||
ID: "testclient",
|
||||
Secret: "testclientsecret",
|
||||
RedirectURIs: []string{redirectURL},
|
||||
}
|
||||
if err := s.storage.CreateClient(client); err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
oauth2Client.config = &oauth2.Config{
|
||||
ClientID: client.ID,
|
||||
ClientSecret: client.Secret,
|
||||
Endpoint: p.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "email", "offline_access"},
|
||||
RedirectURL: redirectURL,
|
||||
}
|
||||
|
||||
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
|
||||
tok := &oauth2.Token{
|
||||
RefreshToken: oauth2Client.token.RefreshToken,
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
|
||||
// Login in again to recieve a new token.
|
||||
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
|
||||
// try to refresh expired token with old refresh token.
|
||||
newToken, err := oauth2Client.config.TokenSource(ctx, tok).Token()
|
||||
if newToken != nil {
|
||||
t.Errorf("Token refreshed with invalid refresh token.")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
|
|||
{"RefreshTokenCRUD", testRefreshTokenCRUD},
|
||||
{"PasswordCRUD", testPasswordCRUD},
|
||||
{"KeysCRUD", testKeysCRUD},
|
||||
{"OfflineSessionCRUD", testOfflineSessionCRUD},
|
||||
{"GarbageCollection", testGC},
|
||||
{"TimezoneSupport", testTimezones},
|
||||
})
|
||||
|
@ -340,6 +341,60 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
|
|||
|
||||
}
|
||||
|
||||
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
||||
session := storage.OfflineSessions{
|
||||
UserID: "User",
|
||||
ConnID: "Conn",
|
||||
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||
}
|
||||
|
||||
// Creating an OfflineSession with an empty Refresh list to ensure that
|
||||
// an empty map is translated as expected by the storage.
|
||||
if err := s.CreateOfflineSessions(session); err != nil {
|
||||
t.Fatalf("create offline session: %v", err)
|
||||
}
|
||||
|
||||
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
|
||||
gr, err := s.GetOfflineSessions(userID, connID)
|
||||
if err != nil {
|
||||
t.Errorf("get offline session: %v", err)
|
||||
return
|
||||
}
|
||||
if diff := pretty.Compare(want, gr); diff != "" {
|
||||
t.Errorf("offline session retrieved from storage did not match: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
getAndCompare("User", "Conn", session)
|
||||
|
||||
id := storage.NewID()
|
||||
tokenRef := storage.RefreshTokenRef{
|
||||
ID: id,
|
||||
ClientID: "client_id",
|
||||
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
||||
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
||||
}
|
||||
session.Refresh[tokenRef.ClientID] = &tokenRef
|
||||
|
||||
if err := s.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
||||
old.Refresh[tokenRef.ClientID] = &tokenRef
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to update offline session: %v", err)
|
||||
}
|
||||
|
||||
getAndCompare("User", "Conn", session)
|
||||
|
||||
if err := s.DeleteOfflineSessions(session.UserID, session.ConnID); err != nil {
|
||||
t.Fatalf("failed to delete offline session: %v", err)
|
||||
}
|
||||
|
||||
if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound {
|
||||
t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func testKeysCRUD(t *testing.T, s storage.Storage) {
|
||||
updateAndCompare := func(k storage.Keys) {
|
||||
err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) {
|
||||
|
|
|
@ -58,6 +58,12 @@ func (c *client) idToName(s string) string {
|
|||
return idToName(s, c.hash)
|
||||
}
|
||||
|
||||
// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name.
|
||||
// This is used when more than one field is used to uniquely identify the object.
|
||||
func (c *client) offlineTokenName(userID string, connID string) string {
|
||||
return offlineTokenName(userID, connID, c.hash)
|
||||
}
|
||||
|
||||
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
|
||||
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
||||
|
||||
|
@ -65,6 +71,12 @@ func idToName(s string, h func() hash.Hash) string {
|
|||
return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=")
|
||||
}
|
||||
|
||||
func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
|
||||
h().Write([]byte(userID))
|
||||
h().Write([]byte(connID))
|
||||
return strings.TrimRight(encoding.EncodeToString(h().Sum(nil)), "=")
|
||||
}
|
||||
|
||||
func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
|
||||
basePath := "apis/"
|
||||
if apiVersion == "v1" {
|
||||
|
|
|
@ -21,6 +21,7 @@ const (
|
|||
kindRefreshToken = "RefreshToken"
|
||||
kindKeys = "SigningKey"
|
||||
kindPassword = "Password"
|
||||
kindOfflineSessions = "OfflineSessions"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -30,6 +31,7 @@ const (
|
|||
resourceRefreshToken = "refreshtokens"
|
||||
resourceKeys = "signingkeies" // Kubernetes attempts to pluralize.
|
||||
resourcePassword = "passwords"
|
||||
resourceOfflineSessions = "offlinesessions"
|
||||
)
|
||||
|
||||
// Config values for the Kubernetes storage type.
|
||||
|
@ -156,6 +158,10 @@ func (cli *client) CreateRefresh(r storage.RefreshToken) error {
|
|||
return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
|
||||
}
|
||||
|
||||
func (cli *client) CreateOfflineSessions(o storage.OfflineSessions) error {
|
||||
return cli.post(resourceOfflineSessions, cli.fromStorageOfflineSessions(o))
|
||||
}
|
||||
|
||||
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||
var req AuthRequest
|
||||
if err := cli.get(resourceAuthRequest, id, &req); err != nil {
|
||||
|
@ -235,6 +241,25 @@ func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (cli *client) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
|
||||
o, err := cli.getOfflineSessions(userID, connID)
|
||||
if err != nil {
|
||||
return storage.OfflineSessions{}, err
|
||||
}
|
||||
return toStorageOfflineSessions(o), nil
|
||||
}
|
||||
|
||||
func (cli *client) getOfflineSessions(userID string, connID string) (o OfflineSessions, err error) {
|
||||
name := cli.offlineTokenName(userID, connID)
|
||||
if err = cli.get(resourceOfflineSessions, name, &o); err != nil {
|
||||
return OfflineSessions{}, err
|
||||
}
|
||||
if userID != o.UserID || connID != o.ConnID {
|
||||
return OfflineSessions{}, fmt.Errorf("get offline session: wrong session retrieved")
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (cli *client) ListClients() ([]storage.Client, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
@ -292,6 +317,15 @@ func (cli *client) DeletePassword(email string) error {
|
|||
return cli.delete(resourcePassword, p.ObjectMeta.Name)
|
||||
}
|
||||
|
||||
func (cli *client) DeleteOfflineSessions(userID string, connID string) error {
|
||||
// Check for hash collition.
|
||||
o, err := cli.getOfflineSessions(userID, connID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cli.delete(resourceOfflineSessions, o.ObjectMeta.Name)
|
||||
}
|
||||
|
||||
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
r, err := cli.getRefreshToken(id)
|
||||
if err != nil {
|
||||
|
@ -342,6 +376,22 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor
|
|||
return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword)
|
||||
}
|
||||
|
||||
func (cli *client) UpdateOfflineSessions(userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||
o, err := cli.getOfflineSessions(userID, connID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updated, err := updater(toStorageOfflineSessions(o))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newOfflineSessions := cli.fromStorageOfflineSessions(updated)
|
||||
newOfflineSessions.ObjectMeta = o.ObjectMeta
|
||||
return cli.put(resourceOfflineSessions, o.ObjectMeta.Name, newOfflineSessions)
|
||||
}
|
||||
|
||||
func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
||||
firstUpdate := false
|
||||
var keys Keys
|
||||
|
|
|
@ -66,6 +66,14 @@ var thirdPartyResources = []k8sapi.ThirdPartyResource{
|
|||
Description: "Passwords managed by the OIDC server.",
|
||||
Versions: []k8sapi.APIVersion{{Name: "v1"}},
|
||||
},
|
||||
{
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: "offline-sessions.oidc.coreos.com",
|
||||
},
|
||||
TypeMeta: tprMeta,
|
||||
Description: "User sessions with an active refresh token.",
|
||||
Versions: []k8sapi.APIVersion{{Name: "v1"}},
|
||||
},
|
||||
}
|
||||
|
||||
// There will only ever be a single keys resource. Maintain this by setting a
|
||||
|
@ -465,3 +473,38 @@ func toStorageKeys(keys Keys) storage.Keys {
|
|||
NextRotation: keys.NextRotation,
|
||||
}
|
||||
}
|
||||
|
||||
// OfflineSessions is a mirrored struct from storage with JSON struct tags and Kubernetes
|
||||
// type metadata.
|
||||
type OfflineSessions struct {
|
||||
k8sapi.TypeMeta `json:",inline"`
|
||||
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||
|
||||
UserID string `json:"userID,omitempty"`
|
||||
ConnID string `json:"connID,omitempty"`
|
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||
}
|
||||
|
||||
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
|
||||
return OfflineSessions{
|
||||
TypeMeta: k8sapi.TypeMeta{
|
||||
Kind: kindOfflineSessions,
|
||||
APIVersion: cli.apiVersion,
|
||||
},
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: cli.offlineTokenName(o.UserID, o.ConnID),
|
||||
Namespace: cli.namespace,
|
||||
},
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
||||
return storage.OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ func New(logger logrus.FieldLogger) storage.Storage {
|
|||
refreshTokens: make(map[string]storage.RefreshToken),
|
||||
authReqs: make(map[string]storage.AuthRequest),
|
||||
passwords: make(map[string]storage.Password),
|
||||
offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
@ -42,12 +43,18 @@ type memStorage struct {
|
|||
refreshTokens map[string]storage.RefreshToken
|
||||
authReqs map[string]storage.AuthRequest
|
||||
passwords map[string]storage.Password
|
||||
offlineSessions map[offlineSessionID]storage.OfflineSessions
|
||||
|
||||
keys storage.Keys
|
||||
|
||||
logger logrus.FieldLogger
|
||||
}
|
||||
|
||||
type offlineSessionID struct {
|
||||
userID string
|
||||
connID string
|
||||
}
|
||||
|
||||
func (s *memStorage) tx(f func()) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
@ -130,6 +137,32 @@ func (s *memStorage) CreatePassword(p storage.Password) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error) {
|
||||
id := offlineSessionID{
|
||||
userID: o.UserID,
|
||||
connID: o.ConnID,
|
||||
}
|
||||
s.tx(func() {
|
||||
if _, ok := s.offlineSessions[id]; ok {
|
||||
err = storage.ErrAlreadyExists
|
||||
} else {
|
||||
s.offlineSessions[id] = o
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
|
||||
s.tx(func() {
|
||||
var ok bool
|
||||
if c, ok = s.authCodes[id]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) GetPassword(email string) (p storage.Password, err error) {
|
||||
email = strings.ToLower(email)
|
||||
s.tx(func() {
|
||||
|
@ -156,10 +189,10 @@ func (s *memStorage) GetKeys() (keys storage.Keys, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) GetRefresh(token string) (tok storage.RefreshToken, err error) {
|
||||
func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error) {
|
||||
s.tx(func() {
|
||||
var ok bool
|
||||
if tok, ok = s.refreshTokens[token]; !ok {
|
||||
if tok, ok = s.refreshTokens[id]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
|
@ -178,6 +211,21 @@ func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err err
|
|||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage.OfflineSessions, err error) {
|
||||
id := offlineSessionID{
|
||||
userID: userID,
|
||||
connID: connID,
|
||||
}
|
||||
s.tx(func() {
|
||||
var ok bool
|
||||
if o, ok = s.offlineSessions[id]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) ListClients() (clients []storage.Client, err error) {
|
||||
s.tx(func() {
|
||||
for _, client := range s.clients {
|
||||
|
@ -228,13 +276,13 @@ func (s *memStorage) DeleteClient(id string) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) DeleteRefresh(token string) (err error) {
|
||||
func (s *memStorage) DeleteRefresh(id string) (err error) {
|
||||
s.tx(func() {
|
||||
if _, ok := s.refreshTokens[token]; !ok {
|
||||
if _, ok := s.refreshTokens[id]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
delete(s.refreshTokens, token)
|
||||
delete(s.refreshTokens, id)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -261,13 +309,17 @@ func (s *memStorage) DeleteAuthRequest(id string) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
|
||||
func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err error) {
|
||||
id := offlineSessionID{
|
||||
userID: userID,
|
||||
connID: connID,
|
||||
}
|
||||
s.tx(func() {
|
||||
var ok bool
|
||||
if c, ok = s.authCodes[id]; !ok {
|
||||
if _, ok := s.offlineSessions[id]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
delete(s.offlineSessions, id)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -338,3 +390,21 @@ func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.Refres
|
|||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) {
|
||||
id := offlineSessionID{
|
||||
userID: userID,
|
||||
connID: connID,
|
||||
}
|
||||
s.tx(func() {
|
||||
r, ok := s.offlineSessions[id]
|
||||
if !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
if r, err = updater(r); err == nil {
|
||||
s.offlineSessions[id] = r
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -624,6 +624,75 @@ func scanPassword(s scanner) (p storage.Password, err error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||
_, err := c.Exec(`
|
||||
insert into offline_session (
|
||||
user_id, conn_id, refresh
|
||||
)
|
||||
values (
|
||||
$1, $2, $3
|
||||
);
|
||||
`,
|
||||
s.UserID, s.ConnID, encoder(s.Refresh),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert offline session: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
s, err := getOfflineSessions(tx, userID, connID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newSession, err := updater(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`
|
||||
update offline_session
|
||||
set
|
||||
refresh = $1
|
||||
where user_id = $2 AND conn_id = $3;
|
||||
`,
|
||||
encoder(newSession.Refresh), s.UserID, s.ConnID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update offline session: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
|
||||
return getOfflineSessions(c, userID, connID)
|
||||
}
|
||||
|
||||
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
|
||||
return scanOfflineSessions(q.QueryRow(`
|
||||
select
|
||||
user_id, conn_id, refresh
|
||||
from offline_session
|
||||
where user_id = $1 AND conn_id = $2;
|
||||
`, userID, connID))
|
||||
}
|
||||
|
||||
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
|
||||
err = s.Scan(
|
||||
&o.UserID, &o.ConnID, decoder(&o.Refresh),
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return o, storage.ErrNotFound
|
||||
}
|
||||
return o, fmt.Errorf("select offline session: %v", err)
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
|
||||
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) }
|
||||
func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) }
|
||||
|
@ -632,6 +701,24 @@ func (c *conn) DeletePassword(email string) error {
|
|||
return c.delete("password", "email", strings.ToLower(email))
|
||||
}
|
||||
|
||||
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
|
||||
result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID)
|
||||
}
|
||||
|
||||
// For now mandate that the driver implements RowsAffected. If we ever need to support
|
||||
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
|
||||
n, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %v", err)
|
||||
}
|
||||
if n < 1 {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do NOT call directly. Does not escape table.
|
||||
func (c *conn) delete(table, field, id string) error {
|
||||
result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id)
|
||||
|
|
|
@ -153,6 +153,7 @@ var migrations = []migration{
|
|||
signing_key_pub bytea not null, -- JSON object
|
||||
next_rotation timestamptz not null
|
||||
);
|
||||
|
||||
`,
|
||||
},
|
||||
{
|
||||
|
@ -165,4 +166,14 @@ var migrations = []migration{
|
|||
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
|
||||
`,
|
||||
},
|
||||
{
|
||||
stmt: `
|
||||
create table offline_session (
|
||||
user_id text not null,
|
||||
conn_id text not null,
|
||||
refresh bytea not null,
|
||||
PRIMARY KEY (user_id, conn_id)
|
||||
);
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ type Storage interface {
|
|||
CreateAuthCode(c AuthCode) error
|
||||
CreateRefresh(r RefreshToken) error
|
||||
CreatePassword(p Password) error
|
||||
CreateOfflineSessions(s OfflineSessions) error
|
||||
|
||||
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
|
||||
// requests that way instead of using ErrNotFound.
|
||||
|
@ -61,6 +62,7 @@ type Storage interface {
|
|||
GetKeys() (Keys, error)
|
||||
GetRefresh(id string) (RefreshToken, error)
|
||||
GetPassword(email string) (Password, error)
|
||||
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
||||
|
||||
ListClients() ([]Client, error)
|
||||
ListRefreshTokens() ([]RefreshToken, error)
|
||||
|
@ -72,6 +74,7 @@ type Storage interface {
|
|||
DeleteClient(id string) error
|
||||
DeleteRefresh(id string) error
|
||||
DeletePassword(email string) error
|
||||
DeleteOfflineSessions(userID string, connID string) error
|
||||
|
||||
// Update methods take a function for updating an object then performs that update within
|
||||
// a transaction. "updater" functions may be called multiple times by a single update call.
|
||||
|
@ -92,6 +95,7 @@ type Storage interface {
|
|||
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
|
||||
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
|
||||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
|
||||
|
||||
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
||||
GarbageCollect(now time.Time) (GCResult, error)
|
||||
|
@ -241,6 +245,30 @@ type RefreshToken struct {
|
|||
Nonce string
|
||||
}
|
||||
|
||||
// RefreshTokenRef is a reference object that contains metadata about refresh tokens.
|
||||
type RefreshTokenRef struct {
|
||||
ID string
|
||||
|
||||
// Client the refresh token is valid for.
|
||||
ClientID string
|
||||
|
||||
CreatedAt time.Time
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
// OfflineSessions objects are sessions pertaining to users with refresh tokens.
|
||||
type OfflineSessions struct {
|
||||
// UserID of an end user who has logged in to the server.
|
||||
UserID string
|
||||
|
||||
// The ID of the connector used to login the user.
|
||||
ConnID string
|
||||
|
||||
// Refresh is a hash table of refresh token reference objects
|
||||
// indexed by the ClientID of the refresh token.
|
||||
Refresh map[string]*RefreshTokenRef
|
||||
}
|
||||
|
||||
// Password is an email to password mapping managed by the storage.
|
||||
type Password struct {
|
||||
// Email and identifying name of the password. Emails are assumed to be valid and
|
||||
|
|
Reference in a new issue