Merge pull request #1180 from JoelSpeed/refresh-tokens

Implement refreshing with Google
This commit is contained in:
Nándor István Krácser 2019-11-19 17:39:23 +01:00 committed by GitHub
commit b1e98d8590
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 172 additions and 58 deletions

View file

@ -3,12 +3,14 @@ package oidc
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"time"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -60,6 +62,11 @@ var brokenAuthHeaderDomains = []string{
"oktapreview.com", "oktapreview.com",
} }
// connectorData stores information for sessions authenticated by this connector
type connectorData struct {
RefreshToken []byte
}
// Detect auth header provider issues for known providers. This lets users // Detect auth header provider issues for known providers. This lets users
// avoid having to explicitly set "basicAuthUnsupported" in their config. // avoid having to explicitly set "basicAuthUnsupported" in their config.
// //
@ -167,14 +174,19 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
} }
var opts []oauth2.AuthCodeOption
if len(c.hostedDomains) > 0 { if len(c.hostedDomains) > 0 {
preferredDomain := c.hostedDomains[0] preferredDomain := c.hostedDomains[0]
if len(c.hostedDomains) > 1 { if len(c.hostedDomains) > 1 {
preferredDomain = "*" preferredDomain = "*"
} }
return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil opts = append(opts, oauth2.SetAuthURLParam("hd", preferredDomain))
} }
return c.oauth2Config.AuthCodeURL(state), nil
if s.OfflineAccess {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
}
return c.oauth2Config.AuthCodeURL(state, opts...), nil
} }
type oauth2Error struct { type oauth2Error struct {
@ -199,11 +211,35 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
return identity, fmt.Errorf("oidc: failed to get token: %v", err) return identity, fmt.Errorf("oidc: failed to get token: %v", err)
} }
return c.createIdentity(r.Context(), identity, token)
}
// Refresh is used to refresh a session with the refresh token provided by the IdP
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
cd := connectorData{}
err := json.Unmarshal(identity.ConnectorData, &cd)
if err != nil {
return identity, fmt.Errorf("oidc: failed to unmarshal connector data: %v", err)
}
t := &oauth2.Token{
RefreshToken: string(cd.RefreshToken),
Expiry: time.Now().Add(-time.Hour),
}
token, err := c.oauth2Config.TokenSource(ctx, t).Token()
if err != nil {
return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err)
}
return c.createIdentity(ctx, identity, token)
}
func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) {
rawIDToken, ok := token.Extra("id_token").(string) rawIDToken, ok := token.Extra("id_token").(string)
if !ok { if !ok {
return identity, errors.New("oidc: no id_token in token response") return identity, errors.New("oidc: no id_token in token response")
} }
idToken, err := c.verifier.Verify(r.Context(), rawIDToken) idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
} }
@ -215,7 +251,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
// We immediately want to run getUserInfo if configured before we validate the claims // We immediately want to run getUserInfo if configured before we validate the claims
if c.getUserInfo { if c.getUserInfo {
userInfo, err := c.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token)) userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
if err != nil { if err != nil {
return identity, fmt.Errorf("oidc: error loading userinfo: %v", err) return identity, fmt.Errorf("oidc: error loading userinfo: %v", err)
} }
@ -260,11 +296,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
} }
} }
cd := connectorData{
RefreshToken: []byte(token.RefreshToken),
}
connData, err := json.Marshal(&cd)
if err != nil {
return identity, fmt.Errorf("oidc: failed to encode connector data: %v", err)
}
identity = connector.Identity{ identity = connector.Identity{
UserID: idToken.Subject, UserID: idToken.Subject,
Username: name, Username: name,
Email: email, Email: email,
EmailVerified: emailVerified, EmailVerified: emailVerified,
ConnectorData: connData,
} }
if c.userIDKey != "" { if c.userIDKey != "" {
@ -277,8 +323,3 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
return identity, nil return identity, nil
} }
// Refresh is implemented for backwards compatibility, even though it's a no-op.
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
return identity, nil
}

View file

