package manager

import (
	"net/url"
	"testing"
	"time"

	"github.com/coreos/go-oidc/jose"
	"github.com/jonboulle/clockwork"
	"github.com/kylelemons/godebug/pretty"

	"github.com/coreos/dex/connector"
	"github.com/coreos/dex/db"
	"github.com/coreos/dex/user"
)

type testFixtures struct {
	ur    user.UserRepo
	pwr   user.PasswordInfoRepo
	ccr   connector.ConnectorConfigRepo
	mgr   *UserManager
	clock clockwork.Clock
}

func makeTestFixtures() *testFixtures {
	f := &testFixtures{}
	f.clock = clockwork.NewFakeClock()

	dbMap := db.NewMemDB()
	f.ur = func() user.UserRepo {
		repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
			{
				User: user.User{
					ID:    "ID-1",
					Email: "Email-1@example.com",
				},
				RemoteIdentities: []user.RemoteIdentity{
					{
						ConnectorID: "local",
						ID:          "1",
					},
				},
			}, {
				User: user.User{
					ID:            "ID-2",
					Email:         "Email-2@example.com",
					EmailVerified: true,
				},
				RemoteIdentities: []user.RemoteIdentity{
					{
						ConnectorID: "local",
						ID:          "2",
					},
				},
			},
		})
		if err != nil {
			panic("Failed to create user repo: " + err.Error())
		}
		return repo
	}()

	f.pwr = func() user.PasswordInfoRepo {
		repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{
			{
				UserID:   "ID-1",
				Password: []byte("password-1"),
			},
			{
				UserID:   "ID-2",
				Password: []byte("password-2"),
			},
		})
		if err != nil {
			panic("Failed to create user repo: " + err.Error())
		}
		return repo
	}()

	f.ccr = func() connector.ConnectorConfigRepo {
		repo := db.NewConnectorConfigRepo(dbMap)
		c := []connector.ConnectorConfig{
			&connector.LocalConnectorConfig{ID: "local"},
		}
		if err := repo.Set(c); err != nil {
			panic(err)
		}
		return repo
	}()

	f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, db.TransactionFactory(dbMap), ManagerOptions{})
	f.mgr.Clock = f.clock
	return f
}

func TestRegisterWithRemoteIdentity(t *testing.T) {
	tests := []struct {
		email         string
		emailVerified bool
		rid           user.RemoteIdentity
		err           error
	}{
		{
			email:         "email@example.com",
			emailVerified: false,
			rid: user.RemoteIdentity{
				ConnectorID: "local",
				ID:          "1234",
			},
			err: nil,
		},
		{
			emailVerified: false,
			rid: user.RemoteIdentity{
				ConnectorID: "local",
				ID:          "1234",
			},
			err: user.ErrorInvalidEmail,
		},
		{
			email:         "email@example.com",
			emailVerified: false,
			rid: user.RemoteIdentity{
				ConnectorID: "local",
				ID:          "1",
			},
			err: user.ErrorDuplicateRemoteIdentity,
		},
		{
			email:         "anotheremail@example.com",
			emailVerified: false,
			rid: user.RemoteIdentity{
				ConnectorID: "idonotexist",
				ID:          "1",
			},
			err: connector.ErrorNotFound,
		},
	}

	for i, tt := range tests {
		f := makeTestFixtures()
		userID, err := f.mgr.RegisterWithRemoteIdentity(
			tt.email,
			tt.emailVerified,
			tt.rid)

		if tt.err != nil {
			if tt.err != err {
				t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
			}
			continue
		}

		usr, err := f.ur.Get(nil, userID)
		if err != nil {
			t.Errorf("case %d: err != nil: %q", i, err)
		}

		if usr.Email != tt.email {
			t.Errorf("case %d: user.Email: want=%q, got=%q", i, tt.email, usr.Email)
		}
		if usr.EmailVerified != tt.emailVerified {
			t.Errorf("case %d: user.EmailVerified: want=%v, got=%v", i, tt.emailVerified, usr.EmailVerified)
		}

		ridUSR, err := f.ur.GetByRemoteIdentity(nil, tt.rid)
		if err != nil {
			t.Errorf("case %d: err != nil: %q", i, err)
		}
		if diff := pretty.Compare(usr, ridUSR); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
		}
	}
}

