*: add "groups" scope

This commit is contained in:
Eric Chiang 2016-07-15 16:00:59 -07:00
parent 731dadb29d
commit b02a3a3163
16 changed files with 168 additions and 42 deletions

View file

@ -41,6 +41,7 @@ CREATE TABLE refresh_token (
payload_hash blob, payload_hash blob,
user_id text, user_id text,
client_id text, client_id text,
connector_id text,
scopes text scopes text
); );
@ -63,7 +64,8 @@ CREATE TABLE session (
user_id text, user_id text,
register integer, register integer,
nonce text, nonce text,
scope text scope text,
groups text
); );
CREATE TABLE session_key ( CREATE TABLE session_key (

View file

@ -0,0 +1,3 @@
-- +migrate Up
ALTER TABLE refresh_token ADD COLUMN "connector_id" text;
ALTER TABLE session ADD COLUMN "groups" text;

View file

@ -90,5 +90,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n", "-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
}, },
}, },
{
Id: "0014_add_groups.sql",
Up: []string{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"connector_id\" text;\nALTER TABLE session ADD COLUMN \"groups\" text;\n",
},
},
}, },
} }

View file

@ -41,6 +41,7 @@ type refreshTokenModel struct {
PayloadHash []byte `db:"payload_hash"` PayloadHash []byte `db:"payload_hash"`
UserID string `db:"user_id"` UserID string `db:"user_id"`
ClientID string `db:"client_id"` ClientID string `db:"client_id"`
ConnectorID string `db:"connector_id"`
Scopes string `db:"scopes"` Scopes string `db:"scopes"`
} }
@ -89,7 +90,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
} }
} }
func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) { func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" { if userID == "" {
return "", refresh.ErrorInvalidUserID return "", refresh.ErrorInvalidUserID
} }
@ -112,6 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
PayloadHash: payloadHash, PayloadHash: payloadHash,
UserID: userID, UserID: userID,
ClientID: clientID, ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "), Scopes: strings.Join(scopes, " "),
} }
@ -122,24 +124,24 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
return buildToken(record.ID, tokenPayload), nil return buildToken(record.ID, tokenPayload), nil
} }
func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) { func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
tokenID, tokenPayload, err := parseToken(token) tokenID, tokenPayload, err := parseToken(token)
if err != nil { if err != nil {
return "", nil, err return
} }
record, err := r.get(nil, tokenID) record, err := r.get(nil, tokenID)
if err != nil { if err != nil {
return "", nil, err return
} }
if record.ClientID != clientID { if record.ClientID != clientID {
return "", nil, refresh.ErrorInvalidClientID return "", "", nil, refresh.ErrorInvalidClientID
} }
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return "", nil, err return
} }
var scopes []string var scopes []string
@ -147,7 +149,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes,
scopes = strings.Split(record.Scopes, " ") scopes = strings.Split(record.Scopes, " ")
} }
return record.UserID, scopes, nil return record.UserID, record.ConnectorID, scopes, nil
} }
func (r *refreshTokenRepo) Revoke(userID, token string) error { func (r *refreshTokenRepo) Revoke(userID, token string) error {

View file

@ -44,6 +44,7 @@ type sessionModel struct {
Register bool `db:"register"` Register bool `db:"register"`
Nonce string `db:"nonce"` Nonce string `db:"nonce"`
Scope string `db:"scope"` Scope string `db:"scope"`
Groups string `db:"groups"`
} }
func (s *sessionModel) session() (*session.Session, error) { func (s *sessionModel) session() (*session.Session, error) {
@ -75,6 +76,11 @@ func (s *sessionModel) session() (*session.Session, error) {
Nonce: s.Nonce, Nonce: s.Nonce,
Scope: strings.Fields(s.Scope), Scope: strings.Fields(s.Scope),
} }
if s.Groups != "" {
if err := json.Unmarshal([]byte(s.Groups), &ses.Groups); err != nil {
return nil, fmt.Errorf("failed to decode groups in session: %v", err)
}
}
if s.CreatedAt != 0 { if s.CreatedAt != 0 {
ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC() ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC()
@ -107,6 +113,14 @@ func newSessionModel(s *session.Session) (*sessionModel, error) {
Scope: strings.Join(s.Scope, " "), Scope: strings.Join(s.Scope, " "),
} }
if s.Groups != nil {
data, err := json.Marshal(s.Groups)
if err != nil {
return nil, fmt.Errorf("failed to marshal groups: %v", err)
}
sm.Groups = string(data)
}
if !s.CreatedAt.IsZero() { if !s.CreatedAt.IsZero() {
sm.CreatedAt = s.CreatedAt.Unix() sm.CreatedAt = s.CreatedAt.Unix()
} }

View file

@ -20,7 +20,10 @@ import (
var ( var (
testRefreshClientID = "client1" testRefreshClientID = "client1"
testRefreshClientID2 = "client2" testRefreshClientID2 = "client2"
testRefreshClients = []client.LoadableClient{
testRefreshConnectorID = "IDPC-1"
testRefreshClients = []client.LoadableClient{
{ {
Client: client.Client{ Client: client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
@ -59,7 +62,7 @@ var (
}, },
RemoteIdentities: []user.RemoteIdentity{ RemoteIdentities: []user.RemoteIdentity{
{ {
ConnectorID: "IDPC-1", ConnectorID: testRefreshConnectorID,
ID: "RID-1", ID: "RID-1",
}, },
}, },
@ -103,12 +106,12 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes) tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes)
if err != nil { if err != nil {
t.Fatalf("case %d: failed to create refresh token: %v", i, err) t.Fatalf("case %d: failed to create refresh token: %v", i, err)
} }
tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok) tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
if tt.wantVerifyErr { if tt.wantVerifyErr {
if err == nil { if err == nil {
t.Errorf("case %d: want non-nil error.", i) t.Errorf("case %d: want non-nil error.", i)
@ -126,6 +129,10 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i, t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
testRefreshUserID, tokUserID) testRefreshUserID, tokUserID)
} }
if gotConnectorID != testRefreshConnectorID {
t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID)
}
} }
} }
@ -138,7 +145,7 @@ func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) { func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t)) r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope) token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
@ -209,7 +216,7 @@ func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
result, _, err := r.Verify(tt.creds.ID, tt.token) result, _, _, err := r.Verify(tt.creds.ID, tt.token)
if err != tt.err { if err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
} }
@ -232,7 +239,7 @@ func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.clientIDs { for _, clientID := range tt.clientIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"}) _, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
} }
@ -281,7 +288,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.createIDs { for _, clientID := range tt.createIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"}) _, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
} }
@ -318,7 +325,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
func TestRefreshRepoRevoke(t *testing.T) { func TestRefreshRepoRevoke(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t)) r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope) token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }

