Merge pull request #510 from ericchiang/add-groups-scope-and-ldap-implementation
Add groups scope and LDAP implementation
This commit is contained in:
commit
edb010caa3
23 changed files with 340 additions and 132 deletions
|
@ -56,3 +56,19 @@ For situations in which an app does not have access to a browser, the out-of-ban
|
|||
\* In OpenID Connect a client is called a "Relying Party", but "client" seems to
|
||||
be the more common ter, has been around longer and is present in paramter names
|
||||
like "client_id" so we prefer it over "Relying Party" usually.
|
||||
|
||||
## Groups
|
||||
|
||||
Connectors that support groups (currently only the LDAP connector) can embed the groups a user belongs to in the ID Token. Using the scope "groups" during the initial redirect with a connector that supports groups will return an JWT with the following field.
|
||||
|
||||
```
|
||||
{
|
||||
"groups": [
|
||||
"cn=ipausers,cn=groups,cn=accounts,dc=example,dc=com,
|
||||
"cn=team-engineering,cn=groups,cn=accounts,dc=example,dc=com"
|
||||
],
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
If the client has also requested a refresh token, the groups field is updated during each refresh request.
|
||||
|
|
|
@ -153,6 +153,7 @@ In addition to `id` and `type`, the `ldap` connector takes the following additio
|
|||
* emailAttribute: a `string`. Required. Attribute to map to Email. Default: `mail`
|
||||
* searchBeforeAuth: a `boolean`. Perform search for entryDN to be used for bind.
|
||||
* searchFilter: a `string`. Filter to apply to search. Variable substititions: `%u` User supplied username/e-mail address. `%b` BaseDN. Searches that return multiple entries are considered ambiguous and will return an error.
|
||||
* searchGroupFilter: a `string`. A filter which should return group entry for a given user. The string is formatted the same as `searchFilter`, execpt `%u` is replaced by the fully qualified user entry. Groups are only searched if the client request the "groups" scope.
|
||||
* searchScope: a `string`. Scope of the search. `base|one|sub`. Default: `one`
|
||||
* searchBindDN: a `string`. DN to bind as for search operations.
|
||||
* searchBindPw: a `string`. Password for bind for search operations.
|
||||
|
@ -180,19 +181,20 @@ uid=janedoe,cn=users,cn=accounts,dc=auth,dc=example,dc=com
|
|||
|
||||
The connector then attempts to bind as this entry using the password provided by the end user.
|
||||
|
||||
### Example: Searching the directory
|
||||
### Example: Searching a FreeIPA server with groups
|
||||
|
||||
The following configuration will search a directory using an LDAP filter. With FreeIPA
|
||||
The following configuration will search a FreeIPA directory using an LDAP filter.
|
||||
|
||||
```
|
||||
{
|
||||
"type": "ldap",
|
||||
"id": "ldap",
|
||||
"host": "127.0.0.1:389",
|
||||
"baseDN": "cn=auth,dc=example,dc=com",
|
||||
"baseDN": "cn=accounts,dc=example,dc=com",
|
||||
|
||||
"searchBeforeAuth": true,
|
||||
"searchFilter": "(&(objectClass=person)(uid=%u))",
|
||||
"searchGroupFilter": "(&(objectClass=ipausergroup)(member=%u))",
|
||||
"searchScope": "sub",
|
||||
|
||||
"searchBindDN": "serviceAccountUser",
|
||||
|
@ -206,9 +208,15 @@ The following configuration will search a directory using an LDAP filter. With F
|
|||
(&(objectClass=person)(uid=janedoe))
|
||||
```
|
||||
|
||||
If the search finds an entry, it will attempt to use the provided password to bind as that entry.
|
||||
If the search finds an entry, it will attempt to use the provided password to bind as that entry. Searches that return multiple entries are considered ambiguous and will return an error.
|
||||
|
||||
__NOTE__: Searches that return multiple entries will return an error.
|
||||
"searchGroupFilter" is a format string similar to "searchFilter" except `%u` is replaced by the fully qualified user entry returned by "searchFilter". So if the initial search returns "uid=janedoe,cn=users,cn=accounts,dc=example,dc=com", the connector will use the search query:
|
||||
|
||||
```
|
||||
(&(objectClass=ipausergroup)(member=uid=janedoe,cn=users,cn=accounts,dc=example,dc=com))
|
||||
```
|
||||
|
||||
If the client requests the "groups" scope, the names of all returned entries are added to the ID Token "groups" claim.
|
||||
|
||||
## Setting the Configuration
|
||||
|
||||
|
|
|
@ -107,11 +107,12 @@ type LDAPConnector struct {
|
|||
nameAttribute string
|
||||
emailAttribute string
|
||||
|
||||
searchBeforeAuth bool
|
||||
searchFilter string
|
||||
searchScope int
|
||||
searchBindDN string
|
||||
searchBindPw string
|
||||
searchBeforeAuth bool
|
||||
searchFilter string
|
||||
searchScope int
|
||||
searchBindDN string
|
||||
searchBindPw string
|
||||
searchGroupFilter string
|
||||
|
||||
bindTemplate string
|
||||
|
||||
|
@ -203,19 +204,20 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t
|
|||
}
|
||||
|
||||
idpc := &LDAPConnector{
|
||||
id: cfg.ID,
|
||||
namespace: ns,
|
||||
loginFunc: lf,
|
||||
loginTpl: tpl,
|
||||
baseDN: cfg.BaseDN,
|
||||
nameAttribute: cfg.NameAttribute,
|
||||
emailAttribute: cfg.EmailAttribute,
|
||||
searchBeforeAuth: cfg.SearchBeforeAuth,
|
||||
searchFilter: cfg.SearchFilter,
|
||||
searchScope: searchScope,
|
||||
searchBindDN: cfg.SearchBindDN,
|
||||
searchBindPw: cfg.SearchBindPw,
|
||||
bindTemplate: cfg.BindTemplate,
|
||||
id: cfg.ID,
|
||||
namespace: ns,
|
||||
loginFunc: lf,
|
||||
loginTpl: tpl,
|
||||
baseDN: cfg.BaseDN,
|
||||
nameAttribute: cfg.NameAttribute,
|
||||
emailAttribute: cfg.EmailAttribute,
|
||||
searchBeforeAuth: cfg.SearchBeforeAuth,
|
||||
searchFilter: cfg.SearchFilter,
|
||||
searchGroupFilter: cfg.SearchGroupFilter,
|
||||
searchScope: searchScope,
|
||||
searchBindDN: cfg.SearchBindDN,
|
||||
searchBindPw: cfg.SearchBindPw,
|
||||
bindTemplate: cfg.BindTemplate,
|
||||
ldapPool: &LDAPPool{
|
||||
MaxIdleConn: cfg.MaxIdleConn,
|
||||
PoolCheckTimer: defaultPoolCheckTimer,
|
||||
|
@ -433,12 +435,47 @@ func invalidBindCredentials(err error) bool {
|
|||
|
||||
func (c *LDAPConnector) formatDN(template, username string) string {
|
||||
result := template
|
||||
result = strings.Replace(result, "%u", username, -1)
|
||||
result = strings.Replace(result, "%u", ldap.EscapeFilter(username), -1)
|
||||
result = strings.Replace(result, "%b", c.baseDN, -1)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *LDAPConnector) Groups(fullUserID string) ([]string, error) {
|
||||
if !c.searchBeforeAuth {
|
||||
return nil, fmt.Errorf("cannot search without service account")
|
||||
}
|
||||
if c.searchGroupFilter == "" {
|
||||
return nil, fmt.Errorf("no group filter specified")
|
||||
}
|
||||
|
||||
var groups []string
|
||||
err := c.ldapPool.Do(func(conn *ldap.Conn) error {
|
||||
if err := conn.Bind(c.searchBindDN, c.searchBindPw); err != nil {
|
||||
if !invalidBindCredentials(err) {
|
||||
log.Errorf("failed to connect to LDAP for search bind: %v", err)
|
||||
}
|
||||
return fmt.Errorf("failed to bind: %v", err)
|
||||
}
|
||||
|
||||
req := &ldap.SearchRequest{
|
||||
BaseDN: c.baseDN,
|
||||
Scope: c.searchScope,
|
||||
Filter: c.formatDN(c.searchGroupFilter, fullUserID),
|
||||
}
|
||||
resp, err := conn.Search(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("search failed: %v", err)
|
||||
}
|
||||
groups = make([]string, len(resp.Entries))
|
||||
for i, entry := range resp.Entries {
|
||||
groups[i] = entry.DN
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return groups, err
|
||||
}
|
||||
|
||||
func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, error) {
|
||||
var (
|
||||
identity *oidc.Identity
|
||||
|
@ -447,8 +484,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err
|
|||
if c.searchBeforeAuth {
|
||||
err = c.ldapPool.Do(func(conn *ldap.Conn) error {
|
||||
if err := conn.Bind(c.searchBindDN, c.searchBindPw); err != nil {
|
||||
// Don't wrap error as it may be a specific LDAP error.
|
||||
return err
|
||||
if !invalidBindCredentials(err) {
|
||||
log.Errorf("failed to connect to LDAP for search bind: %v", err)
|
||||
}
|
||||
return fmt.Errorf("failed to bind: %v", err)
|
||||
}
|
||||
|
||||
filter := c.formatDN(c.searchFilter, username)
|
||||
|
@ -491,8 +530,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err
|
|||
err = c.ldapPool.Do(func(conn *ldap.Conn) error {
|
||||
userBindDN := c.formatDN(c.bindTemplate, username)
|
||||
if err := conn.Bind(userBindDN, password); err != nil {
|
||||
// Don't wrap error as it may be a specific LDAP error.
|
||||
return err
|
||||
if !invalidBindCredentials(err) {
|
||||
log.Errorf("failed to connect to LDAP for search bind: %v", err)
|
||||
}
|
||||
return fmt.Errorf("failed to bind: %v", err)
|
||||
}
|
||||
|
||||
req := &ldap.SearchRequest{
|
||||
|
@ -522,11 +563,7 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err
|
|||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !invalidBindCredentials(err) {
|
||||
log.Errorf("failed to connect to LDAP for search bind: %v", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return identity, nil
|
||||
|
|
|
@ -60,6 +60,12 @@ type ConnectorConfig interface {
|
|||
Connector(ns url.URL, loginFunc oidc.LoginFunc, tpls *template.Template) (Connector, error)
|
||||
}
|
||||
|
||||
// GroupsConnector is a strategy for mapping a user to a set of groups. This is optionally
|
||||
// implemented by some connectors.
|
||||
type GroupsConnector interface {
|
||||
Groups(fullUserID string) ([]string, error)
|
||||
}
|
||||
|
||||
type ConnectorConfigRepo interface {
|
||||
All() ([]ConnectorConfig, error)
|
||||
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
|
||||
|
|
|
@ -41,6 +41,7 @@ CREATE TABLE refresh_token (
|
|||
payload_hash blob,
|
||||
user_id text,
|
||||
client_id text,
|
||||
connector_id text,
|
||||
scopes text
|
||||
);
|
||||
|
||||
|
@ -63,7 +64,8 @@ CREATE TABLE session (
|
|||
user_id text,
|
||||
register integer,
|
||||
nonce text,
|
||||
scope text
|
||||
scope text,
|
||||
groups text
|
||||
);
|
||||
|
||||
CREATE TABLE session_key (
|
||||
|
|
3
db/migrations/0014_add_groups.sql
Normal file
3
db/migrations/0014_add_groups.sql
Normal file
|
@ -0,0 +1,3 @@
|
|||
-- +migrate Up
|
||||
ALTER TABLE refresh_token ADD COLUMN "connector_id" text;
|
||||
ALTER TABLE session ADD COLUMN "groups" text;
|
|
@ -90,5 +90,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
|
|||
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "0014_add_groups.sql",
|
||||
Up: []string{
|
||||
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"connector_id\" text;\nALTER TABLE session ADD COLUMN \"groups\" text;\n",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ type refreshTokenModel struct {
|
|||
PayloadHash []byte `db:"payload_hash"`
|
||||
UserID string `db:"user_id"`
|
||||
ClientID string `db:"client_id"`
|
||||
ConnectorID string `db:"connector_id"`
|
||||
Scopes string `db:"scopes"`
|
||||
}
|
||||
|
||||
|
@ -89,7 +90,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
|
|||
}
|
||||
}
|
||||
|
||||
func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) {
|
||||
func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
|
||||
if userID == "" {
|
||||
return "", refresh.ErrorInvalidUserID
|
||||
}
|
||||
|
@ -112,6 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
|
|||
PayloadHash: payloadHash,
|
||||
UserID: userID,
|
||||
ClientID: clientID,
|
||||
ConnectorID: connectorID,
|
||||
Scopes: strings.Join(scopes, " "),
|
||||
}
|
||||
|
||||
|
@ -122,24 +124,24 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
|
|||
return buildToken(record.ID, tokenPayload), nil
|
||||
}
|
||||
|
||||
func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) {
|
||||
func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
|
||||
tokenID, tokenPayload, err := parseToken(token)
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return
|
||||
}
|
||||
|
||||
record, err := r.get(nil, tokenID)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return
|
||||
}
|
||||
|
||||
if record.ClientID != clientID {
|
||||
return "", nil, refresh.ErrorInvalidClientID
|
||||
return "", "", nil, refresh.ErrorInvalidClientID
|
||||
}
|
||||
|
||||
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
||||
return "", nil, err
|
||||
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var scopes []string
|
||||
|
@ -147,7 +149,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes,
|
|||
scopes = strings.Split(record.Scopes, " ")
|
||||
}
|
||||
|
||||
return record.UserID, scopes, nil
|
||||
return record.UserID, record.ConnectorID, scopes, nil
|
||||
}
|
||||
|
||||
func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
||||
|
|
|
@ -44,6 +44,7 @@ type sessionModel struct {
|
|||
Register bool `db:"register"`
|
||||
Nonce string `db:"nonce"`
|
||||
Scope string `db:"scope"`
|
||||
Groups string `db:"groups"`
|
||||
}
|
||||
|
||||
func (s *sessionModel) session() (*session.Session, error) {
|
||||
|
@ -75,6 +76,11 @@ func (s *sessionModel) session() (*session.Session, error) {
|
|||
Nonce: s.Nonce,
|
||||
Scope: strings.Fields(s.Scope),
|
||||
}
|
||||
if s.Groups != "" {
|
||||
if err := json.Unmarshal([]byte(s.Groups), &ses.Groups); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode groups in session: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.CreatedAt != 0 {
|
||||
ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC()
|
||||
|
@ -107,6 +113,14 @@ func newSessionModel(s *session.Session) (*sessionModel, error) {
|
|||
Scope: strings.Join(s.Scope, " "),
|
||||
}
|
||||
|
||||
if s.Groups != nil {
|
||||
data, err := json.Marshal(s.Groups)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal groups: %v", err)
|
||||
}
|
||||
sm.Groups = string(data)
|
||||
}
|
||||
|
||||
if !s.CreatedAt.IsZero() {
|
||||
sm.CreatedAt = s.CreatedAt.Unix()
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ func (fi bindataFileInfo) Sys() interface{} {
|
|||
return nil
|
||||
}
|
||||
|
||||
var _dataIndexHtml = []byte("\x1f\x8b\x08\x00\x00\x09\x6e\x88\x00\xff\x94\x52\xcd\x4e\xc3\x30\x0c\xbe\xef\x29\xac\x9c\xe0\x30\x7a\x47\x6d\x25\x40\xdc\x90\x26\xf1\x02\x53\x9a\x78\x6d\xb4\xfc\x4c\x89\x8b\x36\x4d\x7b\x77\xdc\x96\xae\x5b\x81\x09\x6e\xfe\x14\xfb\xfb\x89\x9d\x37\xe4\x6c\xb9\x00\xc8\xab\xa0\x0f\xe5\x82\x2b\xae\x37\x21\x3a\x90\x8a\x4c\xf0\x85\xc8\x6c\xa8\x8d\x17\x65\xff\xc4\x8f\x24\x2b\x8b\x23\xea\x70\x9c\x40\x07\x75\x09\x4f\x2d\x35\xe8\xc9\x28\x49\x08\x4c\xf6\x78\xd1\xd0\x49\x5d\x4d\x00\xdc\xa9\xe0\x9c\x5c\x26\xdc\xc9\xc8\x13\x1a\xac\x49\x04\x61\x03\xca\x1a\xa6\x59\x1a\x9d\xee\x2f\x25\x32\xd6\x98\x4b\xe6\xc6\xef\x5a\x02\x3a\xec\xb0\x10\x84\x7b\x12\xe0\xa5\xe3\x5a\xc5\x90\xd2\x7a\x60\x12\x50\xce\xa6\x19\x9d\xcd\x70\x3d\x44\x3b\x1e\xc1\x6c\xe0\x61\xb5\x7a\x86\xd3\x69\x6a\xbd\x54\x48\x6d\xe5\x0c\xf3\x7d\x48\xdb\x32\x7c\xeb\xbf\xa8\x8b\xea\x48\xc6\x1a\xa9\x10\xeb\xca\x4a\xbf\x15\x3d\x1b\xda\x84\xff\xa4\x1a\xe6\xbc\x1e\xc7\xf2\xac\x23\xe7\x05\x7d\x37\x37\x5b\x97\x92\xd6\x56\x52\x6d\x05\x38\xa4\x26\xe8\x42\xb0\x9f\x8e\x70\xd0\x7e\x09\x1a\x17\x3f\xd8\xb8\xfa\x33\xee\x39\x1b\x9a\x36\x3f\xed\xed\x56\x80\xd7\xbd\x6a\xa4\xaf\xb1\x57\x1a\x75\x47\xfb\xd7\xa1\xbe\xc2\xf8\x40\xb7\x02\x45\xac\xf9\x1e\x30\x8a\xbf\xa8\xbf\x8f\xcd\x00\xd9\xef\xd2\x79\x36\x9c\x7b\x9e\x0d\xf7\xff\x19\x00\x00\xff\xff\xaf\x0b\xca\x75\x07\x03\x00\x00")
|
||||
var _dataIndexHtml = []byte("\x1f\x8b\x08\x00\x00\x09\x6e\x88\x00\xff\x94\x93\xcf\x8a\xe3\x30\x0c\xc6\xef\x7d\x0a\xe1\x7b\x37\xf7\xc5\x29\xec\x0e\xbd\x0d\x14\xe6\x05\x8a\x63\xab\x89\xa9\xff\x61\x2b\x43\x4b\xe9\xbb\x8f\x53\x37\x61\x52\xd2\xa1\x73\x93\xd1\x27\x7d\x3f\x49\x98\x77\x64\xcd\x66\x05\xc0\x1b\xaf\xce\x43\x90\xc3\x83\x8f\x16\x84\x24\xed\x5d\xcd\x2a\xe3\x5b\xed\x58\x49\x0d\xd9\x30\x85\x00\xff\x7a\xea\xd0\x91\x96\x82\x10\x72\xd9\x5f\xae\x5d\xe8\x09\xe8\x1c\xb0\x66\x84\x27\x62\xe0\x84\xcd\xb1\x8c\x3e\xa5\xbd\x34\x3a\xcb\x19\x04\x23\x24\x76\xde\x28\x8c\x39\xe5\xad\x15\xeb\x84\x41\xc4\xdc\x46\x81\xd1\x89\xc0\x1f\xa0\x88\xd7\x5a\xa5\x6f\xee\x55\x58\x26\xd9\x9e\x28\x0a\x48\xd2\x07\x4c\xcf\x29\x70\x50\xed\x8b\xea\x45\x8a\xbb\x78\x4e\x70\xb9\x80\x3e\xc0\x9f\xdd\xee\x3f\x5c\xaf\x13\xc4\xcc\x36\xf5\x8d\xd5\xd9\xf8\x53\x98\x3e\x3f\xdf\x6f\x5b\x1c\x76\x64\x49\xc4\x16\xa9\x66\xfb\xc6\x08\x77\x64\xb7\x6e\x68\x12\xfe\xb2\x55\xa9\x73\x6a\x2c\xe3\xd5\xd0\x7c\xb3\x5a\x80\x7b\xb8\xa8\x14\xc6\x34\x42\x1e\x19\x58\xa4\xce\xab\x9a\x65\x9e\xa1\x61\xf1\x7e\xf3\x0a\x57\x0b\x18\xb3\x73\x66\xcd\x04\x34\x2d\x87\x37\x71\xb3\x54\xf9\x30\xc0\xf6\x24\x3b\xe1\x5a\xbc\x39\x8d\xbe\x23\xfe\x7c\xa8\xfb\x30\xce\xd3\x4f\x03\x45\x6c\xf3\xb5\x30\xb2\x57\xdc\x3f\x46\x31\x40\xf5\xdc\x9a\x57\xe5\x43\xf0\xaa\xfc\x90\xaf\x00\x00\x00\xff\xff\x9c\x89\xe2\x28\x29\x03\x00\x00")
|
||||
|
||||
func dataIndexHtmlBytes() ([]byte, error) {
|
||||
return bindataRead(
|
||||
|
@ -83,7 +83,7 @@ func dataIndexHtml() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "data/index.html", size: 775, mode: os.FileMode(420), modTime: time.Unix(1466378108, 0)}
|
||||
info := bindataFileInfo{name: "data/index.html", size: 809, mode: os.FileMode(436), modTime: time.Unix(1468620773, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
|
|
@ -1,16 +1,12 @@
|
|||
<html>
|
||||
<body>
|
||||
|
||||
<form action="/login">
|
||||
<table>
|
||||
<tr>
|
||||
<td> Authenticate for:
|
||||
<br>
|
||||
(comma-separated list of client-ids)
|
||||
</td>
|
||||
<td> <input type="text" name="cross_client" > </td>
|
||||
</tr>
|
||||
</table>
|
||||
<p>
|
||||
Authenticate for:<input type="text" name="cross_client" placeholder="comma-separated list of client-ids">
|
||||
</p>
|
||||
<p>
|
||||
Extra scopes:<input type="text" name="extra_scopes" placeholder="comma-separated list of scopes">
|
||||
</p>
|
||||
{{ if .OOB }}
|
||||
<input type="submit" value="Login" formtarget="_blank">
|
||||
{{ else }}
|
||||
|
|
|
@ -218,18 +218,25 @@ func handleLoginFunc(c *oidc.Client) http.HandlerFunc {
|
|||
panic("unable to proceed")
|
||||
}
|
||||
|
||||
xClient := r.Form.Get("cross_client")
|
||||
if xClient != "" {
|
||||
var scopes []string
|
||||
q := u.Query()
|
||||
if scope := q.Get("scope"); scope != "" {
|
||||
scopes = strings.Split(scope, " ")
|
||||
}
|
||||
|
||||
if xClient := r.Form.Get("cross_client"); xClient != "" {
|
||||
xClients := strings.Split(xClient, ",")
|
||||
for i, x := range xClients {
|
||||
xClients[i] = scope.ScopeGoogleCrossClient + x
|
||||
for _, x := range xClients {
|
||||
scopes = append(scopes, scope.ScopeGoogleCrossClient+x)
|
||||
}
|
||||
q := u.Query()
|
||||
scope := q.Get("scope")
|
||||
scopes := strings.Split(scope, " ")
|
||||
scopes = append(scopes, xClients...)
|
||||
scope = strings.Join(scopes, " ")
|
||||
q.Set("scope", scope)
|
||||
}
|
||||
|
||||
if extraScopes := r.Form.Get("extra_scopes"); extraScopes != "" {
|
||||
scopes = append(scopes, strings.Split(extraScopes, ",")...)
|
||||
}
|
||||
|
||||
if scopes != nil {
|
||||
q.Set("scope", strings.Join(scopes, " "))
|
||||
u.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
|
@ -292,57 +299,69 @@ func handleResendFunc(c *oidc.Client, issuerURL, resendURL, cbURL url.URL) http.
|
|||
|
||||
func handleCallbackFunc(c *oidc.Client) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
refreshToken := r.URL.Query().Get("refresh_token")
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
|
||||
oac, err := c.OAuthClient()
|
||||
if err != nil {
|
||||
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to create OAuth2 client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
var token oauth2.TokenResponse
|
||||
|
||||
switch {
|
||||
case code != "":
|
||||
if token, err = oac.RequestToken(oauth2.GrantTypeAuthCode, code); err != nil {
|
||||
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to verify auth code with issuer: %v", err))
|
||||
return
|
||||
}
|
||||
case refreshToken != "":
|
||||
if token, err = oac.RequestToken(oauth2.GrantTypeRefreshToken, refreshToken); err != nil {
|
||||
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to refresh token: %v", err))
|
||||
return
|
||||
}
|
||||
if token.RefreshToken == "" {
|
||||
token.RefreshToken = refreshToken
|
||||
}
|
||||
default:
|
||||
phttp.WriteError(w, http.StatusBadRequest, "code query param must be set")
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := exchangeAuthCode(c, code)
|
||||
tok, err := jose.ParseJWT(token.IDToken)
|
||||
if err != nil {
|
||||
phttp.WriteError(w, http.StatusBadRequest,
|
||||
fmt.Sprintf("unable to verify auth code with issuer: %v", err))
|
||||
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to parse JWT: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
tok, err := jose.ParseJWT(tokens.IDToken)
|
||||
if err != nil {
|
||||
phttp.WriteError(w, http.StatusBadRequest,
|
||||
fmt.Sprintf("unable to parse JWT: %v", err))
|
||||
claims := new(bytes.Buffer)
|
||||
if err := json.Indent(claims, tok.Payload, "", " "); err != nil {
|
||||
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to construct claims: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := tok.Claims()
|
||||
if err != nil {
|
||||
phttp.WriteError(w, http.StatusBadRequest,
|
||||
fmt.Sprintf("unable to construct claims: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s := fmt.Sprintf(`
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
/* make pre wrap */
|
||||
pre {
|
||||
white-space: pre-wrap; /* css-3 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
white-space: -pre-wrap; /* Opera 4-6 */
|
||||
white-space: -o-pre-wrap; /* Opera 7 */
|
||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<p> Token: %v</p>
|
||||
<p> Claims: %v </p>
|
||||
<a href="/resend?jwt=%s">Resend Verification Email</a>
|
||||
<p> Refresh Token: %v </p>
|
||||
<p> Token: <pre><code>%v</code></pre></p>
|
||||
<p> Claims: <pre><code>%v</code></pre></p>
|
||||
<p> Refresh Token: <pre><code>%v</code></pre></p>
|
||||
<p><a href="%s?refresh_token=%s">Redeem refresh token</a><p>
|
||||
<p><a href="/resend?jwt=%s">Resend Verification Email</a></p>
|
||||
</body>
|
||||
</html>`, tok.Encode(), claims, tok.Encode(), tokens.RefreshToken)
|
||||
</html>`, tok.Encode(), claims.String(), token.RefreshToken, r.URL.Path, token.RefreshToken, tok.Encode())
|
||||
w.Write([]byte(s))
|
||||
}
|
||||
}
|
||||
|
||||
func exchangeAuthCode(c *oidc.Client, code string) (oauth2.TokenResponse, error) {
|
||||
oac, err := c.OAuthClient()
|
||||
if err != nil {
|
||||
return oauth2.TokenResponse{}, err
|
||||
}
|
||||
|
||||
t, err := oac.RequestToken(oauth2.GrantTypeAuthCode, code)
|
||||
if err != nil {
|
||||
return oauth2.TokenResponse{}, err
|
||||
}
|
||||
|
||||
return t, nil
|
||||
|
||||
}
|
||||
|
|
|
@ -20,7 +20,10 @@ import (
|
|||
var (
|
||||
testRefreshClientID = "client1"
|
||||
testRefreshClientID2 = "client2"
|
||||
testRefreshClients = []client.LoadableClient{
|
||||
|
||||
testRefreshConnectorID = "IDPC-1"
|
||||
|
||||
testRefreshClients = []client.LoadableClient{
|
||||
{
|
||||
Client: client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
|
@ -59,7 +62,7 @@ var (
|
|||
},
|
||||
RemoteIdentities: []user.RemoteIdentity{
|
||||
{
|
||||
ConnectorID: "IDPC-1",
|
||||
ConnectorID: testRefreshConnectorID,
|
||||
ID: "RID-1",
|
||||
},
|
||||
},
|
||||
|
@ -103,12 +106,12 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
|
|||
|
||||
for i, tt := range tests {
|
||||
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
||||
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes)
|
||||
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: failed to create refresh token: %v", i, err)
|
||||
}
|
||||
|
||||
tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
|
||||
tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
|
||||
if tt.wantVerifyErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil error.", i)
|
||||
|
@ -126,6 +129,10 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
|
|||
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
|
||||
testRefreshUserID, tokUserID)
|
||||
}
|
||||
|
||||
if gotConnectorID != testRefreshConnectorID {
|
||||
t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,7 +145,7 @@ func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
|
|||
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
|
||||
r := db.NewRefreshTokenRepo(connect(t))
|
||||
|
||||
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
|
||||
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
@ -209,7 +216,7 @@ func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
result, _, err := r.Verify(tt.creds.ID, tt.token)
|
||||
result, _, _, err := r.Verify(tt.creds.ID, tt.token)
|
||||
if err != tt.err {
|
||||
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
||||
}
|
||||
|
@ -232,7 +239,7 @@ func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
|
|||
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
||||
|
||||
for _, clientID := range tt.clientIDs {
|
||||
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
|
||||
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
||||
}
|
||||
|
@ -281,7 +288,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
|
|||
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
||||
|
||||
for _, clientID := range tt.createIDs {
|
||||
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
|
||||
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
||||
}
|
||||
|
@ -318,7 +325,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
|
|||
func TestRefreshRepoRevoke(t *testing.T) {
|
||||
r := db.NewRefreshTokenRepo(connect(t))
|
||||
|
||||
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
|
||||
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
|
|
@ -104,6 +104,13 @@ func TestSessionRepoCreateGet(t *testing.T) {
|
|||
ExpiresAt: time.Unix(789, 0).UTC(),
|
||||
Nonce: "oncenay",
|
||||
},
|
||||
session.Session{
|
||||
ID: "anID",
|
||||
ClientState: "blargh",
|
||||
ExpiresAt: time.Unix(789, 0).UTC(),
|
||||
Nonce: "oncenay",
|
||||
Groups: []string{"group1", "group2"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
|
|
@ -149,7 +149,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
|||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||
for _, user := range userUsers {
|
||||
if _, err := refreshRepo.Create(user.User.ID, testClientID,
|
||||
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
|
||||
"", append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
|
||||
panic("Failed to create refresh token: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,12 +44,12 @@ type RefreshTokenRepo interface {
|
|||
// The scopes will be stored with the refresh token, and used to verify
|
||||
// against future OIDC refresh requests' scopes.
|
||||
// On success the token will be returned.
|
||||
Create(userID, clientID string, scope []string) (string, error)
|
||||
Create(userID, clientID, connectorID string, scope []string) (string, error)
|
||||
|
||||
// Verify verifies that a token belongs to the client.
|
||||
// It returns the user ID to which the token belongs, and the scopes stored
|
||||
// with token.
|
||||
Verify(clientID, token string) (string, scope.Scopes, error)
|
||||
Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error)
|
||||
|
||||
// Revoke deletes the refresh token if the token belongs to the given userID.
|
||||
Revoke(userID, token string) error
|
||||
|
|
|
@ -6,6 +6,9 @@ const (
|
|||
// Scope prefix which indicates initiation of a cross-client authentication flow.
|
||||
// See https://developers.google.com/identity/protocols/CrossClientAuth
|
||||
ScopeGoogleCrossClient = "audience:server:client_id:"
|
||||
|
||||
// ScopeGroups indicates that groups should be added to the ID Token.
|
||||
ScopeGroups = "groups"
|
||||
)
|
||||
|
||||
type Scopes []string
|
||||
|
|
|
@ -421,6 +421,7 @@ func validateScopes(srv OIDCServer, clientID string, scopes []string) error {
|
|||
foundOpenIDScope = true
|
||||
case curScope == "profile":
|
||||
case curScope == "email":
|
||||
case curScope == scope.ScopeGroups:
|
||||
case curScope == "offline_access":
|
||||
// According to the spec, for offline_access scope, the client must
|
||||
// use a response_type value that would result in an Authorization
|
||||
|
|
100
server/server.go
100
server/server.go
|
@ -75,7 +75,8 @@ type Server struct {
|
|||
OOBTemplate *template.Template
|
||||
|
||||
HealthChecks []health.Checkable
|
||||
Connectors []connector.Connector
|
||||
// TODO(ericchiang): Make this a map of ID to connector.
|
||||
Connectors []connector.Connector
|
||||
|
||||
ClientRepo client.ClientRepo
|
||||
ConnectorConfigRepo connector.ConnectorConfigRepo
|
||||
|
@ -306,6 +307,15 @@ func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL ur
|
|||
return s.SessionManager.NewSessionKey(sessionID)
|
||||
}
|
||||
|
||||
func (s *Server) connector(id string) (connector.Connector, bool) {
|
||||
for _, c := range s.Connectors {
|
||||
if c.ID() == id {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
||||
sessionID, err := s.SessionManager.ExchangeKey(key)
|
||||
if err != nil {
|
||||
|
@ -318,6 +328,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
|||
}
|
||||
log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident)
|
||||
|
||||
// Get the connector used to log the user in.
|
||||
conn, ok := s.connector(ses.ConnectorID)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
|
||||
}
|
||||
|
||||
// If the client has requested access to groups, add them here.
|
||||
if ses.Scope.HasScope(scope.ScopeGroups) {
|
||||
grouper, ok := conn.(connector.GroupsConnector)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("scope %q provided but connector does not support groups", scope.ScopeGroups)
|
||||
}
|
||||
groups, err := grouper.Groups(ident.ID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to retrieve user groups for %q %v", ident.ID, err)
|
||||
}
|
||||
|
||||
// Update the session.
|
||||
if ses, err = s.SessionManager.AttachGroups(sessionID, groups); err != nil {
|
||||
return "", fmt.Errorf("failed save groups")
|
||||
}
|
||||
}
|
||||
|
||||
if ses.Register {
|
||||
code, err := s.SessionManager.NewSessionKey(sessionID)
|
||||
if err != nil {
|
||||
|
@ -334,18 +367,6 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
|||
|
||||
remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID}
|
||||
|
||||
// Get the connector used to log the user in.
|
||||
var conn connector.Connector
|
||||
for _, c := range s.Connectors {
|
||||
if c.ID() == ses.ConnectorID {
|
||||
conn = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if conn == nil {
|
||||
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
|
||||
}
|
||||
|
||||
usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity)
|
||||
if err == user.ErrorNotFound {
|
||||
if ses.Identity.Email == "" {
|
||||
|
@ -508,7 +529,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
|||
if scope == "offline_access" {
|
||||
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
|
||||
|
||||
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope)
|
||||
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.ConnectorID, ses.Scope)
|
||||
switch err {
|
||||
case nil:
|
||||
break
|
||||
|
@ -535,7 +556,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
|||
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||
}
|
||||
|
||||
userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
||||
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
||||
switch err {
|
||||
case nil:
|
||||
break
|
||||
|
@ -555,7 +576,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
|||
}
|
||||
}
|
||||
|
||||
user, err := s.UserRepo.Get(nil, userID)
|
||||
usr, err := s.UserRepo.Get(nil, userID)
|
||||
if err != nil {
|
||||
// The error can be user.ErrorNotFound, but we are not deleting
|
||||
// user at this moment, so this shouldn't happen.
|
||||
|
@ -563,6 +584,43 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
|||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
|
||||
var groups []string
|
||||
if rtScopes.HasScope(scope.ScopeGroups) {
|
||||
conn, ok := s.connector(connectorID)
|
||||
if !ok {
|
||||
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
|
||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
|
||||
grouper, ok := conn.(connector.GroupsConnector)
|
||||
if !ok {
|
||||
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
|
||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
|
||||
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get remote identities: %v", err)
|
||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
|
||||
for _, ri := range remoteIdentities {
|
||||
if ri.ConnectorID == connectorID {
|
||||
return ri, true
|
||||
}
|
||||
}
|
||||
return user.RemoteIdentity{}, false
|
||||
}()
|
||||
if !ok {
|
||||
log.Errorf("failed to get remote identity for connector %s", connectorID)
|
||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
|
||||
log.Errorf("failed to get groups for refresh token: %v", connectorID)
|
||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
}
|
||||
|
||||
signer, err := s.KeyManager.Signer()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to refresh ID token: %v", err)
|
||||
|
@ -572,8 +630,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
|||
now := time.Now()
|
||||
expireAt := now.Add(session.DefaultSessionValidityWindow)
|
||||
|
||||
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
|
||||
user.AddToClaims(claims)
|
||||
claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt)
|
||||
usr.AddToClaims(claims)
|
||||
if rtScopes.HasScope(scope.ScopeGroups) {
|
||||
if groups == nil {
|
||||
groups = []string{}
|
||||
}
|
||||
claims["groups"] = groups
|
||||
}
|
||||
|
||||
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
|
||||
|
||||
|
|
|
@ -785,8 +785,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
t.Errorf("case %d: error creating other client: %v", i, err)
|
||||
}
|
||||
|
||||
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID,
|
||||
tt.createScopes); err != nil {
|
||||
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, "", tt.createScopes); err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -144,6 +144,18 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.S
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) AttachGroups(sessionID string, groups []string) (*session.Session, error) {
|
||||
s, err := m.sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Groups = groups
|
||||
if err = m.sessions.Update(*s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
|
||||
s, err := m.sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
|
|
|
@ -55,6 +55,9 @@ type Session struct {
|
|||
// Scope is the 'scope' field in the authentication request. Example scopes
|
||||
// are 'openid', 'email', 'offline', etc.
|
||||
Scope scope.Scopes
|
||||
|
||||
// Groups the user belongs to.
|
||||
Groups []string
|
||||
}
|
||||
|
||||
// Claims returns a new set of Claims for the current session.
|
||||
|
@ -65,5 +68,8 @@ func (s *Session) Claims(issuerURL string) jose.Claims {
|
|||
if s.Nonce != "" {
|
||||
claims["nonce"] = s.Nonce
|
||||
}
|
||||
if s.Scope.HasScope(scope.ScopeGroups) {
|
||||
claims["groups"] = s.Groups
|
||||
}
|
||||
return claims
|
||||
}
|
||||
|
|
|
@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
}
|
||||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||
for _, token := range refreshTokens {
|
||||
if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil {
|
||||
if _, err := refreshRepo.Create(token.userID, token.clientID, "local", []string{"openid"}); err != nil {
|
||||
panic("Failed to create refresh token: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
|
Reference in a new issue