func TestRegisterWithPassword(t *testing.T) {
	tests := []struct {
		email     string
		plaintext string
		err       error
	}{
		{
			email:     "email@example.com",
			plaintext: "secretpassword123",
			err:       nil,
		},
		{
			plaintext: "secretpassword123",
			err:       user.ErrorInvalidEmail,
		},
		{
			email: "email@example.com",
			err:   user.ErrorInvalidPassword,
		},
	}

	for i, tt := range tests {
		f := makeTestFixtures()
		connID := "local"
		userID, err := f.mgr.RegisterWithPassword(
			tt.email,
			tt.plaintext,
			connID)

		if tt.err != nil {
			if tt.err != err {
				t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
			}
			continue
		}

		usr, err := f.ur.Get(nil, userID)
		if err != nil {
			t.Errorf("case %d: err != nil: %q", i, err)
		}

		if usr.Email != tt.email {
			t.Errorf("case %d: user.Email: want=%q, got=%q", i, tt.email, usr.Email)
		}
		if usr.EmailVerified != false {
			t.Errorf("case %d: user.EmailVerified: want=%v, got=%v", i, false, usr.EmailVerified)
		}

		ridUSR, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{
			ID:          userID,
			ConnectorID: connID,
		})
		if err != nil {
			t.Errorf("case %d: err != nil: %q", i, err)
		}
		if diff := pretty.Compare(usr, ridUSR); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
			continue
		}

		pwi, err := f.pwr.Get(nil, userID)
		if err != nil {
			t.Errorf("case %d: err != nil: %q", i, err)
			continue
		}
		ident, err := pwi.Authenticate(tt.plaintext)
		if err != nil {
			t.Errorf("case %d: err != nil: %q", i, err)
			continue
		}
		if ident.ID != userID {
			t.Errorf("case %d: ident.ID: want=%q, got=%q", i, userID, ident.ID)
			continue
		}

		_, err = pwi.Authenticate(tt.plaintext + "WRONG")
		if err == nil {
			t.Errorf("case %d: want non-nil err", i)
		}
	}
}

func TestVerifyEmail(t *testing.T) {
	now := time.Now()
	issuer, _ := url.Parse("http://example.com")
	clientID := "myclient"
	callback := "http://client.example.com/callback"
	expires := time.Hour * 3

	makeClaims := func(usr user.User) jose.Claims {
		return map[string]interface{}{
			"iss": issuer.String(),
			"aud": clientID,
			user.ClaimEmailVerificationCallback: callback,
			user.ClaimEmailVerificationEmail:    usr.Email,
			"exp": float64(now.Add(expires).Unix()),
			"sub": usr.ID,
			"iat": float64(now.Unix()),
		}
	}

	tests := []struct {
		evClaims jose.Claims
		wantErr  bool
	}{
		{
			// happy path
			evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-1@example.com"}),
		},
		{
			// non-matching email
			evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-2@example.com"}),
			wantErr:  true,
		},
		{
			// already verified email
			evClaims: makeClaims(user.User{ID: "ID-2", Email: "Email-2@example.com"}),
			wantErr:  true,
		},
		{
			// non-existent user.
			evClaims: makeClaims(user.User{ID: "ID-UNKNOWN", Email: "noone@example.com"}),
			wantErr:  true,
		},
	}

	for i, tt := range tests {
		f := makeTestFixtures()
		cb, err := f.mgr.VerifyEmail(user.EmailVerification{Claims: tt.evClaims})
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want non-nil err", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: want err=nil got=%q", i, err)
			continue
		}

		if cb.String() != tt.evClaims[user.ClaimEmailVerificationCallback] {
			t.Errorf("case %d: want=%q, got=%q", i, cb.String(),
				tt.evClaims[user.ClaimEmailVerificationCallback])
		}
	}
}

func TestChangePassword(t *testing.T) {
	now := time.Now()
	issuer, _ := url.Parse("http://example.com")
	clientID := "myclient"
	callback := "http://client.example.com/callback"
	expires := time.Hour * 3
	password := "password-1"

	makeClaims := func(usrID, callback string) jose.Claims {
		return map[string]interface{}{
			"iss": issuer.String(),
			"aud": clientID,
			user.ClaimPasswordResetCallback: callback,
			user.ClaimPasswordResetPassword: password,
			"exp": float64(now.Add(expires).Unix()),
			"sub": usrID,
			"iat": float64(now.Unix()),
		}
	}

	tests := []struct {
		pwrClaims   jose.Claims
		newPassword string
		wantErr     bool
	}{
		{
			// happy path
			pwrClaims:   makeClaims("ID-1", callback),
			newPassword: "password-1.1",
		},
		{
			// happy path with no callback
			pwrClaims:   makeClaims("ID-1", ""),
			newPassword: "password-1.1",
		},
		{
			// passwords don't match changed
			pwrClaims:   makeClaims("ID-2", callback),
			newPassword: "password-1.1",
			wantErr:     true,
		},
		{
			// user doesn't exist
			pwrClaims:   makeClaims("ID-123", callback),
			newPassword: "password-1.1",
			wantErr:     true,
		},
	}

	for i, tt := range tests {
		f := makeTestFixtures()
		cb, err := f.mgr.ChangePassword(user.PasswordReset{Claims: tt.pwrClaims}, tt.newPassword)
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want non-nil err", i)
			}
			continue
		}

		if err != nil {
			t.Errorf("case %d: want err=nil got=%q", i, err)
			continue
		}

		var cbString string
		if cb != nil {
			cbString = cb.String()
		}
		if cbString != tt.pwrClaims[user.ClaimPasswordResetCallback] {
			t.Errorf("case %d: want=%q, got=%q", i, cb.String(),
				tt.pwrClaims[user.ClaimPasswordResetCallback])
		}
	}
}

func TestCreateUser(t *testing.T) {
	tests := []struct {
		usr      user.User
		hashedPW user.Password
		localID  string // defaults to "local"

		wantErr bool
	}{
		{
			usr: user.User{
				DisplayName: "Bob Exampleson",
				Email:       "bob@example.com",
			},
			hashedPW: user.Password("I am a hash"),
		},
		{
			usr: user.User{
				DisplayName: "Al Adminson",
				Email:       "al@example.com",
				Admin:       true,
			},
			hashedPW: user.Password("I am a hash"),
		},
		{
			usr: user.User{
				DisplayName: "Ed Emailless",
			},
			hashedPW: user.Password("I am a hash"),
			wantErr:  true,
		},
		{
			usr: user.User{
				DisplayName: "Eric Exampleson",
				Email:       "eric@example.com",
			},
			hashedPW: user.Password("I am a hash"),
			localID:  "abadlocalid",
			wantErr:  true,
		},
	}

	for i, tt := range tests {
		f := makeTestFixtures()
		localID := "local"
		if tt.localID != "" {
			localID = tt.localID
		}
		id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID)
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want non-nil err", i)
			}
			continue
		}
		if id == "" {
			t.Errorf("case %d: want non-empty id", i)
		}

		if err != nil {
			t.Errorf("case %d: unexpected err: %v", i, err)
			continue
		}

		gotUsr, err := f.ur.Get(nil, id)
		if err != nil {
			t.Errorf("case %d: unexpected err: %v", i, err)
		}

		tt.usr.ID = id
		tt.usr.CreatedAt = f.clock.Now()
		if diff := pretty.Compare(tt.usr, gotUsr); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
		}

		pwi, err := f.pwr.Get(nil, id)
		if err != nil {
			t.Errorf("case %d: unexpected err: %v", i, err)
		}

		if string(pwi.Password) != string(tt.hashedPW) {
			t.Errorf("case %d: want=%q, got=%q", i, tt.hashedPW, pwi.Password)
		}

		ridUser, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{
			ID:          id,
			ConnectorID: "local",
		})
		if err != nil {
			t.Errorf("case %d: err != nil: %q", i, err)
		}
		if diff := pretty.Compare(gotUsr, ridUser); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
		}
	}
}