View file

@ -104,6 +104,13 @@ func TestSessionRepoCreateGet(t *testing.T) {
ExpiresAt: time.Unix(789, 0).UTC(), ExpiresAt: time.Unix(789, 0).UTC(),
Nonce: "oncenay", Nonce: "oncenay",
}, },
session.Session{
ID: "anID",
ClientState: "blargh",
ExpiresAt: time.Unix(789, 0).UTC(),
Nonce: "oncenay",
Groups: []string{"group1", "group2"},
},
} }
for i, tt := range tests { for i, tt := range tests {

View file

@ -149,7 +149,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
refreshRepo := db.NewRefreshTokenRepo(dbMap) refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, user := range userUsers { for _, user := range userUsers {
if _, err := refreshRepo.Create(user.User.ID, testClientID, if _, err := refreshRepo.Create(user.User.ID, testClientID,
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil { "", append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
panic("Failed to create refresh token: " + err.Error()) panic("Failed to create refresh token: " + err.Error())
} }
} }

View file

@ -44,12 +44,12 @@ type RefreshTokenRepo interface {
// The scopes will be stored with the refresh token, and used to verify // The scopes will be stored with the refresh token, and used to verify
// against future OIDC refresh requests' scopes. // against future OIDC refresh requests' scopes.
// On success the token will be returned. // On success the token will be returned.
Create(userID, clientID string, scope []string) (string, error) Create(userID, clientID, connectorID string, scope []string) (string, error)
// Verify verifies that a token belongs to the client. // Verify verifies that a token belongs to the client.
// It returns the user ID to which the token belongs, and the scopes stored // It returns the user ID to which the token belongs, and the scopes stored
// with token. // with token.
Verify(clientID, token string) (string, scope.Scopes, error) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error)
// Revoke deletes the refresh token if the token belongs to the given userID. // Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error Revoke(userID, token string) error

View file

@ -6,6 +6,9 @@ const (
// Scope prefix which indicates initiation of a cross-client authentication flow. // Scope prefix which indicates initiation of a cross-client authentication flow.
// See https://developers.google.com/identity/protocols/CrossClientAuth // See https://developers.google.com/identity/protocols/CrossClientAuth
ScopeGoogleCrossClient = "audience:server:client_id:" ScopeGoogleCrossClient = "audience:server:client_id:"
// ScopeGroups indicates that groups should be added to the ID Token.
ScopeGroups = "groups"
) )
type Scopes []string type Scopes []string

View file

@ -421,6 +421,7 @@ func validateScopes(srv OIDCServer, clientID string, scopes []string) error {
foundOpenIDScope = true foundOpenIDScope = true
case curScope == "profile": case curScope == "profile":
case curScope == "email": case curScope == "email":
case curScope == scope.ScopeGroups:
case curScope == "offline_access": case curScope == "offline_access":
// According to the spec, for offline_access scope, the client must // According to the spec, for offline_access scope, the client must
// use a response_type value that would result in an Authorization // use a response_type value that would result in an Authorization

View file

@ -75,7 +75,8 @@ type Server struct {
OOBTemplate *template.Template OOBTemplate *template.Template
HealthChecks []health.Checkable HealthChecks []health.Checkable
Connectors []connector.Connector // TODO(ericchiang): Make this a map of ID to connector.
Connectors []connector.Connector
ClientRepo client.ClientRepo ClientRepo client.ClientRepo
ConnectorConfigRepo connector.ConnectorConfigRepo ConnectorConfigRepo connector.ConnectorConfigRepo
@ -306,6 +307,15 @@ func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL ur
return s.SessionManager.NewSessionKey(sessionID) return s.SessionManager.NewSessionKey(sessionID)
} }
func (s *Server) connector(id string) (connector.Connector, bool) {
for _, c := range s.Connectors {
if c.ID() == id {
return c, true
}
}
return nil, false
}
func (s *Server) Login(ident oidc.Identity, key string) (string, error) { func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
sessionID, err := s.SessionManager.ExchangeKey(key) sessionID, err := s.SessionManager.ExchangeKey(key)
if err != nil { if err != nil {
@ -318,6 +328,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
} }
log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident) log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident)
// Get the connector used to log the user in.
conn, ok := s.connector(ses.ConnectorID)
if !ok {
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
}
// If the client has requested access to groups, add them here.
if ses.Scope.HasScope(scope.ScopeGroups) {
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
return "", fmt.Errorf("scope %q provided but connector does not support groups", scope.ScopeGroups)
}
groups, err := grouper.Groups(ident.ID)
if err != nil {
return "", fmt.Errorf("failed to retrieve user groups for %q %v", ident.ID, err)
}
// Update the session.
if ses, err = s.SessionManager.AttachGroups(sessionID, groups); err != nil {
return "", fmt.Errorf("failed save groups")
}
}
if ses.Register { if ses.Register {
code, err := s.SessionManager.NewSessionKey(sessionID) code, err := s.SessionManager.NewSessionKey(sessionID)
if err != nil { if err != nil {
@ -334,18 +367,6 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID} remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID}
// Get the connector used to log the user in.
var conn connector.Connector
for _, c := range s.Connectors {
if c.ID() == ses.ConnectorID {
conn = c
break
}
}
if conn == nil {
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
}
usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity) usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity)
if err == user.ErrorNotFound { if err == user.ErrorNotFound {
if ses.Identity.Email == "" { if ses.Identity.Email == "" {
@ -508,7 +529,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
if scope == "offline_access" { if scope == "offline_access" {
log.Infof("Session %s requests offline access, will generate refresh token", sessionID) log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope) refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.ConnectorID, ses.Scope)
switch err { switch err {
case nil: case nil:
break break
@ -535,7 +556,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
return nil, oauth2.NewError(oauth2.ErrorInvalidClient) return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
} }
userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
switch err { switch err {
case nil: case nil:
break break
@ -555,7 +576,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
} }
} }
user, err := s.UserRepo.Get(nil, userID) usr, err := s.UserRepo.Get(nil, userID)
if err != nil { if err != nil {
// The error can be user.ErrorNotFound, but we are not deleting // The error can be user.ErrorNotFound, but we are not deleting
// user at this moment, so this shouldn't happen. // user at this moment, so this shouldn't happen.
@ -563,6 +584,43 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, oauth2.NewError(oauth2.ErrorServerError)
} }
var groups []string
if rtScopes.HasScope(scope.ScopeGroups) {
conn, ok := s.connector(connectorID)
if !ok {
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
if err != nil {
log.Errorf("failed to get remote identities: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
for _, ri := range remoteIdentities {
if ri.ConnectorID == connectorID {
return ri, true
}
}
return user.RemoteIdentity{}, false
}()
if !ok {
log.Errorf("failed to get remote identity for connector %s", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
log.Errorf("failed to get groups for refresh token: %v", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
}
signer, err := s.KeyManager.Signer() signer, err := s.KeyManager.Signer()
if err != nil { if err != nil {
log.Errorf("Failed to refresh ID token: %v", err) log.Errorf("Failed to refresh ID token: %v", err)
@ -572,8 +630,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
now := time.Now() now := time.Now()
expireAt := now.Add(session.DefaultSessionValidityWindow) expireAt := now.Add(session.DefaultSessionValidityWindow)
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt) claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt)
user.AddToClaims(claims) usr.AddToClaims(claims)
if rtScopes.HasScope(scope.ScopeGroups) {
if groups == nil {
groups = []string{}
}
claims["groups"] = groups
}
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID) s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)

View file

@ -785,8 +785,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Errorf("case %d: error creating other client: %v", i, err) t.Errorf("case %d: error creating other client: %v", i, err)
} }
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, "", tt.createScopes); err != nil {
tt.createScopes); err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }

View file

@ -144,6 +144,18 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.S
return s, nil return s, nil
} }
func (m *SessionManager) AttachGroups(sessionID string, groups []string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.Groups = groups
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) { func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID) s, err := m.sessions.Get(sessionID)
if err != nil { if err != nil {

View file

@ -55,6 +55,9 @@ type Session struct {
// Scope is the 'scope' field in the authentication request. Example scopes // Scope is the 'scope' field in the authentication request. Example scopes
// are 'openid', 'email', 'offline', etc. // are 'openid', 'email', 'offline', etc.
Scope scope.Scopes Scope scope.Scopes
// Groups the user belongs to.
Groups []string
} }
// Claims returns a new set of Claims for the current session. // Claims returns a new set of Claims for the current session.
@ -65,5 +68,8 @@ func (s *Session) Claims(issuerURL string) jose.Claims {
if s.Nonce != "" { if s.Nonce != "" {
claims["nonce"] = s.Nonce claims["nonce"] = s.Nonce
} }
if s.Scope.HasScope(scope.ScopeGroups) {
claims["groups"] = s.Groups
}
return claims return claims
} }

View file

@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
} }
refreshRepo := db.NewRefreshTokenRepo(dbMap) refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens { for _, token := range refreshTokens {
if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil { if _, err := refreshRepo.Create(token.userID, token.clientID, "local", []string{"openid"}); err != nil {
panic("Failed to create refresh token: " + err.Error()) panic("Failed to create refresh token: " + err.Error())
} }
} }