Merge pull request #386 from ericchiang/revoke_refresh_2

add ability to revoke refresh tokens in user API
This commit is contained in:
Eric Chiang 2016-04-06 13:45:23 -07:00
commit cd7d3fff85
15 changed files with 780 additions and 31 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc"
) )
const ( const (
@ -179,6 +180,35 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
return tx.Commit() return tx.Commit()
} }
func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error {
q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName))
_, err := r.executor(nil).Exec(q, userID, clientID)
return err
}
func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]oidc.ClientIdentity, error) {
q := `SELECT c.* FROM %s as c
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
q = fmt.Sprintf(q, r.quote(clientIdentityTableName), r.quote(refreshTokenTableName))
var clients []clientIdentityModel
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
return nil, err
}
c := make([]oidc.ClientIdentity, len(clients))
for i, client := range clients {
ident, err := client.ClientIdentity()
if err != nil {
return nil, err
}
c[i] = *ident
// Do not share the secret.
c[i].Credentials.Secret = ""
}
return c, nil
}
func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) { func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx) ex := r.executor(tx)
result, err := ex.Get(refreshTokenModel{}, tokenID) result, err := ex.Get(refreshTokenModel{}, tokenID)

View file

@ -24,7 +24,8 @@ var (
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{ url.URL{
Scheme: "https", Scheme: "https",
Host: "client1.example.com/callback", Host: "client1.example.com",
Path: "/callback",
}, },
}, },
}, },
@ -38,7 +39,8 @@ var (
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{ url.URL{
Scheme: "https", Scheme: "https",
Host: "client2.example.com/callback", Host: "client2.example.com",
Path: "/callback",
}, },
}, },
}, },

View file

@ -0,0 +1,93 @@
package repo
import (
"encoding/base64"
"net/url"
"os"
"testing"
"time"
"github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/user"
)
func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []oidc.ClientIdentity) refresh.RefreshTokenRepo {
var dbMap *gorp.DbMap
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
dbMap = db.NewMemDB()
} else {
dbMap = connect(t)
}
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
t.Fatalf("Unable to add users: %v", err)
}
if _, err := db.NewClientIdentityRepoFromClients(dbMap, clients); err != nil {
t.Fatalf("Unable to add clients: %v", err)
}
return db.NewRefreshTokenRepo(dbMap)
}
func TestRefreshTokenRepo(t *testing.T) {
clientID := "client1"
userID := "user1"
clients := []oidc.ClientIdentity{
{
Credentials: oidc.ClientCredentials{
ID: clientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"},
},
},
},
}
users := []user.UserWithRemoteIdentities{
{
User: user.User{
ID: userID,
Email: "Email-1@example.com",
CreatedAt: time.Now().Truncate(time.Second),
},
RemoteIdentities: []user.RemoteIdentity{
{
ConnectorID: "IDPC-1",
ID: "RID-1",
},
},
},
}
repo := newRefreshRepo(t, users, clients)
tok, err := repo.Create(userID, clientID)
if err != nil {
t.Fatalf("failed to create refresh token: %v", err)
}
if tokUserID, err := repo.Verify(clientID, tok); err != nil {
t.Errorf("Could not verify token: %v", err)
} else if tokUserID != userID {
t.Errorf("Verified token returned wrong user id, want=%s, got=%s", userID, tokUserID)
}
if userClients, err := repo.ClientsWithRefreshTokens(userID); err != nil {
t.Errorf("Failed to get the list of clients the user was logged into: %v", err)
} else {
if diff := pretty.Compare(userClients, clients); diff == "" {
t.Errorf("Clients user logged into: want did not equal got %s", diff)
}
}
if err := repo.RevokeTokensForClient(userID, clientID); err != nil {
t.Errorf("Failed to revoke refresh token: %v", err)
}
if _, err := repo.Verify(clientID, tok); err == nil {
t.Errorf("Token which should have been revoked was verified")
}
}

View file

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"sort"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -99,10 +100,9 @@ var (
func makeUserAPITestFixtures() *userAPITestFixtures { func makeUserAPITestFixtures() *userAPITestFixtures {
f := &userAPITestFixtures{} f := &userAPITestFixtures{}
_, _, _, um := makeUserObjects(userUsers, userPasswords) dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
cir := func() client.ClientIdentityRepo { cir := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ repo, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: testClientID, ID: testClientID,
@ -144,8 +144,16 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
return oidc.NewJWTVerifier(testIssuerURL.String(), clientID, noop, keysFunc) return oidc.NewJWTVerifier(testIssuerURL.String(), clientID, noop, keysFunc)
} }
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, user := range userUsers {
if _, err := refreshRepo.Create(user.User.ID, testClientID); err != nil {
panic("Failed to create refresh token: " + err.Error())
}
}
f.emailer = &testEmailer{} f.emailer = &testEmailer{}
api := api.NewUsersAPI(um, cir, f.emailer, "local") um.Clock = clock
api := api.NewUsersAPI(dbMap, um, f.emailer, "local")
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir) usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir)
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler()) f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
@ -584,6 +592,48 @@ func TestDisableUser(t *testing.T) {
} }
} }
func TestRefreshTokenEndpoints(t *testing.T) {
tests := []struct {
userID string
clients []string
}{
{"ID-1", []string{testClientID}},
{"ID-2", []string{testClientID}},
}
for i, tt := range tests {
f := makeUserAPITestFixtures()
list, err := f.client.RefreshClient.List(tt.userID).Do()
if err != nil {
t.Errorf("case %d: list clients: %v", i, err)
continue
}
var ids []string
for _, client := range list.Clients {
ids = append(ids, client.ClientID)
}
sort.Strings(ids)
sort.Strings(tt.clients)
if diff := pretty.Compare(tt.clients, ids); diff != "" {
t.Errorf("case %d: expected client ids did not match actual: %s", i, diff)
}
for _, clientID := range ids {
if err := f.client.Clients.Revoke(tt.userID, clientID).Do(); err != nil {
t.Errorf("case %d: failed to revoke client: %v", i, err)
}
}
list, err = f.client.RefreshClient.List(tt.userID).Do()
if err != nil {
t.Errorf("case %d: list clients after revocation: %v", i, err)
continue
}
if n := len(list.Clients); n != 0 {
t.Errorf("case %d: expected no refresh tokens after revocation, got %d", i, n)
}
}
}
func TestResendEmailInvitation(t *testing.T) { func TestResendEmailInvitation(t *testing.T) {
tests := []struct { tests := []struct {
req schema.ResendEmailInvitationRequest req schema.ResendEmailInvitationRequest

View file

@ -3,6 +3,8 @@ package refresh
import ( import (
"crypto/rand" "crypto/rand"
"errors" "errors"
"github.com/coreos/go-oidc/oidc"
) )
const ( const (
@ -47,4 +49,10 @@ type RefreshTokenRepo interface {
// Revoke deletes the refresh token if the token belongs to the given userID. // Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error Revoke(userID, token string) error
// RevokeTokensForClient revokes all tokens issued for the userID for the provided client.
RevokeTokensForClient(userID, clientID string) error
// ClientsWithRefreshTokens returns a list of all clients the user has an outstanding client with.
ClientsWithRefreshTokens(userID string) ([]oidc.ClientIdentity, error)
} }

View file

@ -33,8 +33,15 @@ fi
$GENDOC --f $IN --o $DOC $GENDOC --f $IN --o $DOC
# See schema/generator_import.go for instructions on updating the dependency # Though google-api-go-generator is a main, dex vendors the app using the same
PKG="google.golang.org/api/google-api-go-generator" # tool it uses to vendor third party packages. Hence, it can be found in the
# "vendor" directory.
#
# This vendoring is currently done with godep, but may change if/when we move to
# another tool.
#
# See schema/generator_import.go for instructions on updating the dependency.
PKG="github.com/coreos/dex/vendor/google.golang.org/api/google-api-go-generator"
# First, write the discovery document into a go file so it can be served statically by the API # First, write the discovery document into a go file so it can be served statically by the API
cat << EOF > "${OUT}" cat << EOF > "${OUT}"
@ -53,7 +60,7 @@ echo -n '`' >> "${OUT}"
# Now build google-api-go-generator - we vendor so this is consistently reproducible # Now build google-api-go-generator - we vendor so this is consistently reproducible
GEN_PATH="bin/google-api-go-generator" GEN_PATH="bin/google-api-go-generator"
if [ ! -f ${GEN_PATH} ]; then if [ ! -f ${GEN_PATH} ]; then
GOPATH="${PWD}/Godeps/_workspace" go build -o ${GEN_PATH} ${PKG} go build -o ${GEN_PATH} ${PKG}
fi fi
# Build the bindings # Build the bindings

View file

@ -59,6 +59,31 @@ __Version:__ v1
} }
``` ```
### RefreshClient
A client with associated public metadata.
```
{
clientID: string,
clientName: string,
clientURI: string,
logoURI: string
}
```
### RefreshClientList
```
{
clients: [
RefreshClient
]
}
```
### ResendEmailInvitationRequest ### ResendEmailInvitationRequest
@ -166,6 +191,58 @@ __Version:__ v1
## Paths ## Paths
### GET /account/{userid}/refresh
> __Summary__
> List RefreshClient
> __Description__
> List all clients that hold refresh tokens for the authenticated user.
> __Parameters__
> |Name|Located in|Description|Required|Type|
|:-----|:-----|:-----|:-----|:-----|
| userid | path | | Yes | string |
> __Responses__
> |Code|Description|Type|
|:-----|:-----|:-----|
| 200 | | [RefreshClientList](#refreshclientlist) |
| default | Unexpected error | |
### DELETE /account/{userid}/refresh/{clientid}
> __Summary__
> Revoke Clients
> __Description__
> Revoke all refresh tokens issues to the client for the authenticated user.
> __Parameters__
> |Name|Located in|Description|Required|Type|
|:-----|:-----|:-----|:-----|:-----|
| clientid | path | | Yes | string |
| userid | path | | Yes | string |
> __Responses__
> |Code|Description|Type|
|:-----|:-----|:-----|
| default | Unexpected error | |
### GET /clients ### GET /clients
> __Summary__ > __Summary__

View file

@ -14,12 +14,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"google.golang.org/api/googleapi"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"google.golang.org/api/googleapi"
) )
// Always reference these packages, just in case the auto-generated code // Always reference these packages, just in case the auto-generated code
@ -45,6 +46,7 @@ func New(client *http.Client) (*Service, error) {
} }
s := &Service{client: client, BasePath: basePath} s := &Service{client: client, BasePath: basePath}
s.Clients = NewClientsService(s) s.Clients = NewClientsService(s)
s.RefreshClient = NewRefreshClientService(s)
s.Users = NewUsersService(s) s.Users = NewUsersService(s)
return s, nil return s, nil
} }
@ -55,6 +57,8 @@ type Service struct {
Clients *ClientsService Clients *ClientsService
RefreshClient *RefreshClientService
Users *UsersService Users *UsersService
} }
@ -67,6 +71,15 @@ type ClientsService struct {
s *Service s *Service
} }
func NewRefreshClientService(s *Service) *RefreshClientService {
rs := &RefreshClientService{s: s}
return rs
}
type RefreshClientService struct {
s *Service
}
func NewUsersService(s *Service) *UsersService { func NewUsersService(s *Service) *UsersService {
rs := &UsersService{s: s} rs := &UsersService{s: s}
return rs return rs
@ -102,6 +115,20 @@ type Error struct {
Error_description string `json:"error_description,omitempty"` Error_description string `json:"error_description,omitempty"`
} }
type RefreshClient struct {
ClientID string `json:"clientID,omitempty"`
ClientName string `json:"clientName,omitempty"`
ClientURI string `json:"clientURI,omitempty"`
LogoURI string `json:"logoURI,omitempty"`
}
type RefreshClientList struct {
Clients []*RefreshClient `json:"clients,omitempty"`
}
type ResendEmailInvitationRequest struct { type ResendEmailInvitationRequest struct {
RedirectURL string `json:"redirectURL,omitempty"` RedirectURL string `json:"redirectURL,omitempty"`
} }
@ -307,6 +334,154 @@ func (c *ClientsListCall) Do() (*ClientPage, error) {
} }
// method id "dex.Client.Revoke":
type ClientsRevokeCall struct {
s *Service
userid string
clientid string
opt_ map[string]interface{}
}
// Revoke: Revoke all refresh tokens issues to the client for the
// authenticated user.
func (r *ClientsService) Revoke(userid string, clientid string) *ClientsRevokeCall {
c := &ClientsRevokeCall{s: r.s, opt_: make(map[string]interface{})}
c.userid = userid
c.clientid = clientid
return c
}
// Fields allows partial responses to be retrieved.
// See https://developers.google.com/gdata/docs/2.0/basics#PartialResponse
// for more information.
func (c *ClientsRevokeCall) Fields(s ...googleapi.Field) *ClientsRevokeCall {
c.opt_["fields"] = googleapi.CombineFields(s)
return c
}
func (c *ClientsRevokeCall) Do() error {
var body io.Reader = nil
params := make(url.Values)
params.Set("alt", "json")
if v, ok := c.opt_["fields"]; ok {
params.Set("fields", fmt.Sprintf("%v", v))
}
urls := googleapi.ResolveRelative(c.s.BasePath, "account/{userid}/refresh/{clientid}")
urls += "?" + params.Encode()
req, _ := http.NewRequest("DELETE", urls, body)
googleapi.Expand(req.URL, map[string]string{
"userid": c.userid,
"clientid": c.clientid,
})
req.Header.Set("User-Agent", "google-api-go-client/0.5")
res, err := c.s.client.Do(req)
if err != nil {
return err
}
defer googleapi.CloseBody(res)
if err := googleapi.CheckResponse(res); err != nil {
return err
}
return nil
// {
// "description": "Revoke all refresh tokens issues to the client for the authenticated user.",
// "httpMethod": "DELETE",
// "id": "dex.Client.Revoke",
// "parameterOrder": [
// "userid",
// "clientid"
// ],
// "parameters": {
// "clientid": {
// "location": "path",
// "required": true,
// "type": "string"
// },
// "userid": {
// "location": "path",
// "required": true,
// "type": "string"
// }
// },
// "path": "account/{userid}/refresh/{clientid}"
// }
}
// method id "dex.Client.List":
type RefreshClientListCall struct {
s *Service
userid string
opt_ map[string]interface{}
}
// List: List all clients that hold refresh tokens for the authenticated
// user.
func (r *RefreshClientService) List(userid string) *RefreshClientListCall {
c := &RefreshClientListCall{s: r.s, opt_: make(map[string]interface{})}
c.userid = userid
return c
}
// Fields allows partial responses to be retrieved.
// See https://developers.google.com/gdata/docs/2.0/basics#PartialResponse
// for more information.
func (c *RefreshClientListCall) Fields(s ...googleapi.Field) *RefreshClientListCall {
c.opt_["fields"] = googleapi.CombineFields(s)
return c
}
func (c *RefreshClientListCall) Do() (*RefreshClientList, error) {
var body io.Reader = nil
params := make(url.Values)
params.Set("alt", "json")
if v, ok := c.opt_["fields"]; ok {
params.Set("fields", fmt.Sprintf("%v", v))
}
urls := googleapi.ResolveRelative(c.s.BasePath, "account/{userid}/refresh")
urls += "?" + params.Encode()
req, _ := http.NewRequest("GET", urls, body)
googleapi.Expand(req.URL, map[string]string{
"userid": c.userid,
})
req.Header.Set("User-Agent", "google-api-go-client/0.5")
res, err := c.s.client.Do(req)
if err != nil {
return nil, err
}
defer googleapi.CloseBody(res)
if err := googleapi.CheckResponse(res); err != nil {
return nil, err
}
var ret *RefreshClientList
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
return nil, err
}
return ret, nil
// {
// "description": "List all clients that hold refresh tokens for the authenticated user.",
// "httpMethod": "GET",
// "id": "dex.Client.List",
// "parameterOrder": [
// "userid"
// ],
// "parameters": {
// "userid": {
// "location": "path",
// "required": true,
// "type": "string"
// }
// },
// "path": "account/{userid}/refresh",
// "response": {
// "$ref": "RefreshClientList"
// }
// }
}
// method id "dex.User.Create": // method id "dex.User.Create":
type UsersCreateCall struct { type UsersCreateCall struct {

View file

@ -55,6 +55,37 @@ const DiscoveryJSON = `{
} }
} }
}, },
"RefreshClient": {
"id": "Client",
"type": "object",
"description": "A client with associated public metadata.",
"properties": {
"clientID": {
"type": "string"
},
"clientName": {
"type": "string"
},
"logoURI": {
"type": "string"
},
"clientURI": {
"type": "string"
}
}
},
"RefreshClientList": {
"id": "RefreshClientList",
"type": "object",
"properties": {
"clients": {
"type": "array",
"items": {
"$ref": "RefreshClient"
}
}
}
},
"ClientWithSecret": { "ClientWithSecret": {
"id": "Client", "id": "Client",
"type": "object", "type": "object",
@ -241,6 +272,27 @@ const DiscoveryJSON = `{
"response": { "response": {
"$ref": "ClientWithSecret" "$ref": "ClientWithSecret"
} }
},
"Revoke": {
"id": "dex.Client.Revoke",
"description": "Revoke all refresh tokens issues to the client for the authenticated user.",
"httpMethod": "DELETE",
"path": "account/{userid}/refresh/{clientid}",
"parameterOrder": [
"userid","clientid"
],
"parameters": {
"clientid": {
"type": "string",
"required": true,
"location": "path"
},
"userid": {
"type": "string",
"required": true,
"location": "path"
}
}
} }
} }
}, },
@ -341,6 +393,29 @@ const DiscoveryJSON = `{
} }
} }
} }
},
"RefreshClient": {
"methods": {
"List": {
"id": "dex.Client.List",
"description": "List all clients that hold refresh tokens for the authenticated user.",
"httpMethod": "GET",
"path": "account/{userid}/refresh",
"parameters": {
"userid": {
"type": "string",
"required": true,
"location": "path"
}
},
"parameterOrder": [
"userid"
],
"response": {
"$ref": "RefreshClientList"
}
}
}
} }
} }
} }

View file

@ -49,6 +49,37 @@
} }
} }
}, },
"RefreshClient": {
"id": "Client",
"type": "object",
"description": "A client with associated public metadata.",
"properties": {
"clientID": {
"type": "string"
},
"clientName": {
"type": "string"
},
"logoURI": {
"type": "string"
},
"clientURI": {
"type": "string"
}
}
},
"RefreshClientList": {
"id": "RefreshClientList",
"type": "object",
"properties": {
"clients": {
"type": "array",
"items": {
"$ref": "RefreshClient"
}
}
}
},
"ClientWithSecret": { "ClientWithSecret": {
"id": "Client", "id": "Client",
"type": "object", "type": "object",
@ -235,6 +266,27 @@
"response": { "response": {
"$ref": "ClientWithSecret" "$ref": "ClientWithSecret"
} }
},
"Revoke": {
"id": "dex.Client.Revoke",
"description": "Revoke all refresh tokens issues to the client for the authenticated user.",
"httpMethod": "DELETE",
"path": "account/{userid}/refresh/{clientid}",
"parameterOrder": [
"userid","clientid"
],
"parameters": {
"clientid": {
"type": "string",
"required": true,
"location": "path"
},
"userid": {
"type": "string",
"required": true,
"location": "path"
}
}
} }
} }
}, },
@ -335,6 +387,29 @@
} }
} }
} }
},
"RefreshClient": {
"methods": {
"List": {
"id": "dex.Client.List",
"description": "List all clients that hold refresh tokens for the authenticated user.",
"httpMethod": "GET",
"path": "account/{userid}/refresh",
"parameters": {
"userid": {
"type": "string",
"required": true,
"location": "path"
}
},
"parameterOrder": [
"userid"
],
"response": {
"$ref": "RefreshClientList"
}
}
}
} }
} }
} }

View file

@ -164,6 +164,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
srv.SessionManager = sm srv.SessionManager = sm
srv.RefreshTokenRepo = refTokRepo srv.RefreshTokenRepo = refTokRepo
srv.HealthChecks = append(srv.HealthChecks, db.NewHealthChecker(dbMap)) srv.HealthChecks = append(srv.HealthChecks, db.NewHealthChecker(dbMap))
srv.dbMap = dbMap
return nil return nil
} }
@ -290,6 +291,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
srv.SessionManager = sm srv.SessionManager = sm
srv.RefreshTokenRepo = refreshTokenRepo srv.RefreshTokenRepo = refreshTokenRepo
srv.HealthChecks = append(srv.HealthChecks, db.NewHealthChecker(dbc)) srv.HealthChecks = append(srv.HealthChecks, db.NewHealthChecker(dbc))
srv.dbMap = dbc
return nil return nil
} }

View file

@ -15,6 +15,7 @@ import (
"github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/coreos/pkg/health" "github.com/coreos/pkg/health"
"github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
@ -77,6 +78,7 @@ type Server struct {
EnableRegistration bool EnableRegistration bool
EnableClientRegistration bool EnableClientRegistration bool
dbMap *gorp.DbMap
localConnectorID string localConnectorID string
} }
@ -257,12 +259,10 @@ func (s *Server) HTTPHandler() http.Handler {
clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientIdentityRepo) clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientIdentityRepo)
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler)) mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID) usersAPI := usersapi.NewUsersAPI(s.dbMap, s.UserManager, s.UserEmailer, s.localConnectorID)
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler() handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()
path := path.Join(apiBasePath, UsersSubTree)
mux.Handle(path, handler) mux.Handle(apiBasePath+"/", handler)
mux.Handle(path+"/", handler)
return http.Handler(mux) return http.Handler(mux)
} }

View file

@ -30,6 +30,9 @@ var (
UsersGetEndpoint = addBasePath(UsersSubTree + "/:id") UsersGetEndpoint = addBasePath(UsersSubTree + "/:id")
UsersDisableEndpoint = addBasePath(UsersSubTree + "/:id/disable") UsersDisableEndpoint = addBasePath(UsersSubTree + "/:id/disable")
UsersResendInvitationEndpoint = addBasePath(UsersSubTree + "/:id/resend-invitation") UsersResendInvitationEndpoint = addBasePath(UsersSubTree + "/:id/resend-invitation")
AccountSubTree = "/account"
AccountListRefreshTokens = addBasePath(AccountSubTree + "/:userid/refresh")
AccountRevokeRefreshToken = addBasePath(AccountSubTree + "/:userid/refresh/:clientid")
) )
type UserMgmtServer struct { type UserMgmtServer struct {
@ -52,26 +55,48 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler {
r := httprouter.New() r := httprouter.New()
r.RedirectTrailingSlash = false r.RedirectTrailingSlash = false
r.RedirectFixedPath = false r.RedirectFixedPath = false
r.GET(UsersListEndpoint, s.authAPIHandle(s.listUsers))
r.POST(UsersCreateEndpoint, s.authAPIHandle(s.createUser)) r.GET(UsersListEndpoint, s.authAdminUser(s.listUsers))
r.POST(UsersDisableEndpoint, s.authAPIHandle(s.disableUser)) r.POST(UsersCreateEndpoint, s.authAdminUser(s.createUser))
r.GET(UsersGetEndpoint, s.authAPIHandle(s.getUser)) r.POST(UsersDisableEndpoint, s.authAdminUser(s.disableUser))
r.POST(UsersResendInvitationEndpoint, s.authAPIHandle(s.resendInvitationEmail)) r.GET(UsersGetEndpoint, s.authAdminUser(s.getUser))
r.POST(UsersResendInvitationEndpoint, s.authAdminUser(s.resendInvitationEmail))
r.GET(AccountListRefreshTokens, s.authAccount(s.listClientsWithRefreshTokens))
r.DELETE(AccountRevokeRefreshToken, s.authAccount(s.revokeRefreshTokensForClient))
return r return r
} }
func (s *UserMgmtServer) authAdminUser(handle authedHandle) httprouter.Handle {
return s.authAPIHandle(handle, true)
}
func (s *UserMgmtServer) authAccount(handle authedHandle) httprouter.Handle {
return s.authAPIHandle(handle, false)
}
// authedHandle is an HTTP handle which requires requests to be authenticated as an admin user. // authedHandle is an HTTP handle which requires requests to be authenticated as an admin user.
type authedHandle func(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) type authedHandle func(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds)
// authAPIHandle is a middleware function with authenticates an HTTP request before passing // authAPIHandle is a middleware function with authenticates an HTTP request before passing
// it along to the authedHandle. // it along to the authedHandle.
func (s *UserMgmtServer) authAPIHandle(handle authedHandle) httprouter.Handle { //
// The authorization checks for an ID token bearer token in the request header, requiring the
// audience (aud claim) be a client ID of an admin client.
//
// If requiresAdmin is true, the subject identifier (sub claim) of the ID token provided must be
// that of an admin user.
func (s *UserMgmtServer) authAPIHandle(handle authedHandle, requiresAdmin bool) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
creds, err := s.getCreds(r) creds, err := s.getCreds(r)
if err != nil { if err != nil {
s.writeError(w, err) s.writeError(w, err)
return return
} }
if creds.User.Disabled || (requiresAdmin && !creds.User.Admin) {
s.writeError(w, api.ErrorUnauthorized)
return
}
handle(w, r, ps, creds) handle(w, r, ps, creds)
} }
} }
@ -191,6 +216,23 @@ func (s *UserMgmtServer) resendInvitationEmail(w http.ResponseWriter, r *http.Re
writeResponseWithBody(w, http.StatusOK, resendEmailInvitationResponse) writeResponseWithBody(w, http.StatusOK, resendEmailInvitationResponse)
} }
func (s *UserMgmtServer) listClientsWithRefreshTokens(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
clients, err := s.api.ListClientsWithRefreshTokens(creds, ps.ByName("userid"))
if err != nil {
s.writeError(w, err)
return
}
writeResponseWithBody(w, http.StatusOK, schema.RefreshClientList{Clients: clients})
}
func (s *UserMgmtServer) revokeRefreshTokensForClient(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
if err := s.api.RevokeRefreshTokensForClient(creds, ps.ByName("userid"), ps.ByName("clientid")); err != nil {
s.writeError(w, err)
return
}
w.WriteHeader(http.StatusOK) // NOTE (ericchiang): http.StatusNoContent or return an empty JSON object?
}
func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) { func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) {
log.Errorf("Error calling user management API: %v: ", err) log.Errorf("Error calling user management API: %v: ", err)
if apiErr, ok := err.(api.Error); ok { if apiErr, ok := err.(api.Error); ok {

View file

@ -9,8 +9,12 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" "github.com/coreos/dex/user/manager"
@ -87,6 +91,7 @@ type UsersAPI struct {
manager *manager.UserManager manager *manager.UserManager
localConnectorID string localConnectorID string
clientIdentityRepo client.ClientIdentityRepo clientIdentityRepo client.ClientIdentityRepo
refreshRepo refresh.RefreshTokenRepo
emailer Emailer emailer Emailer
} }
@ -99,10 +104,12 @@ type Creds struct {
User user.User User user.User
} }
func NewUsersAPI(manager *manager.UserManager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI { // TODO(ericchiang): Don't pass a dbMap. See #385.
func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI {
return &UsersAPI{ return &UsersAPI{
manager: manager, manager: userManager,
clientIdentityRepo: cir, refreshRepo: db.NewRefreshTokenRepo(dbMap),
clientIdentityRepo: db.NewClientIdentityRepo(dbMap),
localConnectorID: localConnectorID, localConnectorID: localConnectorID,
emailer: emailer, emailer: emailer,
} }
@ -258,6 +265,47 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string)
return list, tok, nil return list, tok, nil
} }
// ListClientsWithRefreshTokens returns all clients issued refresh tokens
// for the authenticated user.
func (u *UsersAPI) ListClientsWithRefreshTokens(creds Creds, userID string) ([]*schema.RefreshClient, error) {
// Users must either be an admin or be requesting data associated with their own account.
if !creds.User.Admin && (creds.User.ID != userID) {
return nil, ErrorUnauthorized
}
clientIdentities, err := u.refreshRepo.ClientsWithRefreshTokens(userID)
if err != nil {
return nil, err
}
clients := make([]*schema.RefreshClient, len(clientIdentities))
urlToString := func(u *url.URL) string {
if u == nil {
return ""
}
return u.String()
}
for i, identity := range clientIdentities {
clients[i] = &schema.RefreshClient{
ClientID: identity.Credentials.ID,
ClientName: identity.Metadata.ClientName,
ClientURI: urlToString(identity.Metadata.ClientURI),
LogoURI: urlToString(identity.Metadata.LogoURI),
}
}
return clients, nil
}
// RevokeClient revokes all refresh tokens issued to this client for the
// authenticiated user.
func (u *UsersAPI) RevokeRefreshTokensForClient(creds Creds, userID, clientID string) error {
// Users must either be an admin or be requesting data associated with their own account.
if !creds.User.Admin && (creds.User.ID != userID) {
return ErrorUnauthorized
}
return u.refreshRepo.RevokeTokensForClient(userID, clientID)
}
func (u *UsersAPI) Authorize(creds Creds) bool { func (u *UsersAPI) Authorize(creds Creds) bool {
return creds.User.Admin && !creds.User.Disabled return creds.User.Admin && !creds.User.Disabled
} }

View file

@ -3,6 +3,7 @@ package api
import ( import (
"encoding/base64" "encoding/base64"
"net/url" "net/url"
"sort"
"testing" "testing"
"time" "time"
@ -10,7 +11,6 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
@ -166,16 +166,27 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}, },
}, },
} }
cir := func() client.ClientIdentityRepo { if _, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci}); err != nil {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) panic("Failed to create client identity repo: " + err.Error())
if err != nil { }
panic("Failed to create client identity repo: " + err.Error())
// Used in TestRevokeRefreshToken test.
refreshTokens := []struct {
clientID string
userID string
}{
{"XXX", "ID-1"},
{"XXX", "ID-2"},
}
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens {
if _, err := refreshRepo.Create(token.userID, token.clientID); err != nil {
panic("Failed to create refresh token: " + err.Error())
} }
return repo }
}()
emailer := &testEmailer{} emailer := &testEmailer{}
api := NewUsersAPI(mgr, cir, emailer, "local") api := NewUsersAPI(dbMap, mgr, emailer, "local")
return api, emailer return api, emailer
} }
@ -562,3 +573,57 @@ func TestResendEmailInvitation(t *testing.T) {
} }
} }
} }
func TestRevokeRefreshToken(t *testing.T) {
tests := []struct {
userID string
toRevoke string
before []string // clientIDs expected before the change.
after []string // clientIDs expected after the change.
}{
{"ID-1", "XXX", []string{"XXX"}, []string{}},
{"ID-2", "XXX", []string{"XXX"}, []string{}},
}
api, _ := makeTestFixtures()
listClientsWithRefreshTokens := func(creds Creds, userID string) ([]string, error) {
clients, err := api.ListClientsWithRefreshTokens(creds, userID)
if err != nil {
return nil, err
}
clientIDs := make([]string, len(clients))
for i, client := range clients {
clientIDs[i] = client.ClientID
}
sort.Strings(clientIDs)
return clientIDs, nil
}
for i, tt := range tests {
creds := Creds{User: user.User{ID: tt.userID}}
gotBefore, err := listClientsWithRefreshTokens(creds, tt.userID)
if err != nil {
t.Errorf("case %d: list clients failed: %v", i, err)
} else {
if diff := pretty.Compare(tt.before, gotBefore); diff != "" {
t.Errorf("case %d: before exp!=got: %s", i, diff)
}
}
if err := api.RevokeRefreshTokensForClient(creds, tt.userID, tt.toRevoke); err != nil {
t.Errorf("case %d: failed to revoke client: %v", i, err)
continue
}
gotAfter, err := listClientsWithRefreshTokens(creds, tt.userID)
if err != nil {
t.Errorf("case %d: list clients failed: %v", i, err)
} else {
if diff := pretty.Compare(tt.after, gotAfter); diff != "" {
t.Errorf("case %d: after exp!=got: %s", i, diff)
}
}
}
}