dex/functional/repo/refresh_repo_test.go
2016-07-19 11:23:04 -07:00

389 lines
9.4 KiB
Go

package repo
import (
"encoding/base64"
"fmt"
"net/url"
"sort"
"testing"
"time"
"github.com/coreos/go-oidc/oidc"
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/user"
)
var (
testRefreshClientID = "client1"
testRefreshClientID2 = "client2"
testRefreshConnectorID = "IDPC-1"
testRefreshClients = []client.LoadableClient{
{
Client: client.Client{
Credentials: oidc.ClientCredentials{
ID: testRefreshClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"},
},
},
},
},
{
Client: client.Client{
Credentials: oidc.ClientCredentials{
ID: testRefreshClientID2,
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "client2.example.com", Path: "/callback"},
},
},
},
},
}
testRefreshUserID = "user1"
testRefreshUsers = []user.UserWithRemoteIdentities{
{
User: user.User{
ID: testRefreshUserID,
Email: "Email-1@example.com",
CreatedAt: time.Now().Truncate(time.Second),
},
RemoteIdentities: []user.RemoteIdentity{
{
ConnectorID: testRefreshConnectorID,
ID: "RID-1",
},
},
},
}
)
func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.LoadableClient) refresh.RefreshTokenRepo {
dbMap := connect(t)
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
t.Fatalf("Unable to add users: %v", err)
}
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil {
t.Fatalf("Unable to add clients: %v", err)
}
return db.NewRefreshTokenRepo(dbMap)
}
func TestRefreshTokenRepoCreateVerify(t *testing.T) {
tests := []struct {
createScopes []string
verifyClientID string
wantVerifyErr bool
}{
{
createScopes: []string{"openid", "profile"},
verifyClientID: testRefreshClientID,
},
{
createScopes: []string{},
verifyClientID: testRefreshClientID,
},
{
createScopes: []string{"openid", "profile"},
verifyClientID: "not-a-client",
wantVerifyErr: true,
},
}
for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes)
if err != nil {
t.Fatalf("case %d: failed to create refresh token: %v", i, err)
}
tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
if tt.wantVerifyErr {
if err == nil {
t.Errorf("case %d: want non-nil error.", i)
}
continue
}
if diff := pretty.Compare(tt.createScopes, gotScopes); diff != "" {
t.Errorf("case %d: Compare(want, got): %v", i, diff)
}
if err != nil {
t.Errorf("case %d: Could not verify token: %v", i, err)
} else if tokUserID != testRefreshUserID {
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
testRefreshUserID, tokUserID)
}
if gotConnectorID != testRefreshConnectorID {
t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID)
}
}
}
// buildRefreshToken combines the token ID and token payload to create a new token.
// used in the tests to created a refresh token.
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct {
token string
creds oidc.ClientCredentials
err error
expected string
}{
{
"invalid-token-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"b/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"1/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token + "corrupted-token-payload",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's ID content is invalid.
tokenWithBadID,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's payload content is invalid.
tokenWithBadPayload,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token,
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
refresh.ErrorInvalidClientID,
"",
},
{
token,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
nil,
"user-foo",
},
}
for i, tt := range tests {
result, _, _, err := r.Verify(tt.creds.ID, tt.token)
if err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
if result != tt.expected {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
}
}
}
func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
tests := []struct {
clientIDs []string
}{
{clientIDs: []string{"client1", "client2"}},
{clientIDs: []string{"client1"}},
{clientIDs: []string{}},
}
for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.clientIDs {
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
}
}
clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID)
if err != nil {
t.Fatalf("case %d: unexpected error fetching clients %q", i, err)
}
var clientIDs []string
for _, client := range clients {
clientIDs = append(clientIDs, client.Credentials.ID)
}
sort.Strings(clientIDs)
if diff := pretty.Compare(clientIDs, tt.clientIDs); diff != "" {
t.Errorf("case %d: Compare(want, got): %v", i, diff)
}
}
}
func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
tests := []struct {
createIDs []string
revokeID string
}{
{
createIDs: []string{"client1", "client2"},
revokeID: "client1",
},
{
createIDs: []string{"client2"},
revokeID: "client1",
},
{
createIDs: []string{"client1"},
revokeID: "client1",
},
{
createIDs: []string{},
revokeID: "oops",
},
}
for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.createIDs {
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
}
if err := repo.RevokeTokensForClient(testRefreshUserID, tt.revokeID); err != nil {
t.Fatalf("case %d: couldn't revoke refresh token(s): %v", i, err)
}
}
var wantIDs []string
for _, id := range tt.createIDs {
if id != tt.revokeID {
wantIDs = append(wantIDs, id)
}
}
clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID)
if err != nil {
t.Fatalf("case %d: unexpected error fetching clients %q", i, err)
}
var gotIDs []string
for _, client := range clients {
gotIDs = append(gotIDs, client.Credentials.ID)
}
sort.Strings(gotIDs)
if diff := pretty.Compare(wantIDs, gotIDs); diff != "" {
t.Errorf("case %d: Compare(wantIDs, gotIDs): %v", i, diff)
}
}
}
func TestRefreshRepoRevoke(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct {
token string
userID string
err error
}{
{
"invalid-token-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
"1/invalid-base64-encoded-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
token + "corrupted-token-payload",
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's ID is invalid.
tokenWithBadID,
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's payload is invalid.
tokenWithBadPayload,
"user-foo",
refresh.ErrorInvalidToken,
},
{
token,
"invalid-user",
refresh.ErrorInvalidUserID,
},
{
token,
"user-foo",
nil,
},
}
for i, tt := range tests {
if err := r.Revoke(tt.userID, tt.token); err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
}
}