package repo

import (
	"os"
	"testing"
	"time"

	"github.com/jonboulle/clockwork"
	"github.com/kylelemons/godebug/pretty"

	"github.com/coreos/dex/db"
	"github.com/coreos/dex/session"
)

func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
	clock := clockwork.NewFakeClock()
	if os.Getenv("DEX_TEST_DSN") == "" {
		return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock
	}
	dbMap := connect(t)
	return db.NewSessionRepoWithClock(dbMap, clock), clock
}

func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
	clock := clockwork.NewFakeClock()
	if os.Getenv("DEX_TEST_DSN") == "" {
		return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock
	}
	dbMap := connect(t)
	return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
}

func TestSessionKeyRepoPopNoExist(t *testing.T) {
	r, _ := newSessionKeyRepo(t)

	_, err := r.Pop("123")
	if err == nil {
		t.Fatalf("Expected error, got nil")
	}
}

func TestSessionKeyRepoPushPop(t *testing.T) {
	r, _ := newSessionKeyRepo(t)

	key := "123"
	sessionID := "456"

	r.Push(session.SessionKey{Key: key, SessionID: sessionID}, time.Second)

	got, err := r.Pop(key)
	if err != nil {
		t.Fatalf("Expected nil error: %v", err)
	}

	if got != sessionID {
		t.Fatalf("Incorrect sessionID: want=%s got=%s", sessionID, got)
	}
}

func TestSessionKeyRepoExpired(t *testing.T) {
	r, fc := newSessionKeyRepo(t)

	key := "123"
	sessionID := "456"

	r.Push(session.SessionKey{Key: key, SessionID: sessionID}, time.Second)

	fc.Advance(2 * time.Second)

	_, err := r.Pop(key)
	if err == nil {
		t.Fatalf("Expected error, got nil")
	}
}

func TestSessionRepoGetNoExist(t *testing.T) {
	r, _ := newSessionRepo(t)

	ses, err := r.Get("123")
	if ses != nil {
		t.Fatalf("Expected nil, got %#v", ses)
	}
	if err == nil {
		t.Fatalf("Expected non-nil error")
	}
}

func TestSessionRepoCreateGet(t *testing.T) {
	tests := []session.Session{
		session.Session{
			ID:          "123",
			ClientState: "blargh",
			ExpiresAt:   time.Unix(123, 0).UTC(),
		},
		session.Session{
			ID:          "456",
			ClientState: "argh",
			ExpiresAt:   time.Unix(456, 0).UTC(),
			Register:    true,
		},
		session.Session{
			ID:          "789",
			ClientState: "blargh",
			ExpiresAt:   time.Unix(789, 0).UTC(),
			Nonce:       "oncenay",
		},
		session.Session{
			ID:          "anID",
			ClientState: "blargh",
			ExpiresAt:   time.Unix(789, 0).UTC(),
			Nonce:       "oncenay",
			Groups:      []string{"group1", "group2"},
		},
	}

	for i, tt := range tests {
		r, _ := newSessionRepo(t)

		r.Create(tt)

		ses, _ := r.Get(tt.ID)
		if ses == nil {
			t.Fatalf("case %d: Expected non-nil Session", i)
		}

		if diff := pretty.Compare(tt, ses); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
		}

	}
}

func TestSessionRepoCreateUpdate(t *testing.T) {
	tests := []struct {
		initial session.Session
		update  session.Session
	}{
		{
			initial: session.Session{
				ID:          "123",
				ClientState: "blargh",
				ExpiresAt:   time.Unix(123, 0).UTC(),
			},
			update: session.Session{
				ID:          "123",
				ClientState: "boom",
				ExpiresAt:   time.Unix(123, 0).UTC(),
				Register:    true,
			},
		},
	}

	for i, tt := range tests {
		r, _ := newSessionRepo(t)
		r.Create(tt.initial)

		ses, _ := r.Get(tt.initial.ID)
		if diff := pretty.Compare(tt.initial, ses); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
		}

		r.Update(tt.update)
		ses, _ = r.Get(tt.initial.ID)
		if ses == nil {
			t.Fatalf("Expected non-nil Session")
		}
		if diff := pretty.Compare(tt.update, ses); diff != "" {
			t.Errorf("case %d: Compare(want, got) = %v", i, diff)
		}
	}
}

func TestSessionRepoUpdateNoExist(t *testing.T) {
	r, _ := newSessionRepo(t)

	err := r.Update(session.Session{ID: "123", ClientState: "boom"})
	if err == nil {
		t.Fatalf("Expected non-nil error")
	}
}