*: add "groups" scope
This commit is contained in:
parent
731dadb29d
commit
b02a3a3163
16 changed files with 168 additions and 42 deletions
|
@ -41,6 +41,7 @@ CREATE TABLE refresh_token (
|
||||||
payload_hash blob,
|
payload_hash blob,
|
||||||
user_id text,
|
user_id text,
|
||||||
client_id text,
|
client_id text,
|
||||||
|
connector_id text,
|
||||||
scopes text
|
scopes text
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -63,7 +64,8 @@ CREATE TABLE session (
|
||||||
user_id text,
|
user_id text,
|
||||||
register integer,
|
register integer,
|
||||||
nonce text,
|
nonce text,
|
||||||
scope text
|
scope text,
|
||||||
|
groups text
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE session_key (
|
CREATE TABLE session_key (
|
||||||
|
|
3
db/migrations/0014_add_groups.sql
Normal file
3
db/migrations/0014_add_groups.sql
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
-- +migrate Up
|
||||||
|
ALTER TABLE refresh_token ADD COLUMN "connector_id" text;
|
||||||
|
ALTER TABLE session ADD COLUMN "groups" text;
|
|
@ -90,5 +90,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
|
||||||
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
|
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "0014_add_groups.sql",
|
||||||
|
Up: []string{
|
||||||
|
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"connector_id\" text;\nALTER TABLE session ADD COLUMN \"groups\" text;\n",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,7 @@ type refreshTokenModel struct {
|
||||||
PayloadHash []byte `db:"payload_hash"`
|
PayloadHash []byte `db:"payload_hash"`
|
||||||
UserID string `db:"user_id"`
|
UserID string `db:"user_id"`
|
||||||
ClientID string `db:"client_id"`
|
ClientID string `db:"client_id"`
|
||||||
|
ConnectorID string `db:"connector_id"`
|
||||||
Scopes string `db:"scopes"`
|
Scopes string `db:"scopes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +90,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) {
|
func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
return "", refresh.ErrorInvalidUserID
|
return "", refresh.ErrorInvalidUserID
|
||||||
}
|
}
|
||||||
|
@ -112,6 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
|
||||||
PayloadHash: payloadHash,
|
PayloadHash: payloadHash,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
|
ConnectorID: connectorID,
|
||||||
Scopes: strings.Join(scopes, " "),
|
Scopes: strings.Join(scopes, " "),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,24 +124,24 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
|
||||||
return buildToken(record.ID, tokenPayload), nil
|
return buildToken(record.ID, tokenPayload), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) {
|
func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
|
||||||
tokenID, tokenPayload, err := parseToken(token)
|
tokenID, tokenPayload, err := parseToken(token)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
record, err := r.get(nil, tokenID)
|
record, err := r.get(nil, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if record.ClientID != clientID {
|
if record.ClientID != clientID {
|
||||||
return "", nil, refresh.ErrorInvalidClientID
|
return "", "", nil, refresh.ErrorInvalidClientID
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
||||||
return "", nil, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var scopes []string
|
var scopes []string
|
||||||
|
@ -147,7 +149,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes,
|
||||||
scopes = strings.Split(record.Scopes, " ")
|
scopes = strings.Split(record.Scopes, " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
return record.UserID, scopes, nil
|
return record.UserID, record.ConnectorID, scopes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
||||||
|
|
|
@ -44,6 +44,7 @@ type sessionModel struct {
|
||||||
Register bool `db:"register"`
|
Register bool `db:"register"`
|
||||||
Nonce string `db:"nonce"`
|
Nonce string `db:"nonce"`
|
||||||
Scope string `db:"scope"`
|
Scope string `db:"scope"`
|
||||||
|
Groups string `db:"groups"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sessionModel) session() (*session.Session, error) {
|
func (s *sessionModel) session() (*session.Session, error) {
|
||||||
|
@ -75,6 +76,11 @@ func (s *sessionModel) session() (*session.Session, error) {
|
||||||
Nonce: s.Nonce,
|
Nonce: s.Nonce,
|
||||||
Scope: strings.Fields(s.Scope),
|
Scope: strings.Fields(s.Scope),
|
||||||
}
|
}
|
||||||
|
if s.Groups != "" {
|
||||||
|
if err := json.Unmarshal([]byte(s.Groups), &ses.Groups); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode groups in session: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if s.CreatedAt != 0 {
|
if s.CreatedAt != 0 {
|
||||||
ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC()
|
ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC()
|
||||||
|
@ -107,6 +113,14 @@ func newSessionModel(s *session.Session) (*sessionModel, error) {
|
||||||
Scope: strings.Join(s.Scope, " "),
|
Scope: strings.Join(s.Scope, " "),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.Groups != nil {
|
||||||
|
data, err := json.Marshal(s.Groups)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal groups: %v", err)
|
||||||
|
}
|
||||||
|
sm.Groups = string(data)
|
||||||
|
}
|
||||||
|
|
||||||
if !s.CreatedAt.IsZero() {
|
if !s.CreatedAt.IsZero() {
|
||||||
sm.CreatedAt = s.CreatedAt.Unix()
|
sm.CreatedAt = s.CreatedAt.Unix()
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,9 @@ import (
|
||||||
var (
|
var (
|
||||||
testRefreshClientID = "client1"
|
testRefreshClientID = "client1"
|
||||||
testRefreshClientID2 = "client2"
|
testRefreshClientID2 = "client2"
|
||||||
|
|
||||||
|
testRefreshConnectorID = "IDPC-1"
|
||||||
|
|
||||||
testRefreshClients = []client.LoadableClient{
|
testRefreshClients = []client.LoadableClient{
|
||||||
{
|
{
|
||||||
Client: client.Client{
|
Client: client.Client{
|
||||||
|
@ -59,7 +62,7 @@ var (
|
||||||
},
|
},
|
||||||
RemoteIdentities: []user.RemoteIdentity{
|
RemoteIdentities: []user.RemoteIdentity{
|
||||||
{
|
{
|
||||||
ConnectorID: "IDPC-1",
|
ConnectorID: testRefreshConnectorID,
|
||||||
ID: "RID-1",
|
ID: "RID-1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -103,12 +106,12 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
||||||
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes)
|
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: failed to create refresh token: %v", i, err)
|
t.Fatalf("case %d: failed to create refresh token: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
|
tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
|
||||||
if tt.wantVerifyErr {
|
if tt.wantVerifyErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("case %d: want non-nil error.", i)
|
t.Errorf("case %d: want non-nil error.", i)
|
||||||
|
@ -126,6 +129,10 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
|
||||||
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
|
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
|
||||||
testRefreshUserID, tokUserID)
|
testRefreshUserID, tokUserID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if gotConnectorID != testRefreshConnectorID {
|
||||||
|
t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,7 +145,7 @@ func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
|
||||||
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
|
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
|
||||||
r := db.NewRefreshTokenRepo(connect(t))
|
r := db.NewRefreshTokenRepo(connect(t))
|
||||||
|
|
||||||
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
|
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -209,7 +216,7 @@ func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
result, _, err := r.Verify(tt.creds.ID, tt.token)
|
result, _, _, err := r.Verify(tt.creds.ID, tt.token)
|
||||||
if err != tt.err {
|
if err != tt.err {
|
||||||
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
||||||
}
|
}
|
||||||
|
@ -232,7 +239,7 @@ func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
|
||||||
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
||||||
|
|
||||||
for _, clientID := range tt.clientIDs {
|
for _, clientID := range tt.clientIDs {
|
||||||
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
|
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
||||||
}
|
}
|
||||||
|
@ -281,7 +288,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
|
||||||
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
||||||
|
|
||||||
for _, clientID := range tt.createIDs {
|
for _, clientID := range tt.createIDs {
|
||||||
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
|
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
||||||
}
|
}
|
||||||
|
@ -318,7 +325,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
|
||||||
func TestRefreshRepoRevoke(t *testing.T) {
|
func TestRefreshRepoRevoke(t *testing.T) {
|
||||||
r := db.NewRefreshTokenRepo(connect(t))
|
r := db.NewRefreshTokenRepo(connect(t))
|
||||||
|
|
||||||
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
|
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -104,6 +104,13 @@ func TestSessionRepoCreateGet(t *testing.T) {
|
||||||
ExpiresAt: time.Unix(789, 0).UTC(),
|
ExpiresAt: time.Unix(789, 0).UTC(),
|
||||||
Nonce: "oncenay",
|
Nonce: "oncenay",
|
||||||
},
|
},
|
||||||
|
session.Session{
|
||||||
|
ID: "anID",
|
||||||
|
ClientState: "blargh",
|
||||||
|
ExpiresAt: time.Unix(789, 0).UTC(),
|
||||||
|
Nonce: "oncenay",
|
||||||
|
Groups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
|
|
@ -149,7 +149,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
||||||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||||
for _, user := range userUsers {
|
for _, user := range userUsers {
|
||||||
if _, err := refreshRepo.Create(user.User.ID, testClientID,
|
if _, err := refreshRepo.Create(user.User.ID, testClientID,
|
||||||
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
|
"", append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
|
||||||
panic("Failed to create refresh token: " + err.Error())
|
panic("Failed to create refresh token: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,12 +44,12 @@ type RefreshTokenRepo interface {
|
||||||
// The scopes will be stored with the refresh token, and used to verify
|
// The scopes will be stored with the refresh token, and used to verify
|
||||||
// against future OIDC refresh requests' scopes.
|
// against future OIDC refresh requests' scopes.
|
||||||
// On success the token will be returned.
|
// On success the token will be returned.
|
||||||
Create(userID, clientID string, scope []string) (string, error)
|
Create(userID, clientID, connectorID string, scope []string) (string, error)
|
||||||
|
|
||||||
// Verify verifies that a token belongs to the client.
|
// Verify verifies that a token belongs to the client.
|
||||||
// It returns the user ID to which the token belongs, and the scopes stored
|
// It returns the user ID to which the token belongs, and the scopes stored
|
||||||
// with token.
|
// with token.
|
||||||
Verify(clientID, token string) (string, scope.Scopes, error)
|
Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error)
|
||||||
|
|
||||||
// Revoke deletes the refresh token if the token belongs to the given userID.
|
// Revoke deletes the refresh token if the token belongs to the given userID.
|
||||||
Revoke(userID, token string) error
|
Revoke(userID, token string) error
|
||||||
|
|
|
@ -6,6 +6,9 @@ const (
|
||||||
// Scope prefix which indicates initiation of a cross-client authentication flow.
|
// Scope prefix which indicates initiation of a cross-client authentication flow.
|
||||||
// See https://developers.google.com/identity/protocols/CrossClientAuth
|
// See https://developers.google.com/identity/protocols/CrossClientAuth
|
||||||
ScopeGoogleCrossClient = "audience:server:client_id:"
|
ScopeGoogleCrossClient = "audience:server:client_id:"
|
||||||
|
|
||||||
|
// ScopeGroups indicates that groups should be added to the ID Token.
|
||||||
|
ScopeGroups = "groups"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Scopes []string
|
type Scopes []string
|
||||||
|
|
|
@ -421,6 +421,7 @@ func validateScopes(srv OIDCServer, clientID string, scopes []string) error {
|
||||||
foundOpenIDScope = true
|
foundOpenIDScope = true
|
||||||
case curScope == "profile":
|
case curScope == "profile":
|
||||||
case curScope == "email":
|
case curScope == "email":
|
||||||
|
case curScope == scope.ScopeGroups:
|
||||||
case curScope == "offline_access":
|
case curScope == "offline_access":
|
||||||
// According to the spec, for offline_access scope, the client must
|
// According to the spec, for offline_access scope, the client must
|
||||||
// use a response_type value that would result in an Authorization
|
// use a response_type value that would result in an Authorization
|
||||||
|
|
|
@ -75,6 +75,7 @@ type Server struct {
|
||||||
OOBTemplate *template.Template
|
OOBTemplate *template.Template
|
||||||
|
|
||||||
HealthChecks []health.Checkable
|
HealthChecks []health.Checkable
|
||||||
|
// TODO(ericchiang): Make this a map of ID to connector.
|
||||||
Connectors []connector.Connector
|
Connectors []connector.Connector
|
||||||
|
|
||||||
ClientRepo client.ClientRepo
|
ClientRepo client.ClientRepo
|
||||||
|
@ -306,6 +307,15 @@ func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL ur
|
||||||
return s.SessionManager.NewSessionKey(sessionID)
|
return s.SessionManager.NewSessionKey(sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) connector(id string) (connector.Connector, bool) {
|
||||||
|
for _, c := range s.Connectors {
|
||||||
|
if c.ID() == id {
|
||||||
|
return c, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
||||||
sessionID, err := s.SessionManager.ExchangeKey(key)
|
sessionID, err := s.SessionManager.ExchangeKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -318,6 +328,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
||||||
}
|
}
|
||||||
log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident)
|
log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident)
|
||||||
|
|
||||||
|
// Get the connector used to log the user in.
|
||||||
|
conn, ok := s.connector(ses.ConnectorID)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client has requested access to groups, add them here.
|
||||||
|
if ses.Scope.HasScope(scope.ScopeGroups) {
|
||||||
|
grouper, ok := conn.(connector.GroupsConnector)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("scope %q provided but connector does not support groups", scope.ScopeGroups)
|
||||||
|
}
|
||||||
|
groups, err := grouper.Groups(ident.ID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to retrieve user groups for %q %v", ident.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the session.
|
||||||
|
if ses, err = s.SessionManager.AttachGroups(sessionID, groups); err != nil {
|
||||||
|
return "", fmt.Errorf("failed save groups")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if ses.Register {
|
if ses.Register {
|
||||||
code, err := s.SessionManager.NewSessionKey(sessionID)
|
code, err := s.SessionManager.NewSessionKey(sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -334,18 +367,6 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
||||||
|
|
||||||
remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID}
|
remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID}
|
||||||
|
|
||||||
// Get the connector used to log the user in.
|
|
||||||
var conn connector.Connector
|
|
||||||
for _, c := range s.Connectors {
|
|
||||||
if c.ID() == ses.ConnectorID {
|
|
||||||
conn = c
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if conn == nil {
|
|
||||||
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
|
|
||||||
}
|
|
||||||
|
|
||||||
usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity)
|
usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity)
|
||||||
if err == user.ErrorNotFound {
|
if err == user.ErrorNotFound {
|
||||||
if ses.Identity.Email == "" {
|
if ses.Identity.Email == "" {
|
||||||
|
@ -508,7 +529,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
||||||
if scope == "offline_access" {
|
if scope == "offline_access" {
|
||||||
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
|
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
|
||||||
|
|
||||||
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope)
|
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.ConnectorID, ses.Scope)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
break
|
break
|
||||||
|
@ -535,7 +556,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
break
|
break
|
||||||
|
@ -555,7 +576,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.UserRepo.Get(nil, userID)
|
usr, err := s.UserRepo.Get(nil, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// The error can be user.ErrorNotFound, but we are not deleting
|
// The error can be user.ErrorNotFound, but we are not deleting
|
||||||
// user at this moment, so this shouldn't happen.
|
// user at this moment, so this shouldn't happen.
|
||||||
|
@ -563,6 +584,43 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var groups []string
|
||||||
|
if rtScopes.HasScope(scope.ScopeGroups) {
|
||||||
|
conn, ok := s.connector(connectorID)
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
|
||||||
|
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
grouper, ok := conn.(connector.GroupsConnector)
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
|
||||||
|
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get remote identities: %v", err)
|
||||||
|
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
|
}
|
||||||
|
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
|
||||||
|
for _, ri := range remoteIdentities {
|
||||||
|
if ri.ConnectorID == connectorID {
|
||||||
|
return ri, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return user.RemoteIdentity{}, false
|
||||||
|
}()
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("failed to get remote identity for connector %s", connectorID)
|
||||||
|
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
|
}
|
||||||
|
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
|
||||||
|
log.Errorf("failed to get groups for refresh token: %v", connectorID)
|
||||||
|
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
signer, err := s.KeyManager.Signer()
|
signer, err := s.KeyManager.Signer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to refresh ID token: %v", err)
|
log.Errorf("Failed to refresh ID token: %v", err)
|
||||||
|
@ -572,8 +630,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
expireAt := now.Add(session.DefaultSessionValidityWindow)
|
expireAt := now.Add(session.DefaultSessionValidityWindow)
|
||||||
|
|
||||||
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
|
claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt)
|
||||||
user.AddToClaims(claims)
|
usr.AddToClaims(claims)
|
||||||
|
if rtScopes.HasScope(scope.ScopeGroups) {
|
||||||
|
if groups == nil {
|
||||||
|
groups = []string{}
|
||||||
|
}
|
||||||
|
claims["groups"] = groups
|
||||||
|
}
|
||||||
|
|
||||||
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
|
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
|
||||||
|
|
||||||
|
|
|
@ -785,8 +785,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
t.Errorf("case %d: error creating other client: %v", i, err)
|
t.Errorf("case %d: error creating other client: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID,
|
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, "", tt.createScopes); err != nil {
|
||||||
tt.createScopes); err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -144,6 +144,18 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.S
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *SessionManager) AttachGroups(sessionID string, groups []string) (*session.Session, error) {
|
||||||
|
s, err := m.sessions.Get(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.Groups = groups
|
||||||
|
if err = m.sessions.Update(*s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
|
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
|
||||||
s, err := m.sessions.Get(sessionID)
|
s, err := m.sessions.Get(sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -55,6 +55,9 @@ type Session struct {
|
||||||
// Scope is the 'scope' field in the authentication request. Example scopes
|
// Scope is the 'scope' field in the authentication request. Example scopes
|
||||||
// are 'openid', 'email', 'offline', etc.
|
// are 'openid', 'email', 'offline', etc.
|
||||||
Scope scope.Scopes
|
Scope scope.Scopes
|
||||||
|
|
||||||
|
// Groups the user belongs to.
|
||||||
|
Groups []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claims returns a new set of Claims for the current session.
|
// Claims returns a new set of Claims for the current session.
|
||||||
|
@ -65,5 +68,8 @@ func (s *Session) Claims(issuerURL string) jose.Claims {
|
||||||
if s.Nonce != "" {
|
if s.Nonce != "" {
|
||||||
claims["nonce"] = s.Nonce
|
claims["nonce"] = s.Nonce
|
||||||
}
|
}
|
||||||
|
if s.Scope.HasScope(scope.ScopeGroups) {
|
||||||
|
claims["groups"] = s.Groups
|
||||||
|
}
|
||||||
return claims
|
return claims
|
||||||
}
|
}
|
||||||
|
|
|
@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
}
|
}
|
||||||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||||
for _, token := range refreshTokens {
|
for _, token := range refreshTokens {
|
||||||
if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil {
|
if _, err := refreshRepo.Create(token.userID, token.clientID, "local", []string{"openid"}); err != nil {
|
||||||
panic("Failed to create refresh token: " + err.Error())
|
panic("Failed to create refresh token: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue