package repo

import (
	"os"
	"testing"
	"time"

	"github.com/jonboulle/clockwork"
	"github.com/kylelemons/godebug/pretty"

	"github.com/coreos/dex/db"
	"github.com/coreos/dex/session"
)

var makeTestSessionRepo func() (session.SessionRepo, clockwork.FakeClock)
var makeTestSessionKeyRepo func() (session.SessionKeyRepo, clockwork.FakeClock)

func init() {
	dsn := os.Getenv("DEX_TEST_DSN")
	if dsn == "" {
		makeTestSessionRepo = makeTestSessionRepoMem
		makeTestSessionKeyRepo = makeTestSessionKeyRepoMem
	} else {
		makeTestSessionRepo = makeTestSessionRepoDB(dsn)
		makeTestSessionKeyRepo = makeTestSessionKeyRepoDB(dsn)
	}
}

func makeTestSessionRepoMem() (session.SessionRepo, clockwork.FakeClock) {
	fc := clockwork.NewFakeClock()
	return session.NewSessionRepoWithClock(fc), fc
}

func makeTestSessionRepoDB(dsn string) func() (session.SessionRepo, clockwork.FakeClock) {
	return func() (session.SessionRepo, clockwork.FakeClock) {
		c := initDB(dsn)
		fc := clockwork.NewFakeClock()
		return db.NewSessionRepoWithClock(c, fc), fc
	}
}

func makeTestSessionKeyRepoMem() (session.SessionKeyRepo, clockwork.FakeClock) {
	fc := clockwork.NewFakeClock()
	return session.NewSessionKeyRepoWithClock(fc), fc
}

func makeTestSessionKeyRepoDB(dsn string) func() (session.SessionKeyRepo, clockwork.FakeClock) {
	return func() (session.SessionKeyRepo, clockwork.FakeClock) {
		c := initDB(dsn)
		fc := clockwork.NewFakeClock()
		return db.NewSessionKeyRepoWithClock(c, fc), fc
	}
}

func TestSessionKeyRepoPopNoExist(t *testing.T) {
	r, _ := makeTestSessionKeyRepo()

	_, err := r.Pop("123")
	if err == nil {
		t.Fatalf("Expected error, got nil")
	}
}

func TestSessionKeyRepoPushPop(t *testing.T) {
	r, _ := makeTestSessionKeyRepo()

	key := "123"
	sessionID := "456"

	r.Push(session.SessionKey{Key: key, SessionID: sessionID}, time.Second)

	got, err := r.Pop(key)
	if err != nil {
		t.Fatalf("Expected nil error: %v", err)
	}

	if got != sessionID {
		t.Fatalf("Incorrect sessionID: want=%s got=%s", sessionID, got)
	}
}

func TestSessionKeyRepoExpired(t *testing.T) {
	r, fc := makeTestSessionKeyRepo()

	key := "123"
	sessionID := "456"

	r.Push(session.SessionKey{Key: key, SessionID: sessionID}, time.Second)

	fc.Advance(2 * time.Second)

	_, err := r.Pop(key)
	if err == nil {
		t.Fatalf("Expected error, got nil")
	}
}

func TestSessionRepoGetNoExist(t *testing.T) {
	r, _ := makeTestSessionRepo()

	ses, err := r.Get("123")
	if ses != nil {
		t.Fatalf("Expected nil, got %#v", ses)
	}
	if err == nil {
		t.Fatalf("Expected non-nil error")
	}
}

func TestSessionRepoCreateGet(t *testing.T) {
	tests := []session.Session{
		session.Session{
			ID:          "123",
			ClientState: "blargh",
			ExpiresAt:   time.Unix(123, 0).UTC(),
		},
		session.Session{
			ID:          "456",
			ClientState: "argh",
			ExpiresAt:   time.Unix(456, 0).UTC(),
			Register:    true,
		},
		session.Session{
			ID:          "789",
			ClientState: "blargh",
			ExpiresAt:   time.Unix(789, 0).UTC(),
			Nonce:       "oncenay",
		},
	}

	for i, tt := range tests {
		r, _ := makeTestSessionRepo()

		r.Create(tt)

		ses, _ := r.Get(tt.ID)
		if ses == nil {
			t.Fatalf("case %d: Expected non-nil Session", i)
		}

		if diff := pretty.Compare(tt, ses); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
		}

	}
}

func TestSessionRepoCreateUpdate(t *testing.T) {
	tests := []struct {
		initial session.Session
		update  session.Session
	}{
		{
			initial: session.Session{
				ID:          "123",
				ClientState: "blargh",
				ExpiresAt:   time.Unix(123, 0).UTC(),
			},
			update: session.Session{
				ID:          "123",
				ClientState: "boom",
				ExpiresAt:   time.Unix(123, 0).UTC(),
				Register:    true,
			},
		},
	}

	for i, tt := range tests {
		r, _ := makeTestSessionRepo()
		r.Create(tt.initial)

		ses, _ := r.Get(tt.initial.ID)
		if diff := pretty.Compare(tt.initial, ses); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
		}

		r.Update(tt.update)
		ses, _ = r.Get(tt.initial.ID)
		if ses == nil {
			t.Fatalf("Expected non-nil Session")
		}
		if diff := pretty.Compare(tt.update, ses); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
		}
	}
}

func TestSessionRepoUpdateNoExist(t *testing.T) {
	r, _ := makeTestSessionRepo()

	err := r.Update(session.Session{ID: "123", ClientState: "boom"})
	if err == nil {
		t.Fatalf("Expected non-nil error")
	}
}