package repo

import (
	"fmt"
	"os"
	"reflect"
	"strings"
	"testing"
	"time"

	"github.com/go-gorp/gorp"
	"github.com/kylelemons/godebug/pretty"

	"github.com/coreos/dex/db"
	"github.com/coreos/dex/user"
)

var (
	testUsers = []user.UserWithRemoteIdentities{
		{
			User: user.User{
				ID:        "ID-1",
				Email:     "Email-1@example.com",
				CreatedAt: time.Now().Truncate(time.Second),
			},
			RemoteIdentities: []user.RemoteIdentity{
				{
					ConnectorID: "IDPC-1",
					ID:          "RID-1",
				},
			},
		},
		{
			User: user.User{
				ID:        "ID-2",
				Email:     "Email-2@example.com",
				CreatedAt: time.Now(),
				Disabled:  true,
			},
			RemoteIdentities: []user.RemoteIdentity{
				{
					ConnectorID: "IDPC-2",
					ID:          "RID-2",
				},
			},
		},
	}
)

func newUserRepo(t *testing.T, users []user.UserWithRemoteIdentities) user.UserRepo {
	if users == nil {
		users = []user.UserWithRemoteIdentities{}
	}
	var dbMap *gorp.DbMap
	if os.Getenv("DEX_TEST_DSN") == "" {
		dbMap = db.NewMemDB()
	} else {
		dbMap = connect(t)
	}
	repo, err := db.NewUserRepoFromUsers(dbMap, users)
	if err != nil {
		t.Fatalf("Unable to add users: %v", err)
	}
	return repo
}

func TestNewUser(t *testing.T) {
	now := time.Now().UTC().Truncate(time.Second)
	tests := []struct {
		user user.User
		err  error
	}{
		{
			user: user.User{
				ID:        "ID-bob",
				Email:     "bob@example.com",
				CreatedAt: now,
			},
			err: nil,
		},
		{
			user: user.User{
				ID:        "ID-admin",
				Email:     "admin@example.com",
				Admin:     true,
				CreatedAt: now,
			},
			err: nil,
		},
		{
			user: user.User{
				ID:            "ID-verified",
				Email:         "verified@example.com",
				EmailVerified: true,
				CreatedAt:     now,
			},
			err: nil,
		},
		{
			user: user.User{
				ID:          "ID-same",
				Email:       "Email-1@example.com",
				DisplayName: "Oops Same Email",
				CreatedAt:   now,
			},
			err: user.ErrorDuplicateEmail,
		},
		{
			user: user.User{
				ID:          "ID-same",
				Email:       "email-1@example.com",
				DisplayName: "A lower case version of the original email",
				CreatedAt:   now,
			},
			err: user.ErrorDuplicateEmail,
		},
		{
			user: user.User{
				Email:       "AnotherEmail@example.com",
				DisplayName: "Can't set your own ID!",
				CreatedAt:   now,
			},
			err: user.ErrorInvalidID,
		},
		{
			user: user.User{
				ID:          "ID-noemail",
				DisplayName: "No Email",
				CreatedAt:   now,
			},
			err: user.ErrorInvalidEmail,
		},
	}

	for i, tt := range tests {
		repo := newUserRepo(t, testUsers)
		err := repo.Create(nil, tt.user)
		if tt.err != nil {
			if err != tt.err {
				t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
			}
		} else {
			if err != nil {
				t.Errorf("case %d: want nil err, got %v", i, err)
			}

			gotUser, err := repo.Get(nil, tt.user.ID)
			if err != nil {
				t.Errorf("case %d: want nil err, got %v", i, err)
			}

			if diff := pretty.Compare(tt.user, gotUser); diff != "" {
				t.Errorf("case %d: Compare(want, got) = %v", i,
					diff)
			}
		}
	}
}

func TestUpdateUser(t *testing.T) {
	tests := []struct {
		user user.User
		err  error
	}{
		{
			// Update the email.
			user: user.User{
				ID:    "ID-1",
				Email: "Email-1.1@example.com",
			},
			err: nil,
		},
		{
			// No-op.
			user: user.User{
				ID:    "ID-1",
				Email: "Email-1@example.com",
			},
			err: nil,
		},
		{
			// No email.
			user: user.User{
				ID:    "ID-1",
				Email: "",
			},
			err: user.ErrorInvalidEmail,
		},
		{
			// Try Update on non-existent user.
			user: user.User{
				ID:    "NonExistent",
				Email: "GoodEmail@email.com",
			},
			err: user.ErrorNotFound,
		},
		{
			// Try update to someone else's email.
			user: user.User{
				ID:    "ID-2",
				Email: "Email-1@example.com",
			},
			err: user.ErrorDuplicateEmail,
		},
	}

	for i, tt := range tests {
		repo := newUserRepo(t, testUsers)
		err := repo.Update(nil, tt.user)
		if tt.err != nil {
			if err != tt.err {
				t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
			}
		} else {
			if err != nil {
				t.Errorf("case %d: want nil err, got %q", i, err)
			}

			gotUser, err := repo.Get(nil, tt.user.ID)
			if err != nil {
				t.Errorf("case %d: want nil err, got %q", i, err)
			}

			tt.user.Email = strings.ToLower(tt.user.Email)
			if diff := pretty.Compare(tt.user, gotUser); diff != "" {
				t.Errorf("case %d: Compare(want, got) = %v", i,
					diff)
			}
		}
	}
}

func TestDisableUser(t *testing.T) {
	tests := []struct {
		id      string
		disable bool
		err     error
	}{
		{
			id: "ID-1",
		},
		{
			id:      "ID-1",
			disable: true,
		},
		{
			id: "ID-2",
		},
		{
			id:      "ID-2",
			disable: true,
		},
		{
			id:  "NO SUCH ID",
			err: user.ErrorNotFound,
		},
		{
			id:      "NO SUCH ID",
			err:     user.ErrorNotFound,
			disable: true,
		},
		{
			id:  "",
			err: user.ErrorInvalidID,
		},
	}

	for i, tt := range tests {
		repo := newUserRepo(t, testUsers)
		err := repo.Disable(nil, tt.id, tt.disable)
		switch {
		case err != tt.err:
			t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
		case tt.err == nil:
			gotUser, err := repo.Get(nil, tt.id)
			if err != nil {
				t.Fatalf("case %d: want nil err, got %q", i, err)
			}

			if gotUser.Disabled != tt.disable {
				t.Errorf("case %d: disabled status want=%v got=%v",
					i, tt.disable, gotUser.Disabled)
			}
		}
	}
}

func TestAttachRemoteIdentity(t *testing.T) {
	tests := []struct {
		id  string
		rid user.RemoteIdentity
		err error
	}{
		{
			id: "ID-1",
			rid: user.RemoteIdentity{
				ConnectorID: "IDPC-1",
				ID:          "RID-1.1",
			},
		},
		{
			id: "ID-1",
			rid: user.RemoteIdentity{
				ConnectorID: "IDPC-2",
				ID:          "RID-2",
			},
			err: user.ErrorDuplicateRemoteIdentity,
		},
		{
			id: "NoSuchUser",
			rid: user.RemoteIdentity{
				ConnectorID: "IDPC-3",
				ID:          "RID-3",
			},
			err: user.ErrorNotFound,
		},
	}

	for i, tt := range tests {
		repo := newUserRepo(t, testUsers)
		err := repo.AddRemoteIdentity(nil, tt.id, tt.rid)
		if tt.err != nil {
			if err != tt.err {
				t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
			}
		} else {
			if err != nil {
				t.Errorf("case %d: want nil err, got %q", i, err)
			}

			gotUser, err := repo.GetByRemoteIdentity(nil, tt.rid)
			if err != nil {
				t.Errorf("case %d: want nil err, got %q", i, err)
			}

			wantUser, err := repo.Get(nil, tt.id)
			if err != nil {
				t.Errorf("case %d: want nil err, got %q", i, err)
			}

			gotRIDs, err := repo.GetRemoteIdentities(nil, tt.id)
			if err != nil {
				t.Errorf("case %d: want nil err, got %q", i, err)
			}

			if findRemoteIdentity(gotRIDs, tt.rid) == -1 {
				t.Errorf("case %d: user.RemoteIdentity not found", i)
			}

			if !reflect.DeepEqual(wantUser, gotUser) {
				t.Errorf("case %d: want=%#v, got=%#v", i,
					wantUser, gotUser)
			}
		}
	}
}

func TestRemoveRemoteIdentity(t *testing.T) {
	tests := []struct {
		id  string
		rid user.RemoteIdentity
		err error
	}{
		{
			id: "ID-1",
			rid: user.RemoteIdentity{
				ConnectorID: "IDPC-1",
				ID:          "RID-1",
			},
		},
		{
			id: "ID-1",
			rid: user.RemoteIdentity{
				ConnectorID: "IDPC-2",
				ID:          "RID-2",
			},
			err: user.ErrorNotFound,
		},
		{
			id: "NoSuchUser",
			rid: user.RemoteIdentity{
				ConnectorID: "IDPC-3",
				ID:          "RID-3",
			},
			err: user.ErrorNotFound,
		},
	}

	for i, tt := range tests {
		repo := newUserRepo(t, testUsers)
		err := repo.RemoveRemoteIdentity(nil, tt.id, tt.rid)
		if tt.err != nil {
			if err != tt.err {
				t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
			}
		} else {
			if err != nil {
				t.Errorf("case %d: want nil err, got %q", i, err)
			}

			gotUser, err := repo.GetByRemoteIdentity(nil, tt.rid)
			if err == nil {
				if gotUser.ID == tt.id {
					t.Errorf("case %d: user found.", i)

				}
			} else if err != user.ErrorNotFound {
				t.Errorf("case %d: want %q err, got %q err", i, user.ErrorNotFound, err)
			}

			gotRIDs, err := repo.GetRemoteIdentities(nil, tt.id)
			if err != nil {
				t.Errorf("case %d: want nil err, got %q", i, err)
			}

			if findRemoteIdentity(gotRIDs, tt.rid) != -1 {
				t.Errorf("case %d: user.RemoteIdentity found", i)
			}

		}
	}
}

