Merge pull request #465 from bobbyrullo/cross_client_refresh_tokens
Cross client refresh tokens
This commit is contained in:
commit
b7e19b6e84
17 changed files with 679 additions and 407 deletions
|
@ -39,7 +39,8 @@ CREATE TABLE refresh_token (
|
||||||
id integer PRIMARY KEY,
|
id integer PRIMARY KEY,
|
||||||
payload_hash blob,
|
payload_hash blob,
|
||||||
user_id text,
|
user_id text,
|
||||||
client_id text
|
client_id text,
|
||||||
|
scopes text
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE remote_identity_mapping (
|
CREATE TABLE remote_identity_mapping (
|
||||||
|
|
4
db/migrations/0013_add_scopes_to_refresh_tokens.sql
Normal file
4
db/migrations/0013_add_scopes_to_refresh_tokens.sql
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
-- +migrate Up
|
||||||
|
ALTER TABLE refresh_token ADD COLUMN "scopes" text;
|
||||||
|
|
||||||
|
UPDATE refresh_token SET scopes = 'openid profile email offline_access';
|
|
@ -78,5 +78,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
|
||||||
"-- +migrate Up\nCREATE TABLE IF NOT EXISTS \"trusted_peers\" (\n \"client_id\" text not null,\n \"trusted_client_id\" text not null,\n primary key (\"client_id\", \"trusted_client_id\")) ;\n",
|
"-- +migrate Up\nCREATE TABLE IF NOT EXISTS \"trusted_peers\" (\n \"client_id\" text not null,\n \"trusted_client_id\" text not null,\n primary key (\"client_id\", \"trusted_client_id\")) ;\n",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "0013_add_scopes_to_refresh_tokens.sql",
|
||||||
|
Up: []string{
|
||||||
|
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/refresh"
|
"github.com/coreos/dex/refresh"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
|
"github.com/coreos/dex/scope"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -38,10 +39,9 @@ type refreshTokenRepo struct {
|
||||||
type refreshTokenModel struct {
|
type refreshTokenModel struct {
|
||||||
ID int64 `db:"id"`
|
ID int64 `db:"id"`
|
||||||
PayloadHash []byte `db:"payload_hash"`
|
PayloadHash []byte `db:"payload_hash"`
|
||||||
// TODO(yifan): Use some sort of foreign key to manage database level
|
UserID string `db:"user_id"`
|
||||||
// data integrity.
|
ClientID string `db:"client_id"`
|
||||||
UserID string `db:"user_id"`
|
Scopes string `db:"scopes"`
|
||||||
ClientID string `db:"client_id"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildToken combines the token ID and token payload to create a new token.
|
// buildToken combines the token ID and token payload to create a new token.
|
||||||
|
@ -89,7 +89,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) {
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
return "", refresh.ErrorInvalidUserID
|
return "", refresh.ErrorInvalidUserID
|
||||||
}
|
}
|
||||||
|
@ -112,6 +112,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
||||||
PayloadHash: payloadHash,
|
PayloadHash: payloadHash,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
|
Scopes: strings.Join(scopes, " "),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.executor(nil).Insert(record); err != nil {
|
if err := r.executor(nil).Insert(record); err != nil {
|
||||||
|
@ -121,27 +122,32 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
||||||
return buildToken(record.ID, tokenPayload), nil
|
return buildToken(record.ID, tokenPayload), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) {
|
func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) {
|
||||||
tokenID, tokenPayload, err := parseToken(token)
|
tokenID, tokenPayload, err := parseToken(token)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
record, err := r.get(nil, tokenID)
|
record, err := r.get(nil, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if record.ClientID != clientID {
|
if record.ClientID != clientID {
|
||||||
return "", refresh.ErrorInvalidClientID
|
return "", nil, refresh.ErrorInvalidClientID
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return record.UserID, nil
|
var scopes []string
|
||||||
|
if len(record.Scopes) > 0 {
|
||||||
|
scopes = strings.Split(record.Scopes, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return record.UserID, scopes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
||||||
|
@ -190,7 +196,6 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli
|
||||||
q := `SELECT c.* FROM %s as c
|
q := `SELECT c.* FROM %s as c
|
||||||
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
|
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
|
||||||
q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName))
|
q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName))
|
||||||
|
|
||||||
var clients []clientModel
|
var clients []clientModel
|
||||||
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
|
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -206,6 +211,7 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli
|
||||||
// Do not share the secret.
|
// Do not share the secret.
|
||||||
c[i].Credentials.Secret = ""
|
c[i].Credentials.Secret = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package functional
|
package functional
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
@ -16,7 +15,6 @@ import (
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
"github.com/coreos/dex/client/manager"
|
"github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/refresh"
|
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -411,207 +409,3 @@ func TestDBClientAll(t *testing.T) {
|
||||||
t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count)
|
t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildRefreshToken combines the token ID and token payload to create a new token.
|
|
||||||
// used in the tests to created a refresh token.
|
|
||||||
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
|
|
||||||
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDBRefreshRepoCreate(t *testing.T) {
|
|
||||||
r := db.NewRefreshTokenRepo(connect(t))
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
userID string
|
|
||||||
clientID string
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"",
|
|
||||||
"client-foo",
|
|
||||||
refresh.ErrorInvalidUserID,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"user-foo",
|
|
||||||
"",
|
|
||||||
refresh.ErrorInvalidClientID,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"user-foo",
|
|
||||||
"client-foo",
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
token, err := r.Create(tt.userID, tt.clientID)
|
|
||||||
if err != nil {
|
|
||||||
if tt.err == nil {
|
|
||||||
t.Errorf("case %d: create failed: %v", i, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if tt.err != nil {
|
|
||||||
t.Errorf("case %d: expected error, didn't get one", i)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
userID, err := r.Verify(tt.clientID, token)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: failed to verify good token: %v", i, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if userID != tt.userID {
|
|
||||||
t.Errorf("case %d: want userID=%s, got userID=%s", i, tt.userID, userID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDBRefreshRepoVerify(t *testing.T) {
|
|
||||||
r := db.NewRefreshTokenRepo(connect(t))
|
|
||||||
|
|
||||||
token, err := r.Create("user-foo", "client-foo")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
tokenWithBadID := "404" + token[1:]
|
|
||||||
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
token string
|
|
||||||
creds oidc.ClientCredentials
|
|
||||||
err error
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"invalid-token-format",
|
|
||||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
"",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"b/invalid-base64-encoded-format",
|
|
||||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
"",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"1/invalid-base64-encoded-format",
|
|
||||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
"",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
token + "corrupted-token-payload",
|
|
||||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
"",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// The token's ID content is invalid.
|
|
||||||
tokenWithBadID,
|
|
||||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
"",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// The token's payload content is invalid.
|
|
||||||
tokenWithBadPayload,
|
|
||||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
"",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
token,
|
|
||||||
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
|
|
||||||
refresh.ErrorInvalidClientID,
|
|
||||||
"",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
token,
|
|
||||||
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
||||||
nil,
|
|
||||||
"user-foo",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
result, err := r.Verify(tt.creds.ID, tt.token)
|
|
||||||
if err != tt.err {
|
|
||||||
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
|
||||||
}
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDBRefreshRepoRevoke(t *testing.T) {
|
|
||||||
r := db.NewRefreshTokenRepo(connect(t))
|
|
||||||
|
|
||||||
token, err := r.Create("user-foo", "client-foo")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
tokenWithBadID := "404" + token[1:]
|
|
||||||
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
token string
|
|
||||||
userID string
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"invalid-token-format",
|
|
||||||
"user-foo",
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"1/invalid-base64-encoded-format",
|
|
||||||
"user-foo",
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
token + "corrupted-token-payload",
|
|
||||||
"user-foo",
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// The token's ID is invalid.
|
|
||||||
tokenWithBadID,
|
|
||||||
"user-foo",
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// The token's payload is invalid.
|
|
||||||
tokenWithBadPayload,
|
|
||||||
"user-foo",
|
|
||||||
refresh.ErrorInvalidToken,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
token,
|
|
||||||
"invalid-user",
|
|
||||||
refresh.ErrorInvalidUserID,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
token,
|
|
||||||
"user-foo",
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
if err := r.Revoke(tt.userID, tt.token); err != tt.err {
|
|
||||||
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,13 +2,13 @@ package repo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
"github.com/go-gorp/gorp"
|
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
@ -17,40 +17,43 @@ import (
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.Client) refresh.RefreshTokenRepo {
|
var (
|
||||||
var dbMap *gorp.DbMap
|
testRefreshClientID = "client1"
|
||||||
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
|
testRefreshClientID2 = "client2"
|
||||||
dbMap = db.NewMemDB()
|
testRefreshClients = []client.LoadableClient{
|
||||||
} else {
|
|
||||||
dbMap = connect(t)
|
|
||||||
}
|
|
||||||
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
|
|
||||||
t.Fatalf("Unable to add users: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return db.NewRefreshTokenRepo(dbMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRefreshTokenRepo(t *testing.T) {
|
|
||||||
clientID := "client1"
|
|
||||||
userID := "user1"
|
|
||||||
clients := []client.Client{
|
|
||||||
{
|
{
|
||||||
Credentials: oidc.ClientCredentials{
|
Client: client.Client{
|
||||||
ID: clientID,
|
Credentials: oidc.ClientCredentials{
|
||||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
|
ID: testRefreshClientID,
|
||||||
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
|
||||||
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
},
|
||||||
RedirectURIs: []url.URL{
|
{
|
||||||
url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"},
|
Client: client.Client{
|
||||||
|
Credentials: oidc.ClientCredentials{
|
||||||
|
ID: testRefreshClientID2,
|
||||||
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
|
||||||
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
url.URL{Scheme: "https", Host: "client2.example.com", Path: "/callback"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
users := []user.UserWithRemoteIdentities{
|
|
||||||
|
testRefreshUserID = "user1"
|
||||||
|
testRefreshUsers = []user.UserWithRemoteIdentities{
|
||||||
{
|
{
|
||||||
User: user.User{
|
User: user.User{
|
||||||
ID: userID,
|
ID: testRefreshUserID,
|
||||||
Email: "Email-1@example.com",
|
Email: "Email-1@example.com",
|
||||||
CreatedAt: time.Now().Truncate(time.Second),
|
CreatedAt: time.Now().Truncate(time.Second),
|
||||||
},
|
},
|
||||||
|
@ -62,31 +65,318 @@ func TestRefreshTokenRepo(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
repo := newRefreshRepo(t, users, clients)
|
func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.LoadableClient) refresh.RefreshTokenRepo {
|
||||||
tok, err := repo.Create(userID, clientID)
|
dbMap := connect(t)
|
||||||
if err != nil {
|
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
|
||||||
t.Fatalf("failed to create refresh token: %v", err)
|
t.Fatalf("Unable to add users: %v", err)
|
||||||
}
|
|
||||||
if tokUserID, err := repo.Verify(clientID, tok); err != nil {
|
|
||||||
t.Errorf("Could not verify token: %v", err)
|
|
||||||
} else if tokUserID != userID {
|
|
||||||
t.Errorf("Verified token returned wrong user id, want=%s, got=%s", userID, tokUserID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if userClients, err := repo.ClientsWithRefreshTokens(userID); err != nil {
|
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil {
|
||||||
t.Errorf("Failed to get the list of clients the user was logged into: %v", err)
|
t.Fatalf("Unable to add clients: %v", err)
|
||||||
} else {
|
}
|
||||||
if diff := pretty.Compare(userClients, clients); diff == "" {
|
|
||||||
t.Errorf("Clients user logged into: want did not equal got %s", diff)
|
return db.NewRefreshTokenRepo(dbMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenRepoCreateVerify(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
createScopes []string
|
||||||
|
verifyClientID string
|
||||||
|
wantVerifyErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
createScopes: []string{"openid", "profile"},
|
||||||
|
verifyClientID: testRefreshClientID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
createScopes: []string{},
|
||||||
|
verifyClientID: testRefreshClientID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
createScopes: []string{"openid", "profile"},
|
||||||
|
verifyClientID: "not-a-client",
|
||||||
|
wantVerifyErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
||||||
|
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("case %d: failed to create refresh token: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
|
||||||
|
if tt.wantVerifyErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("case %d: want non-nil error.", i)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := pretty.Compare(tt.createScopes, gotScopes); diff != "" {
|
||||||
|
t.Errorf("case %d: Compare(want, got): %v", i, diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: Could not verify token: %v", i, err)
|
||||||
|
} else if tokUserID != testRefreshUserID {
|
||||||
|
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
|
||||||
|
testRefreshUserID, tokUserID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := repo.RevokeTokensForClient(userID, clientID); err != nil {
|
// buildRefreshToken combines the token ID and token payload to create a new token.
|
||||||
t.Errorf("Failed to revoke refresh token: %v", err)
|
// used in the tests to created a refresh token.
|
||||||
|
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
|
||||||
|
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
|
||||||
|
r := db.NewRefreshTokenRepo(connect(t))
|
||||||
|
|
||||||
|
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := repo.Verify(clientID, tok); err == nil {
|
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
||||||
t.Errorf("Token which should have been revoked was verified")
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
tokenWithBadID := "404" + token[1:]
|
||||||
|
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
token string
|
||||||
|
creds oidc.ClientCredentials
|
||||||
|
err error
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"invalid-token-format",
|
||||||
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"b/invalid-base64-encoded-format",
|
||||||
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"1/invalid-base64-encoded-format",
|
||||||
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token + "corrupted-token-payload",
|
||||||
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// The token's ID content is invalid.
|
||||||
|
tokenWithBadID,
|
||||||
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// The token's payload content is invalid.
|
||||||
|
tokenWithBadPayload,
|
||||||
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token,
|
||||||
|
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
|
||||||
|
refresh.ErrorInvalidClientID,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token,
|
||||||
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
||||||
|
nil,
|
||||||
|
"user-foo",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
result, _, err := r.Verify(tt.creds.ID, tt.token)
|
||||||
|
if err != tt.err {
|
||||||
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
||||||
|
}
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
clientIDs []string
|
||||||
|
}{
|
||||||
|
{clientIDs: []string{"client1", "client2"}},
|
||||||
|
{clientIDs: []string{"client1"}},
|
||||||
|
{clientIDs: []string{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
||||||
|
|
||||||
|
for _, clientID := range tt.clientIDs {
|
||||||
|
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("case %d: unexpected error fetching clients %q", i, err)
|
||||||
|
}
|
||||||
|
var clientIDs []string
|
||||||
|
for _, client := range clients {
|
||||||
|
clientIDs = append(clientIDs, client.Credentials.ID)
|
||||||
|
}
|
||||||
|
sort.Strings(clientIDs)
|
||||||
|
|
||||||
|
if diff := pretty.Compare(clientIDs, tt.clientIDs); diff != "" {
|
||||||
|
t.Errorf("case %d: Compare(want, got): %v", i, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
createIDs []string
|
||||||
|
revokeID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
createIDs: []string{"client1", "client2"},
|
||||||
|
revokeID: "client1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
createIDs: []string{"client2"},
|
||||||
|
revokeID: "client1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
createIDs: []string{"client1"},
|
||||||
|
revokeID: "client1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
createIDs: []string{},
|
||||||
|
revokeID: "oops",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
||||||
|
|
||||||
|
for _, clientID := range tt.createIDs {
|
||||||
|
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := repo.RevokeTokensForClient(testRefreshUserID, tt.revokeID); err != nil {
|
||||||
|
t.Fatalf("case %d: couldn't revoke refresh token(s): %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var wantIDs []string
|
||||||
|
for _, id := range tt.createIDs {
|
||||||
|
if id != tt.revokeID {
|
||||||
|
wantIDs = append(wantIDs, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("case %d: unexpected error fetching clients %q", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotIDs []string
|
||||||
|
for _, client := range clients {
|
||||||
|
gotIDs = append(gotIDs, client.Credentials.ID)
|
||||||
|
}
|
||||||
|
sort.Strings(gotIDs)
|
||||||
|
|
||||||
|
if diff := pretty.Compare(wantIDs, gotIDs); diff != "" {
|
||||||
|
t.Errorf("case %d: Compare(wantIDs, gotIDs): %v", i, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshRepoRevoke(t *testing.T) {
|
||||||
|
r := db.NewRefreshTokenRepo(connect(t))
|
||||||
|
|
||||||
|
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
tokenWithBadID := "404" + token[1:]
|
||||||
|
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
token string
|
||||||
|
userID string
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"invalid-token-format",
|
||||||
|
"user-foo",
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"1/invalid-base64-encoded-format",
|
||||||
|
"user-foo",
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token + "corrupted-token-payload",
|
||||||
|
"user-foo",
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// The token's ID is invalid.
|
||||||
|
tokenWithBadID,
|
||||||
|
"user-foo",
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// The token's payload is invalid.
|
||||||
|
tokenWithBadPayload,
|
||||||
|
"user-foo",
|
||||||
|
refresh.ErrorInvalidToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token,
|
||||||
|
"invalid-user",
|
||||||
|
refresh.ErrorInvalidUserID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token,
|
||||||
|
"user-foo",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
if err := r.Revoke(tt.userID, tt.token); err != tt.err {
|
||||||
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,8 @@ import (
|
||||||
func connect(t *testing.T) *gorp.DbMap {
|
func connect(t *testing.T) *gorp.DbMap {
|
||||||
dsn := os.Getenv("DEX_TEST_DSN")
|
dsn := os.Getenv("DEX_TEST_DSN")
|
||||||
if dsn == "" {
|
if dsn == "" {
|
||||||
t.Fatal("DEX_TEST_DSN environment variable not set")
|
return db.NewMemDB()
|
||||||
|
|
||||||
}
|
}
|
||||||
c, err := db.NewConnection(db.Config{DSN: dsn})
|
c, err := db.NewConnection(db.Config{DSN: dsn})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -231,7 +231,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
||||||
|
|
||||||
// this will actually happen due to some interaction between the
|
// this will actually happen due to some interaction between the
|
||||||
// end-user and a remote identity provider
|
// end-user and a remote identity provider
|
||||||
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
|
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access", "email", "profile"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,7 +148,8 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
||||||
|
|
||||||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||||
for _, user := range userUsers {
|
for _, user := range userUsers {
|
||||||
if _, err := refreshRepo.Create(user.User.ID, testClientID); err != nil {
|
if _, err := refreshRepo.Create(user.User.ID, testClientID,
|
||||||
|
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
|
||||||
panic("Failed to create refresh token: " + err.Error())
|
panic("Failed to create refresh token: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/scope"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -40,12 +41,15 @@ func DefaultRefreshTokenGenerator() ([]byte, error) {
|
||||||
|
|
||||||
type RefreshTokenRepo interface {
|
type RefreshTokenRepo interface {
|
||||||
// Create generates and returns a new refresh token for the given client-user pair.
|
// Create generates and returns a new refresh token for the given client-user pair.
|
||||||
// On success the token will be return.
|
// The scopes will be stored with the refresh token, and used to verify
|
||||||
Create(userID, clientID string) (string, error)
|
// against future OIDC refresh requests' scopes.
|
||||||
|
// On success the token will be returned.
|
||||||
|
Create(userID, clientID string, scope []string) (string, error)
|
||||||
|
|
||||||
// Verify verifies that a token belongs to the client, and returns the corresponding user ID.
|
// Verify verifies that a token belongs to the client.
|
||||||
// Note that this assumes the client validation is currently done in the application layer,
|
// It returns the user ID to which the token belongs, and the scopes stored
|
||||||
Verify(clientID, token string) (string, error)
|
// with token.
|
||||||
|
Verify(clientID, token string) (string, scope.Scopes, error)
|
||||||
|
|
||||||
// Revoke deletes the refresh token if the token belongs to the given userID.
|
// Revoke deletes the refresh token if the token belongs to the given userID.
|
||||||
Revoke(userID, token string) error
|
Revoke(userID, token string) error
|
||||||
|
|
|
@ -32,3 +32,17 @@ func (s Scopes) CrossClientIDs() []string {
|
||||||
}
|
}
|
||||||
return clients
|
return clients
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s Scopes) Contains(other Scopes) bool {
|
||||||
|
rScopes := map[string]struct{}{}
|
||||||
|
for _, scope := range s {
|
||||||
|
rScopes[scope] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, scope := range other {
|
||||||
|
if _, ok := rScopes[scope]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
|
@ -14,29 +14,24 @@ import (
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
clientmanager "github.com/coreos/dex/client/manager"
|
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/scope"
|
"github.com/coreos/dex/scope"
|
||||||
)
|
)
|
||||||
|
|
||||||
func makeCrossClientTestFixtures() (*testFixtures, error) {
|
func makeCrossClientTestFixtures() (*testFixtures, error) {
|
||||||
f, err := makeTestFixtures()
|
xClients := []client.LoadableClient{}
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("couldn't make test fixtures: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, cliData := range []struct {
|
for _, cliData := range []struct {
|
||||||
id string
|
id string
|
||||||
authorized []string
|
trustedPeers []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
id: "client_a",
|
id: "client_a",
|
||||||
}, {
|
}, {
|
||||||
id: "client_b",
|
id: "client_b",
|
||||||
authorized: []string{"client_a"},
|
trustedPeers: []string{"client_a"},
|
||||||
}, {
|
}, {
|
||||||
id: "client_c",
|
id: "client_c",
|
||||||
authorized: []string{"client_a", "client_b"},
|
trustedPeers: []string{"client_a", "client_b"},
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
u := url.URL{
|
u := url.URL{
|
||||||
|
@ -44,20 +39,27 @@ func makeCrossClientTestFixtures() (*testFixtures, error) {
|
||||||
Path: cliData.id,
|
Path: cliData.id,
|
||||||
Host: cliData.id,
|
Host: cliData.id,
|
||||||
}
|
}
|
||||||
cliCreds, err := f.clientManager.New(client.Client{
|
xClients = append(xClients, client.LoadableClient{
|
||||||
Credentials: oidc.ClientCredentials{
|
Client: client.Client{
|
||||||
ID: cliData.id,
|
Credentials: oidc.ClientCredentials{
|
||||||
|
ID: cliData.id,
|
||||||
|
Secret: base64.URLEncoding.EncodeToString(
|
||||||
|
[]byte(cliData.id + "_secret")),
|
||||||
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{u},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
TrustedPeers: cliData.trustedPeers,
|
||||||
RedirectURIs: []url.URL{u},
|
|
||||||
},
|
|
||||||
}, &clientmanager.ClientOptions{
|
|
||||||
TrustedPeers: cliData.authorized,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
}
|
||||||
return nil, fmt.Errorf("Unexpected error creating clients: %v", err)
|
|
||||||
}
|
xClients = append(xClients, testClients...)
|
||||||
f.clientCreds[cliData.id] = *cliCreds
|
f, err := makeTestFixturesWithOptions(testFixtureOptions{
|
||||||
|
clients: xClients,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("couldn't make test fixtures: %v", err)
|
||||||
}
|
}
|
||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -518,11 +518,12 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
case oauth2.GrantTypeRefreshToken:
|
case oauth2.GrantTypeRefreshToken:
|
||||||
token := r.PostForm.Get("refresh_token")
|
token := r.PostForm.Get("refresh_token")
|
||||||
|
scopes := r.PostForm.Get("scope")
|
||||||
if token == "" {
|
if token == "" {
|
||||||
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwt, err = srv.RefreshToken(creds, token)
|
jwt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeTokenError(w, err, state)
|
writeTokenError(w, err, state)
|
||||||
return
|
return
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/refresh"
|
"github.com/coreos/dex/refresh"
|
||||||
|
"github.com/coreos/dex/scope"
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session"
|
||||||
sessionmanager "github.com/coreos/dex/session/manager"
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
|
@ -53,7 +54,7 @@ type OIDCServer interface {
|
||||||
|
|
||||||
// RefreshToken takes a previously generated refresh token and returns a new ID token
|
// RefreshToken takes a previously generated refresh token and returns a new ID token
|
||||||
// if the token is valid.
|
// if the token is valid.
|
||||||
RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error)
|
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error)
|
||||||
|
|
||||||
KillSession(string) error
|
KillSession(string) error
|
||||||
|
|
||||||
|
@ -444,35 +445,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
||||||
claims := ses.Claims(s.IssuerURL.String())
|
claims := ses.Claims(s.IssuerURL.String())
|
||||||
user.AddToClaims(claims)
|
user.AddToClaims(claims)
|
||||||
|
|
||||||
crossClientIDs := ses.Scope.CrossClientIDs()
|
s.addClaimsFromScope(claims, ses.Scope, ses.ClientID)
|
||||||
if len(crossClientIDs) > 0 {
|
|
||||||
var aud []string
|
|
||||||
for _, id := range crossClientIDs {
|
|
||||||
if ses.ClientID == id {
|
|
||||||
aud = append(aud, id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
allowed, err := s.CrossClientAuthAllowed(ses.ClientID, id)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", ses.ClientID, id, err)
|
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
|
||||||
}
|
|
||||||
if !allowed {
|
|
||||||
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
|
|
||||||
err.Description = fmt.Sprintf(
|
|
||||||
"%q is not authorized to perform cross-client requests for %q",
|
|
||||||
ses.ClientID, id)
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
aud = append(aud, id)
|
|
||||||
}
|
|
||||||
if len(aud) == 1 {
|
|
||||||
claims.Add("aud", aud[0])
|
|
||||||
} else {
|
|
||||||
claims.Add("aud", aud)
|
|
||||||
}
|
|
||||||
claims.Add("azp", ses.ClientID)
|
|
||||||
}
|
|
||||||
|
|
||||||
jwt, err := jose.NewSignedJWT(claims, signer)
|
jwt, err := jose.NewSignedJWT(claims, signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -487,7 +460,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
||||||
if scope == "offline_access" {
|
if scope == "offline_access" {
|
||||||
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
|
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
|
||||||
|
|
||||||
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID)
|
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
break
|
break
|
||||||
|
@ -503,7 +476,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
||||||
return jwt, refreshToken, nil
|
return jwt, refreshToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
|
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) {
|
||||||
ok, err := s.ClientManager.Authenticate(creds)
|
ok, err := s.ClientManager.Authenticate(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||||
|
@ -514,7 +487,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
|
||||||
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
break
|
break
|
||||||
|
@ -526,6 +499,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(scopes) == 0 {
|
||||||
|
scopes = rtScopes
|
||||||
|
} else {
|
||||||
|
if !rtScopes.Contains(scopes) {
|
||||||
|
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
user, err := s.UserRepo.Get(nil, userID)
|
user, err := s.UserRepo.Get(nil, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// The error can be user.ErrorNotFound, but we are not deleting
|
// The error can be user.ErrorNotFound, but we are not deleting
|
||||||
|
@ -546,6 +527,8 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
|
||||||
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
|
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
|
||||||
user.AddToClaims(claims)
|
user.AddToClaims(claims)
|
||||||
|
|
||||||
|
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
|
||||||
|
|
||||||
jwt, err := jose.NewSignedJWT(claims, signer)
|
jwt, err := jose.NewSignedJWT(claims, signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to generate ID token: %v", err)
|
log.Errorf("Failed to generate ID token: %v", err)
|
||||||
|
@ -587,6 +570,41 @@ func (s *Server) JWTVerifierFactory() JWTVerifierFactory {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addClaimsFromScope adds claims that are based on the scopes that the client requested.
|
||||||
|
// Currently, these include cross-client claims (aud, azp).
|
||||||
|
func (s *Server) addClaimsFromScope(claims jose.Claims, scopes scope.Scopes, clientID string) error {
|
||||||
|
crossClientIDs := scopes.CrossClientIDs()
|
||||||
|
if len(crossClientIDs) > 0 {
|
||||||
|
var aud []string
|
||||||
|
for _, id := range crossClientIDs {
|
||||||
|
if clientID == id {
|
||||||
|
aud = append(aud, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
allowed, err := s.CrossClientAuthAllowed(clientID, id)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", clientID, id, err)
|
||||||
|
return oauth2.NewError(oauth2.ErrorServerError)
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
|
||||||
|
err.Description = fmt.Sprintf(
|
||||||
|
"%q is not authorized to perform cross-client requests for %q",
|
||||||
|
clientID, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
aud = append(aud, id)
|
||||||
|
}
|
||||||
|
if len(aud) == 1 {
|
||||||
|
claims.Add("aud", aud[0])
|
||||||
|
} else {
|
||||||
|
claims.Add("aud", aud)
|
||||||
|
}
|
||||||
|
claims.Add("azp", clientID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type sortableIDPCs []connector.Connector
|
type sortableIDPCs []connector.Connector
|
||||||
|
|
||||||
func (s sortableIDPCs) Len() int {
|
func (s sortableIDPCs) Len() int {
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/refresh/refreshtest"
|
"github.com/coreos/dex/refresh/refreshtest"
|
||||||
|
"github.com/coreos/dex/scope"
|
||||||
"github.com/coreos/dex/session/manager"
|
"github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
)
|
)
|
||||||
|
@ -484,91 +485,197 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
// NOTE(ericchiang): These tests assume that the database ID of the first
|
// NOTE(ericchiang): These tests assume that the database ID of the first
|
||||||
// refresh token will be "1".
|
// refresh token will be "1".
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
token string
|
token string
|
||||||
clientID string // The client that associates with the token.
|
clientID string // The client that associates with the token.
|
||||||
creds oidc.ClientCredentials
|
creds oidc.ClientCredentials
|
||||||
signer jose.Signer
|
signer jose.Signer
|
||||||
err error
|
createScopes []string
|
||||||
|
refreshScopes []string
|
||||||
|
expectedAud []string
|
||||||
|
err error
|
||||||
}{
|
}{
|
||||||
// Everything is good.
|
// Everything is good.
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
testClientCredentials,
|
creds: testClientCredentials,
|
||||||
signerFixture,
|
signer: signerFixture,
|
||||||
nil,
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
},
|
||||||
|
// Asking for a scope not originally granted to you.
|
||||||
|
{
|
||||||
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
|
clientID: testClientID,
|
||||||
|
creds: testClientCredentials,
|
||||||
|
signer: signerFixture,
|
||||||
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile", "extra_scope"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||||
},
|
},
|
||||||
// Invalid refresh token(malformatted).
|
// Invalid refresh token(malformatted).
|
||||||
{
|
{
|
||||||
"invalid-token",
|
token: "invalid-token",
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
testClientCredentials,
|
creds: testClientCredentials,
|
||||||
signerFixture,
|
signer: signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||||
},
|
},
|
||||||
// Invalid refresh token(invalid payload content).
|
// Invalid refresh token(invalid payload content).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
testClientCredentials,
|
creds: testClientCredentials,
|
||||||
signerFixture,
|
signer: signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||||
},
|
},
|
||||||
// Invalid refresh token(invalid ID content).
|
// Invalid refresh token(invalid ID content).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
testClientCredentials,
|
creds: testClientCredentials,
|
||||||
signerFixture,
|
signer: signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||||
},
|
},
|
||||||
// Invalid client(client is not associated with the token).
|
// Invalid client(client is not associated with the token).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
clientB.Credentials,
|
creds: clientB.Credentials,
|
||||||
signerFixture,
|
signer: signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
},
|
},
|
||||||
// Invalid client(no client ID).
|
// Invalid client(no client ID).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
creds: oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
||||||
signerFixture,
|
signer: signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
},
|
},
|
||||||
// Invalid client(no such client).
|
// Invalid client(no such client).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
||||||
signerFixture,
|
signer: signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
},
|
},
|
||||||
// Invalid client(no secrets).
|
// Invalid client(no secrets).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
oidc.ClientCredentials{ID: testClientID},
|
creds: oidc.ClientCredentials{ID: testClientID},
|
||||||
signerFixture,
|
signer: signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
},
|
},
|
||||||
// Invalid client(invalid secret).
|
// Invalid client(invalid secret).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
|
creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
|
||||||
signerFixture,
|
signer: signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
},
|
},
|
||||||
// Signing operation fails.
|
// Signing operation fails.
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
testClientID,
|
clientID: testClientID,
|
||||||
testClientCredentials,
|
creds: testClientCredentials,
|
||||||
&StaticSigner{sig: nil, err: errors.New("fail")},
|
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
||||||
oauth2.NewError(oauth2.ErrorServerError),
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorServerError),
|
||||||
|
},
|
||||||
|
// Valid Cross-Client
|
||||||
|
{
|
||||||
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
|
clientID: "client_a",
|
||||||
|
creds: oidc.ClientCredentials{
|
||||||
|
ID: "client_a",
|
||||||
|
Secret: base64.URLEncoding.EncodeToString(
|
||||||
|
[]byte("client_a_secret")),
|
||||||
|
},
|
||||||
|
signer: signerFixture,
|
||||||
|
createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
|
||||||
|
refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
|
||||||
|
expectedAud: []string{"client_b"},
|
||||||
|
},
|
||||||
|
// Valid Cross-Client - but this time we leave out the scopes in the
|
||||||
|
// refresh request, which should result in the original stored scopes
|
||||||
|
// being used.
|
||||||
|
{
|
||||||
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
|
clientID: "client_a",
|
||||||
|
creds: oidc.ClientCredentials{
|
||||||
|
ID: "client_a",
|
||||||
|
Secret: base64.URLEncoding.EncodeToString(
|
||||||
|
[]byte("client_a_secret")),
|
||||||
|
},
|
||||||
|
signer: signerFixture,
|
||||||
|
createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
|
||||||
|
refreshScopes: []string{},
|
||||||
|
expectedAud: []string{"client_b"},
|
||||||
|
},
|
||||||
|
// Valid Cross-Client - asking for fewer scopes than originally used
|
||||||
|
// when creating the refresh token, which is ok.
|
||||||
|
{
|
||||||
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
|
clientID: "client_a",
|
||||||
|
creds: oidc.ClientCredentials{
|
||||||
|
ID: "client_a",
|
||||||
|
Secret: base64.URLEncoding.EncodeToString(
|
||||||
|
[]byte("client_a_secret")),
|
||||||
|
},
|
||||||
|
signer: signerFixture,
|
||||||
|
createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"},
|
||||||
|
refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
|
||||||
|
expectedAud: []string{"client_b"},
|
||||||
|
},
|
||||||
|
// Valid Cross-Client - asking for multiple clients in the audience.
|
||||||
|
{
|
||||||
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
|
clientID: "client_a",
|
||||||
|
creds: oidc.ClientCredentials{
|
||||||
|
ID: "client_a",
|
||||||
|
Secret: base64.URLEncoding.EncodeToString(
|
||||||
|
[]byte("client_a_secret")),
|
||||||
|
},
|
||||||
|
signer: signerFixture,
|
||||||
|
createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"},
|
||||||
|
refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"},
|
||||||
|
expectedAud: []string{"client_b", "client_c"},
|
||||||
|
},
|
||||||
|
// Invalid Cross-Client - didn't orignally request cross-client when
|
||||||
|
// refresh token was created.
|
||||||
|
{
|
||||||
|
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
|
clientID: "client_a",
|
||||||
|
creds: oidc.ClientCredentials{
|
||||||
|
ID: "client_a",
|
||||||
|
Secret: base64.URLEncoding.EncodeToString(
|
||||||
|
[]byte("client_a_secret")),
|
||||||
|
},
|
||||||
|
signer: signerFixture,
|
||||||
|
createScopes: []string{"openid", "profile"},
|
||||||
|
refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
|
||||||
|
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -576,7 +683,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: tt.signer,
|
signer: tt.signer,
|
||||||
}
|
}
|
||||||
f, err := makeTestFixtures()
|
f, err := makeCrossClientTestFixtures()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error making test fixtures: %v", err)
|
t.Fatalf("error making test fixtures: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -587,11 +694,12 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
t.Errorf("case %d: error creating other client: %v", i, err)
|
t.Errorf("case %d: error creating other client: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID); err != nil {
|
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID,
|
||||||
|
tt.createScopes); err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwt, err := f.srv.RefreshToken(tt.creds, tt.token)
|
jwt, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token)
|
||||||
if !reflect.DeepEqual(err, tt.err) {
|
if !reflect.DeepEqual(err, tt.err) {
|
||||||
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
|
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
|
||||||
}
|
}
|
||||||
|
@ -604,8 +712,27 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Case %d: unexpected error: %v", i, err)
|
t.Errorf("Case %d: unexpected error: %v", i, err)
|
||||||
}
|
}
|
||||||
if claims["iss"] != testIssuerURL.String() || claims["sub"] != testUserID1 || claims["aud"] != testClientID {
|
|
||||||
t.Errorf("Case %d: invalid claims: %v", i, claims)
|
var expectedAud interface{}
|
||||||
|
if tt.expectedAud == nil {
|
||||||
|
expectedAud = testClientID
|
||||||
|
} else if len(tt.expectedAud) == 1 {
|
||||||
|
expectedAud = tt.expectedAud[0]
|
||||||
|
} else {
|
||||||
|
expectedAud = tt.expectedAud
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims["iss"] != testIssuerURL.String() {
|
||||||
|
t.Errorf("Case %d: want=%v, got=%v", i,
|
||||||
|
testIssuerURL.String(), claims["iss"])
|
||||||
|
}
|
||||||
|
if claims["sub"] != testUserID1 {
|
||||||
|
t.Errorf("Case %d: want=%v, got=%v", i,
|
||||||
|
testUserID1, claims["sub"])
|
||||||
|
}
|
||||||
|
if diff := pretty.Compare(claims["aud"], expectedAud); diff != "" {
|
||||||
|
t.Errorf("Case %d: want=%v, got=%v", i,
|
||||||
|
expectedAud, claims["aud"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,18 @@ var (
|
||||||
ID: testClientID,
|
ID: testClientID,
|
||||||
Secret: clientTestSecret,
|
Secret: clientTestSecret,
|
||||||
}
|
}
|
||||||
|
testClients = []client.LoadableClient{
|
||||||
|
{
|
||||||
|
Client: client.Client{
|
||||||
|
Credentials: testClientCredentials,
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
testRedirectURL,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
testConnectorID1 = "IDPC-1"
|
testConnectorID1 = "IDPC-1"
|
||||||
|
|
||||||
|
@ -169,18 +181,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
|
||||||
|
|
||||||
var clients []client.LoadableClient
|
var clients []client.LoadableClient
|
||||||
if options.clients == nil {
|
if options.clients == nil {
|
||||||
clients = []client.LoadableClient{
|
clients = testClients
|
||||||
{
|
|
||||||
Client: client.Client{
|
|
||||||
Credentials: testClientCredentials,
|
|
||||||
Metadata: oidc.ClientMetadata{
|
|
||||||
RedirectURIs: []url.URL{
|
|
||||||
testRedirectURL,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
clients = options.clients
|
clients = options.clients
|
||||||
}
|
}
|
||||||
|
@ -247,6 +248,10 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
|
||||||
srv.absURL(httpPathAcceptInvitation),
|
srv.absURL(httpPathAcceptInvitation),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
clientCreds := map[string]oidc.ClientCredentials{}
|
||||||
|
for _, c := range clients {
|
||||||
|
clientCreds[c.Client.Credentials.ID] = c.Client.Credentials
|
||||||
|
}
|
||||||
return &testFixtures{
|
return &testFixtures{
|
||||||
srv: srv,
|
srv: srv,
|
||||||
redirectURL: testRedirectURL,
|
redirectURL: testRedirectURL,
|
||||||
|
@ -255,9 +260,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
|
||||||
emailer: emailer,
|
emailer: emailer,
|
||||||
clientRepo: clientRepo,
|
clientRepo: clientRepo,
|
||||||
clientManager: clientManager,
|
clientManager: clientManager,
|
||||||
clientCreds: map[string]oidc.ClientCredentials{
|
clientCreds: clientCreds,
|
||||||
testClientID: testClientCreds,
|
|
||||||
},
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
}
|
}
|
||||||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||||
for _, token := range refreshTokens {
|
for _, token := range refreshTokens {
|
||||||
if _, err := refreshRepo.Create(token.userID, token.clientID); err != nil {
|
if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil {
|
||||||
panic("Failed to create refresh token: " + err.Error())
|
panic("Failed to create refresh token: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue