functional: clean up functional tests
Adjust logic and remove panics from functional tests.
This commit is contained in:
parent
bfd63b7514
commit
5052d8007f
9 changed files with 118 additions and 176 deletions
|
@ -28,6 +28,7 @@ script:
|
||||||
- docker cp ${LDAPCONTAINER}:container/service/slapd/assets/certs/ldap.crt /tmp/ldap.crt
|
- docker cp ${LDAPCONTAINER}:container/service/slapd/assets/certs/ldap.crt /tmp/ldap.crt
|
||||||
- sudo sh -c 'echo "127.0.0.1 tlstest.local" >> /etc/hosts'
|
- sudo sh -c 'echo "127.0.0.1 tlstest.local" >> /etc/hosts'
|
||||||
- ./test-functional
|
- ./test-functional
|
||||||
|
- DEX_TEST_DSN="sqlite3://:memory:" ./test-functional
|
||||||
|
|
||||||
deploy:
|
deploy:
|
||||||
provider: script
|
provider: script
|
||||||
|
|
|
@ -36,17 +36,16 @@ func connect(t *testing.T) *gorp.DbMap {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to connect to database: %v", err)
|
t.Fatalf("Unable to connect to database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = c.DropTablesIfExists(); err != nil {
|
if err = c.DropTablesIfExists(); err != nil {
|
||||||
t.Fatalf("Unable to drop database tables: %v", err)
|
t.Fatalf("Unable to drop database tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = db.DropMigrationsTable(c); err != nil {
|
if err = db.DropMigrationsTable(c); err != nil {
|
||||||
panic(fmt.Sprintf("Unable to drop migration table: %v", err))
|
t.Fatalf("Unable to drop migration table: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = db.MigrateToLatest(c); err != nil {
|
if _, err = db.MigrateToLatest(c); err != nil {
|
||||||
panic(fmt.Sprintf("Unable to migrate: %v", err))
|
t.Fatalf("Unable to migrate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
@ -157,12 +156,13 @@ func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
setRepo, err := db.NewPrivateKeySetRepo(connect(t), false, tt.setSecrets...)
|
dbMap := connect(t)
|
||||||
|
setRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.setSecrets...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
getRepo, err := db.NewPrivateKeySetRepo(connect(t), false, tt.getSecrets...)
|
getRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.getSecrets...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
@ -377,9 +377,24 @@ func TestDBRefreshRepoCreate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
_, err := r.Create(tt.userID, tt.clientID)
|
token, err := r.Create(tt.userID, tt.clientID)
|
||||||
if err != tt.err {
|
if err != nil {
|
||||||
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
if tt.err == nil {
|
||||||
|
t.Errorf("case %d: create failed: %v", i, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tt.err != nil {
|
||||||
|
t.Errorf("case %d: expected error, didn't get one", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userID, err := r.Verify(tt.clientID, token)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: failed to verify good token: %v", i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if userID != tt.userID {
|
||||||
|
t.Errorf("case %d: want userID=%s, got userID=%s", i, tt.userID, userID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,9 @@ var connConfigExample = []byte(`[
|
||||||
]`)
|
]`)
|
||||||
|
|
||||||
func TestDexctlCommands(t *testing.T) {
|
func TestDexctlCommands(t *testing.T) {
|
||||||
|
if strings.HasPrefix(dsn, "sqlite3://") {
|
||||||
|
t.Skip("only test dexctl conmand with postgres")
|
||||||
|
}
|
||||||
tempFile, err := ioutil.TempFile("", "dexctl_functional_tests_")
|
tempFile, err := ioutil.TempFile("", "dexctl_functional_tests_")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package repo
|
package repo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -12,8 +11,6 @@ import (
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
var makeTestClientIdentityRepoFromClients func(clients []oidc.ClientIdentity) client.ClientIdentityRepo
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testClients = []oidc.ClientIdentity{
|
testClients = []oidc.ClientIdentity{
|
||||||
oidc.ClientIdentity{
|
oidc.ClientIdentity{
|
||||||
|
@ -47,34 +44,17 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func newClientIdentityRepo(t *testing.T) client.ClientIdentityRepo {
|
||||||
dsn := os.Getenv("DEX_TEST_DSN")
|
dsn := os.Getenv("DEX_TEST_DSN")
|
||||||
if dsn == "" {
|
if dsn == "" {
|
||||||
makeTestClientIdentityRepoFromClients = makeTestClientIdentityRepoMem
|
return client.NewClientIdentityRepo(testClients)
|
||||||
} else {
|
|
||||||
makeTestClientIdentityRepoFromClients = makeTestClientIdentityRepoDB(dsn)
|
|
||||||
}
|
}
|
||||||
}
|
dbMap := connect(t)
|
||||||
|
repo, err := db.NewClientIdentityRepoFromClients(dbMap, testClients)
|
||||||
func makeTestClientIdentityRepoMem(clients []oidc.ClientIdentity) client.ClientIdentityRepo {
|
|
||||||
return client.NewClientIdentityRepo(clients)
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTestClientIdentityRepoDB(dsn string) func([]oidc.ClientIdentity) client.ClientIdentityRepo {
|
|
||||||
return func(clients []oidc.ClientIdentity) client.ClientIdentityRepo {
|
|
||||||
c := initDB(dsn)
|
|
||||||
|
|
||||||
repo, err := db.NewClientIdentityRepoFromClients(c, clients)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Unable to add clients: %v", err))
|
t.Fatalf("failed to create client repo from clients: %v", err)
|
||||||
}
|
}
|
||||||
return repo
|
return repo
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTestClientIdentityRepo() client.ClientIdentityRepo {
|
|
||||||
return makeTestClientIdentityRepoFromClients(testClients)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetSetAdminClient(t *testing.T) {
|
func TestGetSetAdminClient(t *testing.T) {
|
||||||
|
@ -113,12 +93,14 @@ func TestGetSetAdminClient(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tests:
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestClientIdentityRepo()
|
repo := newClientIdentityRepo(t)
|
||||||
for _, cid := range startAdmins {
|
for _, cid := range startAdmins {
|
||||||
err := repo.SetDexAdmin(cid, true)
|
err := repo.SetDexAdmin(cid, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
t.Errorf("case %d: failed to set dex admin: %v", i, err)
|
||||||
|
continue Tests
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +112,7 @@ func TestGetSetAdminClient(t *testing.T) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||||
}
|
}
|
||||||
if gotAdmin != tt.wantAdmin {
|
if gotAdmin != tt.wantAdmin {
|
||||||
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
|
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
|
||||||
|
@ -138,12 +120,12 @@ func TestGetSetAdminClient(t *testing.T) {
|
||||||
|
|
||||||
err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
|
err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
gotAdmin, err = repo.IsDexAdmin(tt.cid)
|
gotAdmin, err = repo.IsDexAdmin(tt.cid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||||
}
|
}
|
||||||
if gotAdmin != tt.setAdmin {
|
if gotAdmin != tt.setAdmin {
|
||||||
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)
|
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package repo
|
package repo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -9,28 +8,16 @@ import (
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo
|
func newConnectorConfigRepo(t *testing.T, configs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
|
||||||
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||||
var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory
|
return connector.NewConnectorConfigRepoFromConfigs(configs)
|
||||||
|
|
||||||
func init() {
|
|
||||||
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
|
|
||||||
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
|
|
||||||
} else {
|
|
||||||
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn)
|
|
||||||
}
|
}
|
||||||
}
|
dbMap := connect(t)
|
||||||
|
|
||||||
func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory {
|
|
||||||
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
|
|
||||||
dbMap := initDB(dsn)
|
|
||||||
|
|
||||||
repo := db.NewConnectorConfigRepo(dbMap)
|
repo := db.NewConnectorConfigRepo(dbMap)
|
||||||
if err := repo.Set(cfgs); err != nil {
|
if err := repo.Set(configs); err != nil {
|
||||||
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
|
t.Fatalf("Unable to set connector configs: %v", err)
|
||||||
}
|
}
|
||||||
return repo
|
return repo
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectorConfigRepoGetByID(t *testing.T) {
|
func TestConnectorConfigRepoGetByID(t *testing.T) {
|
||||||
|
@ -63,7 +50,7 @@ func TestConnectorConfigRepoGetByID(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs)
|
repo := newConnectorConfigRepo(t, tt.cfgs)
|
||||||
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
|
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
|
||||||
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
|
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package repo
|
package repo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -12,8 +11,6 @@ import (
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
var makeTestPasswordInfoRepo func() user.PasswordInfoRepo
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testPWs = []user.PasswordInfo{
|
testPWs = []user.PasswordInfo{
|
||||||
{
|
{
|
||||||
|
@ -23,30 +20,16 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func newPasswordInfoRepo(t *testing.T) user.PasswordInfoRepo {
|
||||||
dsn := os.Getenv("DEX_TEST_DSN")
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||||
if dsn == "" {
|
|
||||||
makeTestPasswordInfoRepo = makeTestPasswordInfoRepoMem
|
|
||||||
} else {
|
|
||||||
makeTestPasswordInfoRepo = makeTestPasswordInfoRepoDB(dsn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTestPasswordInfoRepoMem() user.PasswordInfoRepo {
|
|
||||||
return user.NewPasswordInfoRepoFromPasswordInfos(testPWs)
|
return user.NewPasswordInfoRepoFromPasswordInfos(testPWs)
|
||||||
}
|
}
|
||||||
|
dbMap := connect(t)
|
||||||
func makeTestPasswordInfoRepoDB(dsn string) func() user.PasswordInfoRepo {
|
repo := db.NewPasswordInfoRepo(dbMap)
|
||||||
return func() user.PasswordInfoRepo {
|
if err := user.LoadPasswordInfos(repo, testPWs); err != nil {
|
||||||
c := initDB(dsn)
|
t.Fatalf("Unable to add password infos: %v", err)
|
||||||
|
|
||||||
repo := db.NewPasswordInfoRepo(c)
|
|
||||||
err := user.LoadPasswordInfos(repo, testPWs)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to add passwordInfos: %v", err))
|
|
||||||
}
|
}
|
||||||
return repo
|
return repo
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreatePasswordInfo(t *testing.T) {
|
func TestCreatePasswordInfo(t *testing.T) {
|
||||||
|
@ -87,7 +70,7 @@ func TestCreatePasswordInfo(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestPasswordInfoRepo()
|
repo := newPasswordInfoRepo(t)
|
||||||
err := repo.Create(nil, tt.pw)
|
err := repo.Create(nil, tt.pw)
|
||||||
if tt.err != nil {
|
if tt.err != nil {
|
||||||
if err != tt.err {
|
if err != tt.err {
|
||||||
|
@ -142,7 +125,7 @@ func TestUpdatePasswordInfo(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestPasswordInfoRepo()
|
repo := newPasswordInfoRepo(t)
|
||||||
err := repo.Update(nil, tt.pw)
|
err := repo.Update(nil, tt.pw)
|
||||||
if tt.err != nil {
|
if tt.err != nil {
|
||||||
if err != tt.err {
|
if err != tt.err {
|
||||||
|
|
|
@ -12,48 +12,26 @@ import (
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
var makeTestSessionRepo func() (session.SessionRepo, clockwork.FakeClock)
|
func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
|
||||||
var makeTestSessionKeyRepo func() (session.SessionKeyRepo, clockwork.FakeClock)
|
clock := clockwork.NewFakeClock()
|
||||||
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||||
func init() {
|
return session.NewSessionRepoWithClock(clock), clock
|
||||||
dsn := os.Getenv("DEX_TEST_DSN")
|
|
||||||
if dsn == "" {
|
|
||||||
makeTestSessionRepo = makeTestSessionRepoMem
|
|
||||||
makeTestSessionKeyRepo = makeTestSessionKeyRepoMem
|
|
||||||
} else {
|
|
||||||
makeTestSessionRepo = makeTestSessionRepoDB(dsn)
|
|
||||||
makeTestSessionKeyRepo = makeTestSessionKeyRepoDB(dsn)
|
|
||||||
}
|
}
|
||||||
|
dbMap := connect(t)
|
||||||
|
return db.NewSessionRepoWithClock(dbMap, clock), clock
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeTestSessionRepoMem() (session.SessionRepo, clockwork.FakeClock) {
|
func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
|
||||||
fc := clockwork.NewFakeClock()
|
clock := clockwork.NewFakeClock()
|
||||||
return session.NewSessionRepoWithClock(fc), fc
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||||
}
|
return session.NewSessionKeyRepoWithClock(clock), clock
|
||||||
|
|
||||||
func makeTestSessionRepoDB(dsn string) func() (session.SessionRepo, clockwork.FakeClock) {
|
|
||||||
return func() (session.SessionRepo, clockwork.FakeClock) {
|
|
||||||
c := initDB(dsn)
|
|
||||||
fc := clockwork.NewFakeClock()
|
|
||||||
return db.NewSessionRepoWithClock(c, fc), fc
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTestSessionKeyRepoMem() (session.SessionKeyRepo, clockwork.FakeClock) {
|
|
||||||
fc := clockwork.NewFakeClock()
|
|
||||||
return session.NewSessionKeyRepoWithClock(fc), fc
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTestSessionKeyRepoDB(dsn string) func() (session.SessionKeyRepo, clockwork.FakeClock) {
|
|
||||||
return func() (session.SessionKeyRepo, clockwork.FakeClock) {
|
|
||||||
c := initDB(dsn)
|
|
||||||
fc := clockwork.NewFakeClock()
|
|
||||||
return db.NewSessionKeyRepoWithClock(c, fc), fc
|
|
||||||
}
|
}
|
||||||
|
dbMap := connect(t)
|
||||||
|
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionKeyRepoPopNoExist(t *testing.T) {
|
func TestSessionKeyRepoPopNoExist(t *testing.T) {
|
||||||
r, _ := makeTestSessionKeyRepo()
|
r, _ := newSessionKeyRepo(t)
|
||||||
|
|
||||||
_, err := r.Pop("123")
|
_, err := r.Pop("123")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -62,7 +40,7 @@ func TestSessionKeyRepoPopNoExist(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionKeyRepoPushPop(t *testing.T) {
|
func TestSessionKeyRepoPushPop(t *testing.T) {
|
||||||
r, _ := makeTestSessionKeyRepo()
|
r, _ := newSessionKeyRepo(t)
|
||||||
|
|
||||||
key := "123"
|
key := "123"
|
||||||
sessionID := "456"
|
sessionID := "456"
|
||||||
|
@ -80,7 +58,7 @@ func TestSessionKeyRepoPushPop(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionKeyRepoExpired(t *testing.T) {
|
func TestSessionKeyRepoExpired(t *testing.T) {
|
||||||
r, fc := makeTestSessionKeyRepo()
|
r, fc := newSessionKeyRepo(t)
|
||||||
|
|
||||||
key := "123"
|
key := "123"
|
||||||
sessionID := "456"
|
sessionID := "456"
|
||||||
|
@ -96,7 +74,7 @@ func TestSessionKeyRepoExpired(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionRepoGetNoExist(t *testing.T) {
|
func TestSessionRepoGetNoExist(t *testing.T) {
|
||||||
r, _ := makeTestSessionRepo()
|
r, _ := newSessionRepo(t)
|
||||||
|
|
||||||
ses, err := r.Get("123")
|
ses, err := r.Get("123")
|
||||||
if ses != nil {
|
if ses != nil {
|
||||||
|
@ -129,7 +107,7 @@ func TestSessionRepoCreateGet(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
r, _ := makeTestSessionRepo()
|
r, _ := newSessionRepo(t)
|
||||||
|
|
||||||
r.Create(tt)
|
r.Create(tt)
|
||||||
|
|
||||||
|
@ -166,7 +144,7 @@ func TestSessionRepoCreateUpdate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
r, _ := makeTestSessionRepo()
|
r, _ := newSessionRepo(t)
|
||||||
r.Create(tt.initial)
|
r.Create(tt.initial)
|
||||||
|
|
||||||
ses, _ := r.Get(tt.initial.ID)
|
ses, _ := r.Get(tt.initial.ID)
|
||||||
|
@ -186,7 +164,7 @@ func TestSessionRepoCreateUpdate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionRepoUpdateNoExist(t *testing.T) {
|
func TestSessionRepoUpdateNoExist(t *testing.T) {
|
||||||
r, _ := makeTestSessionRepo()
|
r, _ := newSessionRepo(t)
|
||||||
|
|
||||||
err := r.Update(session.Session{ID: "123", ClientState: "boom"})
|
err := r.Update(session.Session{ID: "123", ClientState: "boom"})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
|
@ -1,28 +1,38 @@
|
||||||
package repo
|
package repo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-gorp/gorp"
|
||||||
|
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/go-gorp/gorp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func initDB(dsn string) *gorp.DbMap {
|
func connect(t *testing.T) *gorp.DbMap {
|
||||||
|
dsn := os.Getenv("DEX_TEST_DSN")
|
||||||
|
if dsn == "" {
|
||||||
|
t.Fatal("DEX_TEST_DSN environment variable not set")
|
||||||
|
}
|
||||||
c, err := db.NewConnection(db.Config{DSN: dsn})
|
c, err := db.NewConnection(db.Config{DSN: dsn})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Unable to connect to database: %v", err))
|
t.Fatalf("Unable to connect to database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = c.DropTablesIfExists(); err != nil {
|
if err = c.DropTablesIfExists(); err != nil {
|
||||||
panic(fmt.Sprintf("Unable to drop database tables: %v", err))
|
t.Fatalf("Unable to drop database tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = db.DropMigrationsTable(c); err != nil {
|
if err = db.DropMigrationsTable(c); err != nil {
|
||||||
panic(fmt.Sprintf("Unable to drop migration table: %v", err))
|
t.Fatalf("Unable to drop migration table: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = db.MigrateToLatest(c); err != nil {
|
n, err := db.MigrateToLatest(c)
|
||||||
panic(fmt.Sprintf("Unable to migrate: %v", err))
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to migrate: %v", err)
|
||||||
}
|
}
|
||||||
|
if n == 0 {
|
||||||
|
t.Fatalf("No migrations performed")
|
||||||
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,8 +13,6 @@ import (
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
var makeTestUserRepoFromUsers func(users []user.UserWithRemoteIdentities) user.UserRepo
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testUsers = []user.UserWithRemoteIdentities{
|
testUsers = []user.UserWithRemoteIdentities{
|
||||||
{
|
{
|
||||||
|
@ -47,34 +45,19 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func newUserRepo(t *testing.T, users []user.UserWithRemoteIdentities) user.UserRepo {
|
||||||
dsn := os.Getenv("DEX_TEST_DSN")
|
if users == nil {
|
||||||
if dsn == "" {
|
users = []user.UserWithRemoteIdentities{}
|
||||||
makeTestUserRepoFromUsers = makeTestUserRepoMem
|
|
||||||
} else {
|
|
||||||
makeTestUserRepoFromUsers = makeTestUserRepoDB(dsn)
|
|
||||||
}
|
}
|
||||||
}
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||||
|
|
||||||
func makeTestUserRepoMem(users []user.UserWithRemoteIdentities) user.UserRepo {
|
|
||||||
return user.NewUserRepoFromUsers(users)
|
return user.NewUserRepoFromUsers(users)
|
||||||
}
|
}
|
||||||
|
dbMap := connect(t)
|
||||||
func makeTestUserRepoDB(dsn string) func([]user.UserWithRemoteIdentities) user.UserRepo {
|
repo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||||
return func(users []user.UserWithRemoteIdentities) user.UserRepo {
|
|
||||||
c := initDB(dsn)
|
|
||||||
|
|
||||||
repo, err := db.NewUserRepoFromUsers(c, users)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Unable to add users: %v", err))
|
t.Fatalf("Unable to add users: %v", err)
|
||||||
}
|
}
|
||||||
return repo
|
return repo
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTestUserRepo() user.UserRepo {
|
|
||||||
return makeTestUserRepoFromUsers(testUsers)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewUser(t *testing.T) {
|
func TestNewUser(t *testing.T) {
|
||||||
|
@ -137,7 +120,7 @@ func TestNewUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestUserRepo()
|
repo := newUserRepo(t, testUsers)
|
||||||
err := repo.Create(nil, tt.user)
|
err := repo.Create(nil, tt.user)
|
||||||
if tt.err != nil {
|
if tt.err != nil {
|
||||||
if err != tt.err {
|
if err != tt.err {
|
||||||
|
@ -209,7 +192,7 @@ func TestUpdateUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestUserRepo()
|
repo := newUserRepo(t, testUsers)
|
||||||
err := repo.Update(nil, tt.user)
|
err := repo.Update(nil, tt.user)
|
||||||
if tt.err != nil {
|
if tt.err != nil {
|
||||||
if err != tt.err {
|
if err != tt.err {
|
||||||
|
@ -269,7 +252,7 @@ func TestDisableUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestUserRepo()
|
repo := newUserRepo(t, testUsers)
|
||||||
err := repo.Disable(nil, tt.id, tt.disable)
|
err := repo.Disable(nil, tt.id, tt.disable)
|
||||||
switch {
|
switch {
|
||||||
case err != tt.err:
|
case err != tt.err:
|
||||||
|
@ -320,7 +303,7 @@ func TestAttachRemoteIdentity(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestUserRepo()
|
repo := newUserRepo(t, testUsers)
|
||||||
err := repo.AddRemoteIdentity(nil, tt.id, tt.rid)
|
err := repo.AddRemoteIdentity(nil, tt.id, tt.rid)
|
||||||
if tt.err != nil {
|
if tt.err != nil {
|
||||||
if err != tt.err {
|
if err != tt.err {
|
||||||
|
@ -390,7 +373,7 @@ func TestRemoveRemoteIdentity(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestUserRepo()
|
repo := newUserRepo(t, testUsers)
|
||||||
err := repo.RemoveRemoteIdentity(nil, tt.id, tt.rid)
|
err := repo.RemoveRemoteIdentity(nil, tt.id, tt.rid)
|
||||||
if tt.err != nil {
|
if tt.err != nil {
|
||||||
if err != tt.err {
|
if err != tt.err {
|
||||||
|
@ -502,7 +485,7 @@ func TestGetByEmail(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestUserRepo()
|
repo := newUserRepo(t, testUsers)
|
||||||
gotUser, gotErr := repo.GetByEmail(nil, tt.email)
|
gotUser, gotErr := repo.GetByEmail(nil, tt.email)
|
||||||
if tt.wantErr != nil {
|
if tt.wantErr != nil {
|
||||||
if tt.wantErr != gotErr {
|
if tt.wantErr != gotErr {
|
||||||
|
@ -566,7 +549,7 @@ func TestGetAdminCount(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestUserRepo()
|
repo := newUserRepo(t, testUsers)
|
||||||
for _, addUser := range tt.addUsers {
|
for _, addUser := range tt.addUsers {
|
||||||
err := repo.Create(nil, addUser)
|
err := repo.Create(nil, addUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -621,7 +604,7 @@ func TestList(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo := makeTestUserRepoFromUsers(repoUsers)
|
repo := newUserRepo(t, repoUsers)
|
||||||
var tok string
|
var tok string
|
||||||
gotIDs := [][]string{}
|
gotIDs := [][]string{}
|
||||||
done := false
|
done := false
|
||||||
|
@ -651,7 +634,7 @@ func TestList(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListErrorNotFound(t *testing.T) {
|
func TestListErrorNotFound(t *testing.T) {
|
||||||
repo := makeTestUserRepoFromUsers(nil)
|
repo := newUserRepo(t, nil)
|
||||||
_, _, err := repo.List(nil, user.UserFilter{}, 10, "")
|
_, _, err := repo.List(nil, user.UserFilter{}, 10, "")
|
||||||
if err != user.ErrorNotFound {
|
if err != user.ErrorNotFound {
|
||||||
t.Errorf("want=%q, got=%q", user.ErrorNotFound, err)
|
t.Errorf("want=%q, got=%q", user.ErrorNotFound, err)
|
||||||
|
|
Reference in a new issue