389 lines
9.4 KiB
Go
389 lines
9.4 KiB
Go
package repo
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/url"
|
|
"sort"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/oidc"
|
|
"github.com/kylelemons/godebug/pretty"
|
|
|
|
"github.com/coreos/dex/client"
|
|
"github.com/coreos/dex/db"
|
|
"github.com/coreos/dex/refresh"
|
|
"github.com/coreos/dex/user"
|
|
)
|
|
|
|
var (
|
|
testRefreshClientID = "client1"
|
|
testRefreshClientID2 = "client2"
|
|
|
|
testRefreshConnectorID = "IDPC-1"
|
|
|
|
testRefreshClients = []client.LoadableClient{
|
|
{
|
|
Client: client.Client{
|
|
Credentials: oidc.ClientCredentials{
|
|
ID: testRefreshClientID,
|
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
|
|
},
|
|
Metadata: oidc.ClientMetadata{
|
|
RedirectURIs: []url.URL{
|
|
url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Client: client.Client{
|
|
Credentials: oidc.ClientCredentials{
|
|
ID: testRefreshClientID2,
|
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
|
|
},
|
|
Metadata: oidc.ClientMetadata{
|
|
RedirectURIs: []url.URL{
|
|
url.URL{Scheme: "https", Host: "client2.example.com", Path: "/callback"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
testRefreshUserID = "user1"
|
|
testRefreshUsers = []user.UserWithRemoteIdentities{
|
|
{
|
|
User: user.User{
|
|
ID: testRefreshUserID,
|
|
Email: "Email-1@example.com",
|
|
CreatedAt: time.Now().Truncate(time.Second),
|
|
},
|
|
RemoteIdentities: []user.RemoteIdentity{
|
|
{
|
|
ConnectorID: testRefreshConnectorID,
|
|
ID: "RID-1",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
)
|
|
|
|
func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.LoadableClient) refresh.RefreshTokenRepo {
|
|
dbMap := connect(t)
|
|
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
|
|
t.Fatalf("Unable to add users: %v", err)
|
|
}
|
|
|
|
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil {
|
|
t.Fatalf("Unable to add clients: %v", err)
|
|
}
|
|
|
|
return db.NewRefreshTokenRepo(dbMap)
|
|
}
|
|
|
|
func TestRefreshTokenRepoCreateVerify(t *testing.T) {
|
|
tests := []struct {
|
|
createScopes []string
|
|
verifyClientID string
|
|
wantVerifyErr bool
|
|
}{
|
|
{
|
|
createScopes: []string{"openid", "profile"},
|
|
verifyClientID: testRefreshClientID,
|
|
},
|
|
{
|
|
createScopes: []string{},
|
|
verifyClientID: testRefreshClientID,
|
|
},
|
|
{
|
|
createScopes: []string{"openid", "profile"},
|
|
verifyClientID: "not-a-client",
|
|
wantVerifyErr: true,
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
|
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes)
|
|
if err != nil {
|
|
t.Fatalf("case %d: failed to create refresh token: %v", i, err)
|
|
}
|
|
|
|
tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
|
|
if tt.wantVerifyErr {
|
|
if err == nil {
|
|
t.Errorf("case %d: want non-nil error.", i)
|
|
}
|
|
continue
|
|
}
|
|
|
|
if diff := pretty.Compare(tt.createScopes, gotScopes); diff != "" {
|
|
t.Errorf("case %d: Compare(want, got): %v", i, diff)
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("case %d: Could not verify token: %v", i, err)
|
|
} else if tokUserID != testRefreshUserID {
|
|
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
|
|
testRefreshUserID, tokUserID)
|
|
}
|
|
|
|
if gotConnectorID != testRefreshConnectorID {
|
|
t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID)
|
|
}
|
|
}
|
|
}
|
|
|
|
// buildRefreshToken combines the token ID and token payload to create a new token.
|
|
// used in the tests to created a refresh token.
|
|
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
|
|
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
|
|
}
|
|
|
|
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
|
|
r := db.NewRefreshTokenRepo(connect(t))
|
|
|
|
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
tokenWithBadID := "404" + token[1:]
|
|
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
|
|
|
tests := []struct {
|
|
token string
|
|
creds oidc.ClientCredentials
|
|
err error
|
|
expected string
|
|
}{
|
|
{
|
|
"invalid-token-format",
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
"b/invalid-base64-encoded-format",
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
"1/invalid-base64-encoded-format",
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
token + "corrupted-token-payload",
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
// The token's ID content is invalid.
|
|
tokenWithBadID,
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
// The token's payload content is invalid.
|
|
tokenWithBadPayload,
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
token,
|
|
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidClientID,
|
|
"",
|
|
},
|
|
{
|
|
token,
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
nil,
|
|
"user-foo",
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
result, _, _, err := r.Verify(tt.creds.ID, tt.token)
|
|
if err != tt.err {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
|
}
|
|
if result != tt.expected {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
|
|
tests := []struct {
|
|
clientIDs []string
|
|
}{
|
|
{clientIDs: []string{"client1", "client2"}},
|
|
{clientIDs: []string{"client1"}},
|
|
{clientIDs: []string{}},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
|
|
|
for _, clientID := range tt.clientIDs {
|
|
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
|
|
if err != nil {
|
|
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
|
}
|
|
}
|
|
|
|
clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID)
|
|
if err != nil {
|
|
t.Fatalf("case %d: unexpected error fetching clients %q", i, err)
|
|
}
|
|
var clientIDs []string
|
|
for _, client := range clients {
|
|
clientIDs = append(clientIDs, client.Credentials.ID)
|
|
}
|
|
sort.Strings(clientIDs)
|
|
|
|
if diff := pretty.Compare(clientIDs, tt.clientIDs); diff != "" {
|
|
t.Errorf("case %d: Compare(want, got): %v", i, diff)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
|
|
tests := []struct {
|
|
createIDs []string
|
|
revokeID string
|
|
}{
|
|
{
|
|
createIDs: []string{"client1", "client2"},
|
|
revokeID: "client1",
|
|
},
|
|
{
|
|
createIDs: []string{"client2"},
|
|
revokeID: "client1",
|
|
},
|
|
{
|
|
createIDs: []string{"client1"},
|
|
revokeID: "client1",
|
|
},
|
|
{
|
|
createIDs: []string{},
|
|
revokeID: "oops",
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
|
|
|
|
for _, clientID := range tt.createIDs {
|
|
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
|
|
if err != nil {
|
|
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
|
|
}
|
|
|
|
if err := repo.RevokeTokensForClient(testRefreshUserID, tt.revokeID); err != nil {
|
|
t.Fatalf("case %d: couldn't revoke refresh token(s): %v", i, err)
|
|
}
|
|
}
|
|
|
|
var wantIDs []string
|
|
for _, id := range tt.createIDs {
|
|
if id != tt.revokeID {
|
|
wantIDs = append(wantIDs, id)
|
|
}
|
|
}
|
|
|
|
clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID)
|
|
if err != nil {
|
|
t.Fatalf("case %d: unexpected error fetching clients %q", i, err)
|
|
}
|
|
|
|
var gotIDs []string
|
|
for _, client := range clients {
|
|
gotIDs = append(gotIDs, client.Credentials.ID)
|
|
}
|
|
sort.Strings(gotIDs)
|
|
|
|
if diff := pretty.Compare(wantIDs, gotIDs); diff != "" {
|
|
t.Errorf("case %d: Compare(wantIDs, gotIDs): %v", i, diff)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRefreshRepoRevoke(t *testing.T) {
|
|
r := db.NewRefreshTokenRepo(connect(t))
|
|
|
|
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
tokenWithBadID := "404" + token[1:]
|
|
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
|
|
|
tests := []struct {
|
|
token string
|
|
userID string
|
|
err error
|
|
}{
|
|
{
|
|
"invalid-token-format",
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
"1/invalid-base64-encoded-format",
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
token + "corrupted-token-payload",
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
// The token's ID is invalid.
|
|
tokenWithBadID,
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
// The token's payload is invalid.
|
|
tokenWithBadPayload,
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
token,
|
|
"invalid-user",
|
|
refresh.ErrorInvalidUserID,
|
|
},
|
|
{
|
|
token,
|
|
"user-foo",
|
|
nil,
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
if err := r.Revoke(tt.userID, tt.token); err != tt.err {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
|
}
|
|
}
|
|
}
|