From b02a3a3163b122d7d7aac610467b53b35f84bc13 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Fri, 15 Jul 2016 16:00:59 -0700 Subject: [PATCH] *: add "groups" scope --- db/migrate_sqlite3.go | 4 +- db/migrations/0014_add_groups.sql | 3 + db/migrations/assets.go | 6 ++ db/refresh.go | 18 ++--- db/session.go | 14 ++++ functional/repo/refresh_repo_test.go | 25 ++++--- functional/repo/session_repo_test.go | 7 ++ integration/user_api_test.go | 2 +- refresh/repo.go | 4 +- scope/scope.go | 3 + server/http.go | 1 + server/server.go | 100 ++++++++++++++++++++++----- server/server_test.go | 3 +- session/manager/manager.go | 12 ++++ session/session.go | 6 ++ user/api/api_test.go | 2 +- 16 files changed, 168 insertions(+), 42 deletions(-) create mode 100644 db/migrations/0014_add_groups.sql diff --git a/db/migrate_sqlite3.go b/db/migrate_sqlite3.go index 07c64546..13163725 100644 --- a/db/migrate_sqlite3.go +++ b/db/migrate_sqlite3.go @@ -41,6 +41,7 @@ CREATE TABLE refresh_token ( payload_hash blob, user_id text, client_id text, + connector_id text, scopes text ); @@ -63,7 +64,8 @@ CREATE TABLE session ( user_id text, register integer, nonce text, - scope text + scope text, + groups text ); CREATE TABLE session_key ( diff --git a/db/migrations/0014_add_groups.sql b/db/migrations/0014_add_groups.sql new file mode 100644 index 00000000..e63b8f0d --- /dev/null +++ b/db/migrations/0014_add_groups.sql @@ -0,0 +1,3 @@ +-- +migrate Up +ALTER TABLE refresh_token ADD COLUMN "connector_id" text; +ALTER TABLE session ADD COLUMN "groups" text; diff --git a/db/migrations/assets.go b/db/migrations/assets.go index 1a4b5f89..e9351ba1 100644 --- a/db/migrations/assets.go +++ b/db/migrations/assets.go @@ -90,5 +90,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{ "-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n", }, }, + { + Id: "0014_add_groups.sql", + Up: []string{ + "-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"connector_id\" text;\nALTER TABLE session ADD COLUMN \"groups\" text;\n", + }, + }, }, } diff --git a/db/refresh.go b/db/refresh.go index d16f313f..0baa655f 100644 --- a/db/refresh.go +++ b/db/refresh.go @@ -41,6 +41,7 @@ type refreshTokenModel struct { PayloadHash []byte `db:"payload_hash"` UserID string `db:"user_id"` ClientID string `db:"client_id"` + ConnectorID string `db:"connector_id"` Scopes string `db:"scopes"` } @@ -89,7 +90,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG } } -func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) { +func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) { if userID == "" { return "", refresh.ErrorInvalidUserID } @@ -112,6 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str PayloadHash: payloadHash, UserID: userID, ClientID: clientID, + ConnectorID: connectorID, Scopes: strings.Join(scopes, " "), } @@ -122,24 +124,24 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str return buildToken(record.ID, tokenPayload), nil } -func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) { +func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) { tokenID, tokenPayload, err := parseToken(token) if err != nil { - return "", nil, err + return } record, err := r.get(nil, tokenID) if err != nil { - return "", nil, err + return } if record.ClientID != clientID { - return "", nil, refresh.ErrorInvalidClientID + return "", "", nil, refresh.ErrorInvalidClientID } - if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { - return "", nil, err + if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { + return } var scopes []string @@ -147,7 +149,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, scopes = strings.Split(record.Scopes, " ") } - return record.UserID, scopes, nil + return record.UserID, record.ConnectorID, scopes, nil } func (r *refreshTokenRepo) Revoke(userID, token string) error { diff --git a/db/session.go b/db/session.go index 1eb05cfe..5fb296d3 100644 --- a/db/session.go +++ b/db/session.go @@ -44,6 +44,7 @@ type sessionModel struct { Register bool `db:"register"` Nonce string `db:"nonce"` Scope string `db:"scope"` + Groups string `db:"groups"` } func (s *sessionModel) session() (*session.Session, error) { @@ -75,6 +76,11 @@ func (s *sessionModel) session() (*session.Session, error) { Nonce: s.Nonce, Scope: strings.Fields(s.Scope), } + if s.Groups != "" { + if err := json.Unmarshal([]byte(s.Groups), &ses.Groups); err != nil { + return nil, fmt.Errorf("failed to decode groups in session: %v", err) + } + } if s.CreatedAt != 0 { ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC() @@ -107,6 +113,14 @@ func newSessionModel(s *session.Session) (*sessionModel, error) { Scope: strings.Join(s.Scope, " "), } + if s.Groups != nil { + data, err := json.Marshal(s.Groups) + if err != nil { + return nil, fmt.Errorf("failed to marshal groups: %v", err) + } + sm.Groups = string(data) + } + if !s.CreatedAt.IsZero() { sm.CreatedAt = s.CreatedAt.Unix() } diff --git a/functional/repo/refresh_repo_test.go b/functional/repo/refresh_repo_test.go index f18d0f09..2de50117 100644 --- a/functional/repo/refresh_repo_test.go +++ b/functional/repo/refresh_repo_test.go @@ -20,7 +20,10 @@ import ( var ( testRefreshClientID = "client1" testRefreshClientID2 = "client2" - testRefreshClients = []client.LoadableClient{ + + testRefreshConnectorID = "IDPC-1" + + testRefreshClients = []client.LoadableClient{ { Client: client.Client{ Credentials: oidc.ClientCredentials{ @@ -59,7 +62,7 @@ var ( }, RemoteIdentities: []user.RemoteIdentity{ { - ConnectorID: "IDPC-1", + ConnectorID: testRefreshConnectorID, ID: "RID-1", }, }, @@ -103,12 +106,12 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) { for i, tt := range tests { repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) - tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes) + tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes) if err != nil { t.Fatalf("case %d: failed to create refresh token: %v", i, err) } - tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok) + tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok) if tt.wantVerifyErr { if err == nil { t.Errorf("case %d: want non-nil error.", i) @@ -126,6 +129,10 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) { t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i, testRefreshUserID, tokUserID) } + + if gotConnectorID != testRefreshConnectorID { + t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID) + } } } @@ -138,7 +145,7 @@ func buildRefreshToken(tokenID int64, tokenPayload []byte) string { func TestRefreshRepoVerifyInvalidTokens(t *testing.T) { r := db.NewRefreshTokenRepo(connect(t)) - token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope) + token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -209,7 +216,7 @@ func TestRefreshRepoVerifyInvalidTokens(t *testing.T) { } for i, tt := range tests { - result, _, err := r.Verify(tt.creds.ID, tt.token) + result, _, _, err := r.Verify(tt.creds.ID, tt.token) if err != tt.err { t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) } @@ -232,7 +239,7 @@ func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) { repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) for _, clientID := range tt.clientIDs { - _, err := repo.Create(testRefreshUserID, clientID, []string{"openid"}) + _, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"}) if err != nil { t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) } @@ -281,7 +288,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) { repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) for _, clientID := range tt.createIDs { - _, err := repo.Create(testRefreshUserID, clientID, []string{"openid"}) + _, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"}) if err != nil { t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) } @@ -318,7 +325,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) { func TestRefreshRepoRevoke(t *testing.T) { r := db.NewRefreshTokenRepo(connect(t)) - token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope) + token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/functional/repo/session_repo_test.go b/functional/repo/session_repo_test.go index 4f939e52..2ec49d97 100644 --- a/functional/repo/session_repo_test.go +++ b/functional/repo/session_repo_test.go @@ -104,6 +104,13 @@ func TestSessionRepoCreateGet(t *testing.T) { ExpiresAt: time.Unix(789, 0).UTC(), Nonce: "oncenay", }, + session.Session{ + ID: "anID", + ClientState: "blargh", + ExpiresAt: time.Unix(789, 0).UTC(), + Nonce: "oncenay", + Groups: []string{"group1", "group2"}, + }, } for i, tt := range tests { diff --git a/integration/user_api_test.go b/integration/user_api_test.go index 9584f2fd..e08dc103 100644 --- a/integration/user_api_test.go +++ b/integration/user_api_test.go @@ -149,7 +149,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures { refreshRepo := db.NewRefreshTokenRepo(dbMap) for _, user := range userUsers { if _, err := refreshRepo.Create(user.User.ID, testClientID, - append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil { + "", append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil { panic("Failed to create refresh token: " + err.Error()) } } diff --git a/refresh/repo.go b/refresh/repo.go index df81a426..23ccda85 100644 --- a/refresh/repo.go +++ b/refresh/repo.go @@ -44,12 +44,12 @@ type RefreshTokenRepo interface { // The scopes will be stored with the refresh token, and used to verify // against future OIDC refresh requests' scopes. // On success the token will be returned. - Create(userID, clientID string, scope []string) (string, error) + Create(userID, clientID, connectorID string, scope []string) (string, error) // Verify verifies that a token belongs to the client. // It returns the user ID to which the token belongs, and the scopes stored // with token. - Verify(clientID, token string) (string, scope.Scopes, error) + Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) // Revoke deletes the refresh token if the token belongs to the given userID. Revoke(userID, token string) error diff --git a/scope/scope.go b/scope/scope.go index f13236db..d13b76fe 100644 --- a/scope/scope.go +++ b/scope/scope.go @@ -6,6 +6,9 @@ const ( // Scope prefix which indicates initiation of a cross-client authentication flow. // See https://developers.google.com/identity/protocols/CrossClientAuth ScopeGoogleCrossClient = "audience:server:client_id:" + + // ScopeGroups indicates that groups should be added to the ID Token. + ScopeGroups = "groups" ) type Scopes []string diff --git a/server/http.go b/server/http.go index 35e0d54e..4616389a 100644 --- a/server/http.go +++ b/server/http.go @@ -421,6 +421,7 @@ func validateScopes(srv OIDCServer, clientID string, scopes []string) error { foundOpenIDScope = true case curScope == "profile": case curScope == "email": + case curScope == scope.ScopeGroups: case curScope == "offline_access": // According to the spec, for offline_access scope, the client must // use a response_type value that would result in an Authorization diff --git a/server/server.go b/server/server.go index c993431f..32c1f0ec 100644 --- a/server/server.go +++ b/server/server.go @@ -75,7 +75,8 @@ type Server struct { OOBTemplate *template.Template HealthChecks []health.Checkable - Connectors []connector.Connector + // TODO(ericchiang): Make this a map of ID to connector. + Connectors []connector.Connector ClientRepo client.ClientRepo ConnectorConfigRepo connector.ConnectorConfigRepo @@ -306,6 +307,15 @@ func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL ur return s.SessionManager.NewSessionKey(sessionID) } +func (s *Server) connector(id string) (connector.Connector, bool) { + for _, c := range s.Connectors { + if c.ID() == id { + return c, true + } + } + return nil, false +} + func (s *Server) Login(ident oidc.Identity, key string) (string, error) { sessionID, err := s.SessionManager.ExchangeKey(key) if err != nil { @@ -318,6 +328,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { } log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident) + // Get the connector used to log the user in. + conn, ok := s.connector(ses.ConnectorID) + if !ok { + return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID) + } + + // If the client has requested access to groups, add them here. + if ses.Scope.HasScope(scope.ScopeGroups) { + grouper, ok := conn.(connector.GroupsConnector) + if !ok { + return "", fmt.Errorf("scope %q provided but connector does not support groups", scope.ScopeGroups) + } + groups, err := grouper.Groups(ident.ID) + if err != nil { + return "", fmt.Errorf("failed to retrieve user groups for %q %v", ident.ID, err) + } + + // Update the session. + if ses, err = s.SessionManager.AttachGroups(sessionID, groups); err != nil { + return "", fmt.Errorf("failed save groups") + } + } + if ses.Register { code, err := s.SessionManager.NewSessionKey(sessionID) if err != nil { @@ -334,18 +367,6 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID} - // Get the connector used to log the user in. - var conn connector.Connector - for _, c := range s.Connectors { - if c.ID() == ses.ConnectorID { - conn = c - break - } - } - if conn == nil { - return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID) - } - usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity) if err == user.ErrorNotFound { if ses.Identity.Email == "" { @@ -508,7 +529,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo if scope == "offline_access" { log.Infof("Session %s requests offline access, will generate refresh token", sessionID) - refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope) + refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.ConnectorID, ses.Scope) switch err { case nil: break @@ -535,7 +556,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, return nil, oauth2.NewError(oauth2.ErrorInvalidClient) } - userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) + userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) switch err { case nil: break @@ -555,7 +576,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, } } - user, err := s.UserRepo.Get(nil, userID) + usr, err := s.UserRepo.Get(nil, userID) if err != nil { // The error can be user.ErrorNotFound, but we are not deleting // user at this moment, so this shouldn't happen. @@ -563,6 +584,43 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, return nil, oauth2.NewError(oauth2.ErrorServerError) } + var groups []string + if rtScopes.HasScope(scope.ScopeGroups) { + conn, ok := s.connector(connectorID) + if !ok { + log.Errorf("refresh token contained invalid connector ID (%s)", connectorID) + return nil, oauth2.NewError(oauth2.ErrorServerError) + } + + grouper, ok := conn.(connector.GroupsConnector) + if !ok { + log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID) + return nil, oauth2.NewError(oauth2.ErrorServerError) + } + + remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID) + if err != nil { + log.Errorf("failed to get remote identities: %v", err) + return nil, oauth2.NewError(oauth2.ErrorServerError) + } + remoteIdentity, ok := func() (user.RemoteIdentity, bool) { + for _, ri := range remoteIdentities { + if ri.ConnectorID == connectorID { + return ri, true + } + } + return user.RemoteIdentity{}, false + }() + if !ok { + log.Errorf("failed to get remote identity for connector %s", connectorID) + return nil, oauth2.NewError(oauth2.ErrorServerError) + } + if groups, err = grouper.Groups(remoteIdentity.ID); err != nil { + log.Errorf("failed to get groups for refresh token: %v", connectorID) + return nil, oauth2.NewError(oauth2.ErrorServerError) + } + } + signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to refresh ID token: %v", err) @@ -572,8 +630,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, now := time.Now() expireAt := now.Add(session.DefaultSessionValidityWindow) - claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt) - user.AddToClaims(claims) + claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt) + usr.AddToClaims(claims) + if rtScopes.HasScope(scope.ScopeGroups) { + if groups == nil { + groups = []string{} + } + claims["groups"] = groups + } s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID) diff --git a/server/server_test.go b/server/server_test.go index 8834ad99..1cad4b9c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -785,8 +785,7 @@ func TestServerRefreshToken(t *testing.T) { t.Errorf("case %d: error creating other client: %v", i, err) } - if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, - tt.createScopes); err != nil { + if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, "", tt.createScopes); err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/session/manager/manager.go b/session/manager/manager.go index c0ac6d4a..cbf5af95 100644 --- a/session/manager/manager.go +++ b/session/manager/manager.go @@ -144,6 +144,18 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.S return s, nil } +func (m *SessionManager) AttachGroups(sessionID string, groups []string) (*session.Session, error) { + s, err := m.sessions.Get(sessionID) + if err != nil { + return nil, err + } + s.Groups = groups + if err = m.sessions.Update(*s); err != nil { + return nil, err + } + return s, nil +} + func (m *SessionManager) Kill(sessionID string) (*session.Session, error) { s, err := m.sessions.Get(sessionID) if err != nil { diff --git a/session/session.go b/session/session.go index 050d60a8..b8c80145 100644 --- a/session/session.go +++ b/session/session.go @@ -55,6 +55,9 @@ type Session struct { // Scope is the 'scope' field in the authentication request. Example scopes // are 'openid', 'email', 'offline', etc. Scope scope.Scopes + + // Groups the user belongs to. + Groups []string } // Claims returns a new set of Claims for the current session. @@ -65,5 +68,8 @@ func (s *Session) Claims(issuerURL string) jose.Claims { if s.Nonce != "" { claims["nonce"] = s.Nonce } + if s.Scope.HasScope(scope.ScopeGroups) { + claims["groups"] = s.Groups + } return claims } diff --git a/user/api/api_test.go b/user/api/api_test.go index 90d68687..12a1f64e 100644 --- a/user/api/api_test.go +++ b/user/api/api_test.go @@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { } refreshRepo := db.NewRefreshTokenRepo(dbMap) for _, token := range refreshTokens { - if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil { + if _, err := refreshRepo.Create(token.userID, token.clientID, "local", []string{"openid"}); err != nil { panic("Failed to create refresh token: " + err.Error()) } }