@ -505,7 +505,45 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth
s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q", s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q",
authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups) authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups)
return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID
_, ok := conn.(connector.RefreshConnector)
if !ok {
return returnURL, nil
}
// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return "", err
}
offlineSessions := storage.OfflineSessions{
UserID: identity.UserID,
ConnID: authReq.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: identity.ConnectorData,
}
// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
return "", err
}
} else {
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if len(identity.ConnectorData) > 0 {
old.ConnectorData = identity.ConnectorData
}
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
return "", err
}
}
return returnURL, nil
} }
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
@ -962,6 +1000,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
scopes = requestedScopes scopes = requestedScopes
} }
var connectorData []byte
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return
}
} else if len(refresh.ConnectorData) > 0 {
// Use the old connector data if it exists, should be deleted once used
connectorData = refresh.ConnectorData
} else {
connectorData = session.ConnectorData
}
conn, err := s.getConnector(refresh.ConnectorID) conn, err := s.getConnector(refresh.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
@ -975,7 +1026,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Email: refresh.Claims.Email, Email: refresh.Claims.Email,
EmailVerified: refresh.Claims.EmailVerified, EmailVerified: refresh.Claims.EmailVerified,
Groups: refresh.Claims.Groups, Groups: refresh.Claims.Groups,
ConnectorData: refresh.ConnectorData, ConnectorData: connectorData,
} }
// Can the connector refresh the identity? If so, attempt to refresh the data // Can the connector refresh the identity? If so, attempt to refresh the data
@ -1041,8 +1092,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
old.Claims.Email = ident.Email old.Claims.Email = ident.Email
old.Claims.EmailVerified = ident.EmailVerified old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups old.Claims.Groups = ident.Groups
old.ConnectorData = ident.ConnectorData
old.LastUsed = lastUsed old.LastUsed = lastUsed
// ConnectorData has been moved to OfflineSession
old.ConnectorData = []byte{}
return old, nil return old, nil
} }
@ -1053,6 +1106,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return old, errors.New("refresh token invalid") return old, errors.New("refresh token invalid")
} }
old.Refresh[refresh.ClientID].LastUsed = lastUsed old.Refresh[refresh.ClientID].LastUsed = lastUsed
old.ConnectorData = ident.ConnectorData
return old, nil return old, nil
}); err != nil { }); err != nil {
s.logger.Errorf("failed to update offline session: %v", err) s.logger.Errorf("failed to update offline session: %v", err)

View file

@ -518,6 +518,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
UserID: userID1, UserID: userID1,
ConnID: "Conn1", ConnID: "Conn1",
Refresh: make(map[string]*storage.RefreshTokenRef), Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
} }
// Creating an OfflineSession with an empty Refresh list to ensure that // Creating an OfflineSession with an empty Refresh list to ensure that
@ -535,6 +536,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
UserID: userID2, UserID: userID2,
ConnID: "Conn2", ConnID: "Conn2",
Refresh: make(map[string]*storage.RefreshTokenRef), Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
} }
if err := s.CreateOfflineSessions(session2); err != nil { if err := s.CreateOfflineSessions(session2); err != nil {

View file

@ -191,6 +191,7 @@ type OfflineSessions struct {
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`
ConnID string `json:"conn_id,omitempty"` ConnID string `json:"conn_id,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
} }
func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
@ -198,6 +199,7 @@ func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
UserID: o.UserID, UserID: o.UserID,
ConnID: o.ConnID, ConnID: o.ConnID,
Refresh: o.Refresh, Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
} }
} }
@ -206,6 +208,7 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
UserID: o.UserID, UserID: o.UserID,
ConnID: o.ConnID, ConnID: o.ConnID,
Refresh: o.Refresh, Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
} }
if s.Refresh == nil { if s.Refresh == nil {
// Server code assumes this will be non-nil. // Server code assumes this will be non-nil.

View file

@ -555,6 +555,7 @@ type OfflineSessions struct {
UserID string `json:"userID,omitempty"` UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"` ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
} }
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
@ -570,6 +571,7 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline
UserID: o.UserID, UserID: o.UserID,
ConnID: o.ConnID, ConnID: o.ConnID,
Refresh: o.Refresh, Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
} }
} }
@ -578,6 +580,7 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
UserID: o.UserID, UserID: o.UserID,
ConnID: o.ConnID, ConnID: o.ConnID,
Refresh: o.Refresh, Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
} }
if s.Refresh == nil { if s.Refresh == nil {
// Server code assumes this will be non-nil. // Server code assumes this will be non-nil.

View file

@ -655,13 +655,13 @@ func scanPassword(s scanner) (p storage.Password, err error) {
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
_, err := c.Exec(` _, err := c.Exec(`
insert into offline_session ( insert into offline_session (
user_id, conn_id, refresh user_id, conn_id, refresh, connector_data
) )
values ( values (
$1, $2, $3 $1, $2, $3, $4
); );
`, `,
s.UserID, s.ConnID, encoder(s.Refresh), s.UserID, s.ConnID, encoder(s.Refresh), s.ConnectorData,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
@ -686,10 +686,11 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
_, err = tx.Exec(` _, err = tx.Exec(`
update offline_session update offline_session
set set
refresh = $1 refresh = $1,
where user_id = $2 AND conn_id = $3; connector_data = $2
where user_id = $3 AND conn_id = $4;
`, `,
encoder(newSession.Refresh), s.UserID, s.ConnID, encoder(newSession.Refresh), newSession.ConnectorData, s.UserID, s.ConnID,
) )
if err != nil { if err != nil {
return fmt.Errorf("update offline session: %v", err) return fmt.Errorf("update offline session: %v", err)
@ -705,7 +706,7 @@ func (c *conn) GetOfflineSessions(userID string, connID string) (storage.Offline
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
return scanOfflineSessions(q.QueryRow(` return scanOfflineSessions(q.QueryRow(`
select select
user_id, conn_id, refresh user_id, conn_id, refresh, connector_data
from offline_session from offline_session
where user_id = $1 AND conn_id = $2; where user_id = $1 AND conn_id = $2;
`, userID, connID)) `, userID, connID))
@ -713,7 +714,7 @@ func getOfflineSessions(q querier, userID string, connID string) (storage.Offlin
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
err = s.Scan( err = s.Scan(
&o.UserID, &o.ConnID, decoder(&o.Refresh), &o.UserID, &o.ConnID, decoder(&o.Refresh), &o.ConnectorData,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {

View file

@ -202,4 +202,11 @@ var migrations = []migration{
add column claims_preferred_username text not null default '';`, add column claims_preferred_username text not null default '';`,
}, },
}, },
{
stmts: []string{`
alter table offline_session
add column connector_data bytea;
`,
},
},
} }

View file

@ -273,6 +273,9 @@ type OfflineSessions struct {
// Refresh is a hash table of refresh token reference objects // Refresh is a hash table of refresh token reference objects
// indexed by the ClientID of the refresh token. // indexed by the ClientID of the refresh token.
Refresh map[string]*RefreshTokenRef Refresh map[string]*RefreshTokenRef
// Authentication data provided by an upstream source.
ConnectorData []byte
} }
// Password is an email to password mapping managed by the storage. // Password is an email to password mapping managed by the storage.