dex/functional/repo/user_repo_test.go

610 lines
12 KiB
Go
Raw Normal View History

2015-08-18 05:57:27 +05:30
package repo
import (
"fmt"
"os"
"reflect"
"testing"
"time"
2016-02-10 01:52:40 +05:30
"github.com/go-gorp/gorp"
2015-08-18 05:57:27 +05:30
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/db"
"github.com/coreos/dex/user"
)
var (
testUsers = []user.UserWithRemoteIdentities{
{
User: user.User{
ID: "ID-1",
Email: "Email-1@example.com",
CreatedAt: time.Now().Truncate(time.Second),
},
RemoteIdentities: []user.RemoteIdentity{
{
ConnectorID: "IDPC-1",
ID: "RID-1",
},
},
},
{
User: user.User{
ID: "ID-2",
Email: "Email-2@example.com",
CreatedAt: time.Now(),
Disabled: true,
2015-08-18 05:57:27 +05:30
},
RemoteIdentities: []user.RemoteIdentity{
{
ConnectorID: "IDPC-2",
ID: "RID-2",
},
},
},
}
)
func newUserRepo(t *testing.T, users []user.UserWithRemoteIdentities) user.UserRepo {
if users == nil {
users = []user.UserWithRemoteIdentities{}
2015-08-18 05:57:27 +05:30
}
2016-02-10 01:52:40 +05:30
var dbMap *gorp.DbMap
if os.Getenv("DEX_TEST_DSN") == "" {
2016-02-10 01:52:40 +05:30
dbMap = db.NewMemDB()
} else {
dbMap = connect(t)
2015-08-18 05:57:27 +05:30
}
repo, err := db.NewUserRepoFromUsers(dbMap, users)
if err != nil {
t.Fatalf("Unable to add users: %v", err)
}
return repo
2015-08-18 05:57:27 +05:30
}
func TestNewUser(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
tests := []struct {
user user.User
err error
}{
{
user: user.User{
ID: "ID-bob",
Email: "bob@example.com",
CreatedAt: now,
},
err: nil,
},
{
user: user.User{
ID: "ID-admin",
Email: "admin@example.com",
Admin: true,
CreatedAt: now,
},
err: nil,
},
{
user: user.User{
ID: "ID-verified",
Email: "verified@example.com",
EmailVerified: true,
CreatedAt: now,
},
err: nil,
},
{
user: user.User{
ID: "ID-same",
Email: "Email-1@example.com",
DisplayName: "Oops Same Email",
CreatedAt: now,
},
err: user.ErrorDuplicateEmail,
},
{
user: user.User{
ID: "ID-same",
Email: "email-1@example.com",
DisplayName: "A lower case version of the original email",
CreatedAt: now,
},
err: user.ErrorDuplicateEmail,
},
2015-08-18 05:57:27 +05:30
{
user: user.User{
Email: "AnotherEmail@example.com",
DisplayName: "Can't set your own ID!",
CreatedAt: now,
},
err: user.ErrorInvalidID,
},
{
user: user.User{
ID: "ID-noemail",
DisplayName: "No Email",
CreatedAt: now,
},
err: user.ErrorInvalidEmail,
},
}
for i, tt := range tests {
repo := newUserRepo(t, testUsers)
2015-08-18 05:57:27 +05:30
err := repo.Create(nil, tt.user)
if tt.err != nil {
if err != tt.err {
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
}
} else {
if err != nil {
t.Errorf("case %d: want nil err, got %v", i, err)
}
gotUser, err := repo.Get(nil, tt.user.ID)
if err != nil {
t.Errorf("case %d: want nil err, got %v", i, err)
}
if diff := pretty.Compare(tt.user, gotUser); diff != "" {
t.Errorf("case %d: Compare(want, got) = %v", i,
diff)
}
}
}
}
func TestUpdateUser(t *testing.T) {
tests := []struct {
user user.User
err error
}{
{
// Update the email.
user: user.User{
ID: "ID-1",
Email: "Email-1.1@example.com",
},
err: nil,
},
{
// No-op.
user: user.User{
ID: "ID-1",
Email: "Email-1@example.com",
},
err: nil,
},
{
// No email.
user: user.User{
ID: "ID-1",
Email: "",
},
err: user.ErrorInvalidEmail,
},
{
// Try Update on non-existent user.
user: user.User{
ID: "NonExistent",
Email: "GoodEmail@email.com",
},
err: user.ErrorNotFound,
},
{
// Try update to someone else's email.
user: user.User{
ID: "ID-2",
Email: "Email-1@example.com",
},
err: user.ErrorDuplicateEmail,
},
}
for i, tt := range tests {
repo := newUserRepo(t, testUsers)
2015-08-18 05:57:27 +05:30
err := repo.Update(nil, tt.user)
if tt.err != nil {
if err != tt.err {
t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
}
} else {
if err != nil {
t.Errorf("case %d: want nil err, got %q", i, err)
}
gotUser, err := repo.Get(nil, tt.user.ID)
if err != nil {
t.Errorf("case %d: want nil err, got %q", i, err)
}
if diff := pretty.Compare(tt.user, gotUser); diff != "" {
t.Errorf("case %d: Compare(want, got) = %v", i,
diff)
}
}
}
}
func TestDisableUser(t *testing.T) {
tests := []struct {
id string
disable bool
err error
}{
{
id: "ID-1",
},
{
id: "ID-1",
disable: true,
},
{
id: "ID-2",
},
{
id: "ID-2",
disable: true,
},
{
id: "NO SUCH ID",
err: user.ErrorNotFound,
},
{
id: "NO SUCH ID",
err: user.ErrorNotFound,
disable: true,
},
{
id: "",
err: user.ErrorInvalidID,
},
}
for i, tt := range tests {
repo := newUserRepo(t, testUsers)
err := repo.Disable(nil, tt.id, tt.disable)
switch {
case err != tt.err:
t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
case tt.err == nil:
gotUser, err := repo.Get(nil, tt.id)
if err != nil {
t.Fatalf("case %d: want nil err, got %q", i, err)
}
if gotUser.Disabled != tt.disable {
t.Errorf("case %d: disabled status want=%v got=%v",
i, tt.disable, gotUser.Disabled)
}
}
}
}
2015-08-18 05:57:27 +05:30
func TestAttachRemoteIdentity(t *testing.T) {
tests := []struct {
id string
rid user.RemoteIdentity
err error
}{
{
id: "ID-1",
rid: user.RemoteIdentity{
ConnectorID: "IDPC-1",
ID: "RID-1.1",
},
},
{
id: "ID-1",
rid: user.RemoteIdentity{
ConnectorID: "IDPC-2",
ID: "RID-2",
},
err: user.ErrorDuplicateRemoteIdentity,
},
{
id: "NoSuchUser",
rid: user.RemoteIdentity{
ConnectorID: "IDPC-3",
ID: "RID-3",
},
err: user.ErrorNotFound,
},
}
for i, tt := range tests {
repo := newUserRepo(t, testUsers)
2015-08-18 05:57:27 +05:30
err := repo.AddRemoteIdentity(nil, tt.id, tt.rid)
if tt.err != nil {
if err != tt.err {
t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
}
} else {
if err != nil {
t.Errorf("case %d: want nil err, got %q", i, err)
}
gotUser, err := repo.GetByRemoteIdentity(nil, tt.rid)
if err != nil {
t.Errorf("case %d: want nil err, got %q", i, err)
}
wantUser, err := repo.Get(nil, tt.id)
if err != nil {
t.Errorf("case %d: want nil err, got %q", i, err)
}
gotRIDs, err := repo.GetRemoteIdentities(nil, tt.id)
if err != nil {
t.Errorf("case %d: want nil err, got %q", i, err)
}
if findRemoteIdentity(gotRIDs, tt.rid) == -1 {
t.Errorf("case %d: user.RemoteIdentity not found", i)
}
if !reflect.DeepEqual(wantUser, gotUser) {
t.Errorf("case %d: want=%#v, got=%#v", i,
wantUser, gotUser)
}
}
}
}
func TestRemoveRemoteIdentity(t *testing.T) {
tests := []struct {
id string
rid user.RemoteIdentity
err error
}{
{
id: "ID-1",
rid: user.RemoteIdentity{
ConnectorID: "IDPC-1",
ID: "RID-1",
},
},
{
id: "ID-1",
rid: user.RemoteIdentity{
ConnectorID: "IDPC-2",
ID: "RID-2",
},
err: user.ErrorNotFound,
},
{
id: "NoSuchUser",
rid: user.RemoteIdentity{
ConnectorID: "IDPC-3",
ID: "RID-3",
},
err: user.ErrorNotFound,
},
}
for i, tt := range tests {
repo := newUserRepo(t, testUsers)
2015-08-18 05:57:27 +05:30
err := repo.RemoveRemoteIdentity(nil, tt.id, tt.rid)
if tt.err != nil {
if err != tt.err {
t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
}
} else {
if err != nil {
t.Errorf("case %d: want nil err, got %q", i, err)
}
gotUser, err := repo.GetByRemoteIdentity(nil, tt.rid)
if err == nil {
if gotUser.ID == tt.id {
t.Errorf("case %d: user found.", i)
}
} else if err != user.ErrorNotFound {
t.Errorf("case %d: want %q err, got %q err", i, user.ErrorNotFound, err)
}
gotRIDs, err := repo.GetRemoteIdentities(nil, tt.id)
if err != nil {
t.Errorf("case %d: want nil err, got %q", i, err)
}
if findRemoteIdentity(gotRIDs, tt.rid) != -1 {
t.Errorf("case %d: user.RemoteIdentity found", i)
}
}
}
}
func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int {
for i, curRID := range rids {
if curRID == rid {
return i
}
}
return -1
}
func TestGetByEmail(t *testing.T) {
tests := []struct {
email string
wantEmail string
wantErr error
2015-08-18 05:57:27 +05:30
}{
{
email: "Email-1@example.com",
wantEmail: "Email-1@example.com",
wantErr: nil,
},
{
email: "EMAIL-1@example.com", // Emails should be case insensitive.
wantEmail: "Email-1@example.com",
wantErr: nil,
2015-08-18 05:57:27 +05:30
},
{
email: "NoSuchEmail@example.com",
wantErr: user.ErrorNotFound,
},
}
for i, tt := range tests {
repo := newUserRepo(t, testUsers)
2015-08-18 05:57:27 +05:30
gotUser, gotErr := repo.GetByEmail(nil, tt.email)
if tt.wantErr != nil {
if tt.wantErr != gotErr {
t.Errorf("case %d: wantErr=%q, gotErr=%q", i, tt.wantErr, gotErr)
}
continue
}
if gotErr != nil {
t.Errorf("case %d: want nil err:% q", i, gotErr)
continue
2015-08-18 05:57:27 +05:30
}
if tt.wantEmail != gotUser.Email {
2015-08-18 05:57:27 +05:30
t.Errorf("case %d: want=%q, got=%q", i, tt.email, gotUser.Email)
}
}
}
func TestGetAdminCount(t *testing.T) {
tests := []struct {
addUsers []user.User
want int
}{
{
addUsers: []user.User{
user.User{
ID: "ID-admin",
Email: "Admin@example.com",
Admin: true,
},
},
want: 1,
},
{
want: 0,
},
{
addUsers: []user.User{
user.User{
ID: "ID-admin",
Email: "NotAdmin@example.com",
},
},
want: 0,
},
{
addUsers: []user.User{
user.User{
ID: "ID-admin",
Email: "Admin@example.com",
Admin: true,
},
user.User{
ID: "ID-admin2",
Email: "AnotherAdmin@example.com",
Admin: true,
},
},
want: 2,
},
}
for i, tt := range tests {
repo := newUserRepo(t, testUsers)
2015-08-18 05:57:27 +05:30
for _, addUser := range tt.addUsers {
err := repo.Create(nil, addUser)
if err != nil {
t.Fatalf("case %d: couldn't add user: %q", i, err)
}
}
got, err := repo.GetAdminCount(nil)
if err != nil {
t.Errorf("case %d: couldn't get admin count: %q", i, err)
continue
}
if tt.want != got {
t.Errorf("case %d: want=%d, got=%d", i, tt.want, got)
}
}
}
func TestList(t *testing.T) {
repoUsers := []user.UserWithRemoteIdentities{}
for i := 0; i < 10; i++ {
repoUsers = append(repoUsers, user.UserWithRemoteIdentities{
User: user.User{
ID: fmt.Sprintf("%d", i),
Email: fmt.Sprintf("%d@example.com", i),
},
})
}
tests := []struct {
filter user.UserFilter
maxResults int
expectedIDs [][]string
}{
{
maxResults: 5,
expectedIDs: [][]string{{"0", "1", "2", "3", "4"}, {"5", "6", "7", "8", "9"}},
},
{
maxResults: 3,
expectedIDs: [][]string{{"0", "1", "2"}, {"3", "4", "5"}, {"6", "7", "8"}, {"9"}},
},
{
maxResults: 9,
expectedIDs: [][]string{{"0", "1", "2", "3", "4", "5", "6", "7", "8"}, {"9"}},
},
{
maxResults: 10,
expectedIDs: [][]string{{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}},
},
}
for i, tt := range tests {
repo := newUserRepo(t, repoUsers)
2015-08-18 05:57:27 +05:30
var tok string
gotIDs := [][]string{}
done := false
for !done {
var users []user.User
var err error
users, tok, err = repo.List(nil, tt.filter, tt.maxResults, tok)
if err != nil {
t.Errorf("case %d: unexpected err: %v", i, err)
done = true
continue
}
ids := []string{}
for _, user := range users {
ids = append(ids, user.ID)
}
gotIDs = append(gotIDs, ids)
if tok == "" {
done = true
}
}
if diff := pretty.Compare(tt.expectedIDs, gotIDs); diff != "" {
t.Errorf("case %d: Compare(want, got) = %v", i,
diff)
}
}
}
func TestListErrorNotFound(t *testing.T) {
repo := newUserRepo(t, nil)
2015-08-18 05:57:27 +05:30
_, _, err := repo.List(nil, user.UserFilter{}, 10, "")
if err != user.ErrorNotFound {
t.Errorf("want=%q, got=%q", user.ErrorNotFound, err)
}
}