func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int {
	for i, curRID := range rids {
		if curRID == rid {
			return i
		}
	}
	return -1
}

func TestGetByEmail(t *testing.T) {
	tests := []struct {
		email     string
		wantEmail string
		wantErr   error
	}{
		{
			email:     "Email-1@example.com",
			wantEmail: "email-1@example.com",
			wantErr:   nil,
		},
		{
			email:     "EMAIL-1@example.com", // Emails should be case insensitive.
			wantEmail: "email-1@example.com",
			wantErr:   nil,
		},
		{
			email:   "NoSuchEmail@example.com",
			wantErr: user.ErrorNotFound,
		},
	}

	for i, tt := range tests {
		repo := newUserRepo(t, testUsers)
		gotUser, gotErr := repo.GetByEmail(nil, tt.email)
		if tt.wantErr != nil {
			if tt.wantErr != gotErr {
				t.Errorf("case %d: wantErr=%q, gotErr=%q", i, tt.wantErr, gotErr)
			}
			continue
		}

		if gotErr != nil {
			t.Errorf("case %d: want nil err:% q", i, gotErr)
			continue
		}

		if tt.wantEmail != gotUser.Email {
			t.Errorf("case %d: want=%q, got=%q", i, tt.email, gotUser.Email)
		}
	}
}

func TestGetAdminCount(t *testing.T) {
	tests := []struct {
		addUsers []user.User
		want     int
	}{
		{
			addUsers: []user.User{
				user.User{
					ID:    "ID-admin",
					Email: "Admin@example.com",
					Admin: true,
				},
			},
			want: 1,
		},
		{
			want: 0,
		},
		{
			addUsers: []user.User{
				user.User{
					ID:    "ID-admin",
					Email: "NotAdmin@example.com",
				},
			},
			want: 0,
		},
		{
			addUsers: []user.User{
				user.User{
					ID:    "ID-admin",
					Email: "Admin@example.com",
					Admin: true,
				},
				user.User{
					ID:    "ID-admin2",
					Email: "AnotherAdmin@example.com",
					Admin: true,
				},
			},
			want: 2,
		},
	}

	for i, tt := range tests {
		repo := newUserRepo(t, testUsers)
		for _, addUser := range tt.addUsers {
			err := repo.Create(nil, addUser)
			if err != nil {
				t.Fatalf("case %d: couldn't add user: %q", i, err)
			}
		}

		got, err := repo.GetAdminCount(nil)
		if err != nil {
			t.Errorf("case %d: couldn't get admin count: %q", i, err)
			continue
		}

		if tt.want != got {
			t.Errorf("case %d: want=%d, got=%d", i, tt.want, got)
		}
	}
}

func TestList(t *testing.T) {
	repoUsers := []user.UserWithRemoteIdentities{}
	for i := 0; i < 10; i++ {
		repoUsers = append(repoUsers, user.UserWithRemoteIdentities{
			User: user.User{
				ID:    fmt.Sprintf("%d", i),
				Email: fmt.Sprintf("%d@example.com", i),
			},
		})

	}
	tests := []struct {
		filter      user.UserFilter
		maxResults  int
		expectedIDs [][]string
	}{
		{
			maxResults:  5,
			expectedIDs: [][]string{{"0", "1", "2", "3", "4"}, {"5", "6", "7", "8", "9"}},
		},
		{
			maxResults:  3,
			expectedIDs: [][]string{{"0", "1", "2"}, {"3", "4", "5"}, {"6", "7", "8"}, {"9"}},
		},
		{
			maxResults:  9,
			expectedIDs: [][]string{{"0", "1", "2", "3", "4", "5", "6", "7", "8"}, {"9"}},
		},
		{
			maxResults:  10,
			expectedIDs: [][]string{{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}},
		},
	}

	for i, tt := range tests {
		repo := newUserRepo(t, repoUsers)
		var tok string
		gotIDs := [][]string{}
		done := false
		for !done {
			var users []user.User
			var err error
			users, tok, err = repo.List(nil, tt.filter, tt.maxResults, tok)
			if err != nil {
				t.Errorf("case %d: unexpected err: %v", i, err)
				done = true
				continue
			}
			ids := []string{}
			for _, user := range users {
				ids = append(ids, user.ID)
			}
			gotIDs = append(gotIDs, ids)
			if tok == "" {
				done = true
			}
		}
		if diff := pretty.Compare(tt.expectedIDs, gotIDs); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i,
				diff)
		}
	}
}

func TestListErrorNotFound(t *testing.T) {
	repo := newUserRepo(t, nil)
	_, _, err := repo.List(nil, user.UserFilter{}, 10, "")
	if err != user.ErrorNotFound {
		t.Errorf("want=%q, got=%q", user.ErrorNotFound, err)
	}
}