forked from mystiq/dex
180 lines
3.8 KiB
Go
180 lines
3.8 KiB
Go
package repo
|
|
|
|
import (
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jonboulle/clockwork"
|
|
"github.com/kylelemons/godebug/pretty"
|
|
|
|
"github.com/coreos/dex/db"
|
|
"github.com/coreos/dex/session"
|
|
)
|
|
|
|
func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
|
|
clock := clockwork.NewFakeClock()
|
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
|
return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock
|
|
}
|
|
dbMap := connect(t)
|
|
return db.NewSessionRepoWithClock(dbMap, clock), clock
|
|
}
|
|
|
|
func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
|
|
clock := clockwork.NewFakeClock()
|
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
|
return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock
|
|
}
|
|
dbMap := connect(t)
|
|
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
|
|
}
|
|
|
|
func TestSessionKeyRepoPopNoExist(t *testing.T) {
|
|
r, _ := newSessionKeyRepo(t)
|
|
|
|
_, err := r.Pop("123")
|
|
if err == nil {
|
|
t.Fatalf("Expected error, got nil")
|
|
}
|
|
}
|
|
|
|
func TestSessionKeyRepoPushPop(t *testing.T) {
|
|
r, _ := newSessionKeyRepo(t)
|
|
|
|
key := "123"
|
|
sessionID := "456"
|
|
|
|
r.Push(session.SessionKey{Key: key, SessionID: sessionID}, time.Second)
|
|
|
|
got, err := r.Pop(key)
|
|
if err != nil {
|
|
t.Fatalf("Expected nil error: %v", err)
|
|
}
|
|
|
|
if got != sessionID {
|
|
t.Fatalf("Incorrect sessionID: want=%s got=%s", sessionID, got)
|
|
}
|
|
}
|
|
|
|
func TestSessionKeyRepoExpired(t *testing.T) {
|
|
r, fc := newSessionKeyRepo(t)
|
|
|
|
key := "123"
|
|
sessionID := "456"
|
|
|
|
r.Push(session.SessionKey{Key: key, SessionID: sessionID}, time.Second)
|
|
|
|
fc.Advance(2 * time.Second)
|
|
|
|
_, err := r.Pop(key)
|
|
if err == nil {
|
|
t.Fatalf("Expected error, got nil")
|
|
}
|
|
}
|
|
|
|
func TestSessionRepoGetNoExist(t *testing.T) {
|
|
r, _ := newSessionRepo(t)
|
|
|
|
ses, err := r.Get("123")
|
|
if ses != nil {
|
|
t.Fatalf("Expected nil, got %#v", ses)
|
|
}
|
|
if err == nil {
|
|
t.Fatalf("Expected non-nil error")
|
|
}
|
|
}
|
|
|
|
func TestSessionRepoCreateGet(t *testing.T) {
|
|
tests := []session.Session{
|
|
session.Session{
|
|
ID: "123",
|
|
ClientState: "blargh",
|
|
ExpiresAt: time.Unix(123, 0).UTC(),
|
|
},
|
|
session.Session{
|
|
ID: "456",
|
|
ClientState: "argh",
|
|
ExpiresAt: time.Unix(456, 0).UTC(),
|
|
Register: true,
|
|
},
|
|
session.Session{
|
|
ID: "789",
|
|
ClientState: "blargh",
|
|
ExpiresAt: time.Unix(789, 0).UTC(),
|
|
Nonce: "oncenay",
|
|
},
|
|
session.Session{
|
|
ID: "anID",
|
|
ClientState: "blargh",
|
|
ExpiresAt: time.Unix(789, 0).UTC(),
|
|
Nonce: "oncenay",
|
|
Groups: []string{"group1", "group2"},
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
r, _ := newSessionRepo(t)
|
|
|
|
r.Create(tt)
|
|
|
|
ses, _ := r.Get(tt.ID)
|
|
if ses == nil {
|
|
t.Fatalf("case %d: Expected non-nil Session", i)
|
|
}
|
|
|
|
if diff := pretty.Compare(tt, ses); diff != "" {
|
|
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
func TestSessionRepoCreateUpdate(t *testing.T) {
|
|
tests := []struct {
|
|
initial session.Session
|
|
update session.Session
|
|
}{
|
|
{
|
|
initial: session.Session{
|
|
ID: "123",
|
|
ClientState: "blargh",
|
|
ExpiresAt: time.Unix(123, 0).UTC(),
|
|
},
|
|
update: session.Session{
|
|
ID: "123",
|
|
ClientState: "boom",
|
|
ExpiresAt: time.Unix(123, 0).UTC(),
|
|
Register: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
r, _ := newSessionRepo(t)
|
|
r.Create(tt.initial)
|
|
|
|
ses, _ := r.Get(tt.initial.ID)
|
|
if diff := pretty.Compare(tt.initial, ses); diff != "" {
|
|
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
|
|
}
|
|
|
|
r.Update(tt.update)
|
|
ses, _ = r.Get(tt.initial.ID)
|
|
if ses == nil {
|
|
t.Fatalf("Expected non-nil Session")
|
|
}
|
|
if diff := pretty.Compare(tt.update, ses); diff != "" {
|
|
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSessionRepoUpdateNoExist(t *testing.T) {
|
|
r, _ := newSessionRepo(t)
|
|
|
|
err := r.Update(session.Session{ID: "123", ClientState: "boom"})
|
|
if err == nil {
|
|
t.Fatalf("Expected non-nil error")
|
|
}
|
|
}
|