Merge pull request #1180 from JoelSpeed/refresh-tokens
Implement refreshing with Google
This commit is contained in:
commit
b1e98d8590
8 changed files with 172 additions and 58 deletions
|
@ -3,12 +3,14 @@ package oidc
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -60,6 +62,11 @@ var brokenAuthHeaderDomains = []string{
|
|||
"oktapreview.com",
|
||||
}
|
||||
|
||||
// connectorData stores information for sessions authenticated by this connector
|
||||
type connectorData struct {
|
||||
RefreshToken []byte
|
||||
}
|
||||
|
||||
// Detect auth header provider issues for known providers. This lets users
|
||||
// avoid having to explicitly set "basicAuthUnsupported" in their config.
|
||||
//
|
||||
|
@ -167,14 +174,19 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
|
|||
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
|
||||
}
|
||||
|
||||
var opts []oauth2.AuthCodeOption
|
||||
if len(c.hostedDomains) > 0 {
|
||||
preferredDomain := c.hostedDomains[0]
|
||||
if len(c.hostedDomains) > 1 {
|
||||
preferredDomain = "*"
|
||||
}
|
||||
return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil
|
||||
opts = append(opts, oauth2.SetAuthURLParam("hd", preferredDomain))
|
||||
}
|
||||
return c.oauth2Config.AuthCodeURL(state), nil
|
||||
|
||||
if s.OfflineAccess {
|
||||
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
}
|
||||
return c.oauth2Config.AuthCodeURL(state, opts...), nil
|
||||
}
|
||||
|
||||
type oauth2Error struct {
|
||||
|
@ -199,11 +211,35 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
|
|||
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
|
||||
}
|
||||
|
||||
return c.createIdentity(r.Context(), identity, token)
|
||||
}
|
||||
|
||||
// Refresh is used to refresh a session with the refresh token provided by the IdP
|
||||
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
|
||||
cd := connectorData{}
|
||||
err := json.Unmarshal(identity.ConnectorData, &cd)
|
||||
if err != nil {
|
||||
return identity, fmt.Errorf("oidc: failed to unmarshal connector data: %v", err)
|
||||
}
|
||||
|
||||
t := &oauth2.Token{
|
||||
RefreshToken: string(cd.RefreshToken),
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
token, err := c.oauth2Config.TokenSource(ctx, t).Token()
|
||||
if err != nil {
|
||||
return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err)
|
||||
}
|
||||
|
||||
return c.createIdentity(ctx, identity, token)
|
||||
}
|
||||
|
||||
func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return identity, errors.New("oidc: no id_token in token response")
|
||||
}
|
||||
idToken, err := c.verifier.Verify(r.Context(), rawIDToken)
|
||||
idToken, err := c.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
|
||||
}
|
||||
|
@ -215,7 +251,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
|
|||
|
||||
// We immediately want to run getUserInfo if configured before we validate the claims
|
||||
if c.getUserInfo {
|
||||
userInfo, err := c.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token))
|
||||
userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
|
||||
if err != nil {
|
||||
return identity, fmt.Errorf("oidc: error loading userinfo: %v", err)
|
||||
}
|
||||
|
@ -260,11 +296,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
|
|||
}
|
||||
}
|
||||
|
||||
cd := connectorData{
|
||||
RefreshToken: []byte(token.RefreshToken),
|
||||
}
|
||||
|
||||
connData, err := json.Marshal(&cd)
|
||||
if err != nil {
|
||||
return identity, fmt.Errorf("oidc: failed to encode connector data: %v", err)
|
||||
}
|
||||
|
||||
identity = connector.Identity{
|
||||
UserID: idToken.Subject,
|
||||
Username: name,
|
||||
Email: email,
|
||||
EmailVerified: emailVerified,
|
||||
ConnectorData: connData,
|
||||
}
|
||||
|
||||
if c.userIDKey != "" {
|
||||
|
@ -277,8 +323,3 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
|
|||
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
// Refresh is implemented for backwards compatibility, even though it's a no-op.
|
||||
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
|
||||
return identity, nil
|
||||
}
|
||||
|
|
|
@ -505,7 +505,45 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth
|
|||
s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q",
|
||||
authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups)
|
||||
|
||||
return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil
|
||||
returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID
|
||||
_, ok := conn.(connector.RefreshConnector)
|
||||
if !ok {
|
||||
return returnURL, nil
|
||||
}
|
||||
|
||||
// Try to retrieve an existing OfflineSession object for the corresponding user.
|
||||
if session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID); err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get offline session: %v", err)
|
||||
return "", err
|
||||
}
|
||||
offlineSessions := storage.OfflineSessions{
|
||||
UserID: identity.UserID,
|
||||
ConnID: authReq.ConnectorID,
|
||||
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||
ConnectorData: identity.ConnectorData,
|
||||
}
|
||||
|
||||
// Create a new OfflineSession object for the user and add a reference object for
|
||||
// the newly received refreshtoken.
|
||||
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
|
||||
s.logger.Errorf("failed to create offline session: %v", err)
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
// Update existing OfflineSession obj with new RefreshTokenRef.
|
||||
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
||||
if len(identity.ConnectorData) > 0 {
|
||||
old.ConnectorData = identity.ConnectorData
|
||||
}
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
s.logger.Errorf("failed to update offline session: %v", err)
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return returnURL, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -962,6 +1000,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
scopes = requestedScopes
|
||||
}
|
||||
|
||||
var connectorData []byte
|
||||
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get offline session: %v", err)
|
||||
return
|
||||
}
|
||||
} else if len(refresh.ConnectorData) > 0 {
|
||||
// Use the old connector data if it exists, should be deleted once used
|
||||
connectorData = refresh.ConnectorData
|
||||
} else {
|
||||
connectorData = session.ConnectorData
|
||||
}
|
||||
|
||||
conn, err := s.getConnector(refresh.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
|
||||
|
@ -975,7 +1026,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
Email: refresh.Claims.Email,
|
||||
EmailVerified: refresh.Claims.EmailVerified,
|
||||
Groups: refresh.Claims.Groups,
|
||||
ConnectorData: refresh.ConnectorData,
|
||||
ConnectorData: connectorData,
|
||||
}
|
||||
|
||||
// Can the connector refresh the identity? If so, attempt to refresh the data
|
||||
|
@ -1041,8 +1092,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
old.Claims.Email = ident.Email
|
||||
old.Claims.EmailVerified = ident.EmailVerified
|
||||
old.Claims.Groups = ident.Groups
|
||||
old.ConnectorData = ident.ConnectorData
|
||||
old.LastUsed = lastUsed
|
||||
|
||||
// ConnectorData has been moved to OfflineSession
|
||||
old.ConnectorData = []byte{}
|
||||
return old, nil
|
||||
}
|
||||
|
||||
|
@ -1053,6 +1106,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
return old, errors.New("refresh token invalid")
|
||||
}
|
||||
old.Refresh[refresh.ClientID].LastUsed = lastUsed
|
||||
old.ConnectorData = ident.ConnectorData
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
s.logger.Errorf("failed to update offline session: %v", err)
|
||||
|
|
|
@ -518,6 +518,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
|||
UserID: userID1,
|
||||
ConnID: "Conn1",
|
||||
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
}
|
||||
|
||||
// Creating an OfflineSession with an empty Refresh list to ensure that
|
||||
|
@ -535,6 +536,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
|||
UserID: userID2,
|
||||
ConnID: "Conn2",
|
||||
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
}
|
||||
|
||||
if err := s.CreateOfflineSessions(session2); err != nil {
|
||||
|
|
|
@ -191,6 +191,7 @@ type OfflineSessions struct {
|
|||
UserID string `json:"user_id,omitempty"`
|
||||
ConnID string `json:"conn_id,omitempty"`
|
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||
ConnectorData []byte `json:"connectorData,omitempty"`
|
||||
}
|
||||
|
||||
func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
|
||||
|
@ -198,6 +199,7 @@ func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
|
|||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
ConnectorData: o.ConnectorData,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -206,6 +208,7 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
|||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
ConnectorData: o.ConnectorData,
|
||||
}
|
||||
if s.Refresh == nil {
|
||||
// Server code assumes this will be non-nil.
|
||||
|
|
|
@ -555,6 +555,7 @@ type OfflineSessions struct {
|
|||
UserID string `json:"userID,omitempty"`
|
||||
ConnID string `json:"connID,omitempty"`
|
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||
ConnectorData []byte `json:"connectorData,omitempty"`
|
||||
}
|
||||
|
||||
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
|
||||
|
@ -570,6 +571,7 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline
|
|||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
ConnectorData: o.ConnectorData,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -578,6 +580,7 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
|||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
ConnectorData: o.ConnectorData,
|
||||
}
|
||||
if s.Refresh == nil {
|
||||
// Server code assumes this will be non-nil.
|
||||
|
|
|
@ -655,13 +655,13 @@ func scanPassword(s scanner) (p storage.Password, err error) {
|
|||
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||
_, err := c.Exec(`
|
||||
insert into offline_session (
|
||||
user_id, conn_id, refresh
|
||||
user_id, conn_id, refresh, connector_data
|
||||
)
|
||||
values (
|
||||
$1, $2, $3
|
||||
$1, $2, $3, $4
|
||||
);
|
||||
`,
|
||||
s.UserID, s.ConnID, encoder(s.Refresh),
|
||||
s.UserID, s.ConnID, encoder(s.Refresh), s.ConnectorData,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
|
@ -686,10 +686,11 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
|
|||
_, err = tx.Exec(`
|
||||
update offline_session
|
||||
set
|
||||
refresh = $1
|
||||
where user_id = $2 AND conn_id = $3;
|
||||
refresh = $1,
|
||||
connector_data = $2
|
||||
where user_id = $3 AND conn_id = $4;
|
||||
`,
|
||||
encoder(newSession.Refresh), s.UserID, s.ConnID,
|
||||
encoder(newSession.Refresh), newSession.ConnectorData, s.UserID, s.ConnID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update offline session: %v", err)
|
||||
|
@ -705,7 +706,7 @@ func (c *conn) GetOfflineSessions(userID string, connID string) (storage.Offline
|
|||
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
|
||||
return scanOfflineSessions(q.QueryRow(`
|
||||
select
|
||||
user_id, conn_id, refresh
|
||||
user_id, conn_id, refresh, connector_data
|
||||
from offline_session
|
||||
where user_id = $1 AND conn_id = $2;
|
||||
`, userID, connID))
|
||||
|
@ -713,7 +714,7 @@ func getOfflineSessions(q querier, userID string, connID string) (storage.Offlin
|
|||
|
||||
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
|
||||
err = s.Scan(
|
||||
&o.UserID, &o.ConnID, decoder(&o.Refresh),
|
||||
&o.UserID, &o.ConnID, decoder(&o.Refresh), &o.ConnectorData,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
|
@ -202,4 +202,11 @@ var migrations = []migration{
|
|||
add column claims_preferred_username text not null default '';`,
|
||||
},
|
||||
},
|
||||
{
|
||||
stmts: []string{`
|
||||
alter table offline_session
|
||||
add column connector_data bytea;
|
||||
`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -273,6 +273,9 @@ type OfflineSessions struct {
|
|||
// Refresh is a hash table of refresh token reference objects
|
||||
// indexed by the ClientID of the refresh token.
|
||||
Refresh map[string]*RefreshTokenRef
|
||||
|
||||
// Authentication data provided by an upstream source.
|
||||
ConnectorData []byte
|
||||
}
|
||||
|
||||
// Password is an email to password mapping managed by the storage.
|
||||
|
|
Reference in a new issue