154 lines
3.1 KiB
Go
154 lines
3.1 KiB
Go
|
package repo
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/coreos/go-oidc/oidc"
|
||
|
|
||
|
"github.com/coreos/dex/client"
|
||
|
"github.com/coreos/dex/db"
|
||
|
)
|
||
|
|
||
|
var makeTestClientIdentityRepoFromClients func(clients []oidc.ClientIdentity) client.ClientIdentityRepo
|
||
|
|
||
|
var (
|
||
|
testClients = []oidc.ClientIdentity{
|
||
|
oidc.ClientIdentity{
|
||
|
Credentials: oidc.ClientCredentials{
|
||
|
ID: "client1",
|
||
|
Secret: "secret-1",
|
||
|
},
|
||
|
Metadata: oidc.ClientMetadata{
|
||
|
RedirectURLs: []url.URL{
|
||
|
url.URL{
|
||
|
Scheme: "https",
|
||
|
Host: "client1.example.com/callback",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
oidc.ClientIdentity{
|
||
|
Credentials: oidc.ClientCredentials{
|
||
|
ID: "client2",
|
||
|
Secret: "secret-2",
|
||
|
},
|
||
|
Metadata: oidc.ClientMetadata{
|
||
|
RedirectURLs: []url.URL{
|
||
|
url.URL{
|
||
|
Scheme: "https",
|
||
|
Host: "client2.example.com/callback",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
)
|
||
|
|
||
|
func init() {
|
||
|
dsn := os.Getenv("DEX_TEST_DSN")
|
||
|
if dsn == "" {
|
||
|
makeTestClientIdentityRepoFromClients = makeTestClientIdentityRepoMem
|
||
|
} else {
|
||
|
makeTestClientIdentityRepoFromClients = makeTestClientIdentityRepoDB(dsn)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func makeTestClientIdentityRepoMem(clients []oidc.ClientIdentity) client.ClientIdentityRepo {
|
||
|
return client.NewClientIdentityRepo(clients)
|
||
|
}
|
||
|
|
||
|
func makeTestClientIdentityRepoDB(dsn string) func([]oidc.ClientIdentity) client.ClientIdentityRepo {
|
||
|
return func(clients []oidc.ClientIdentity) client.ClientIdentityRepo {
|
||
|
c := initDB(dsn)
|
||
|
|
||
|
repo, err := db.NewClientIdentityRepoFromClients(c, clients)
|
||
|
if err != nil {
|
||
|
panic(fmt.Sprintf("Unable to add clients: %v", err))
|
||
|
}
|
||
|
return repo
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
func makeTestClientIdentityRepo() client.ClientIdentityRepo {
|
||
|
return makeTestClientIdentityRepoFromClients(testClients)
|
||
|
}
|
||
|
|
||
|
func TestGetSetAdminClient(t *testing.T) {
|
||
|
startAdmins := []string{"client2"}
|
||
|
tests := []struct {
|
||
|
// client ID
|
||
|
cid string
|
||
|
|
||
|
// initial state of client
|
||
|
wantAdmin bool
|
||
|
|
||
|
// final state of client
|
||
|
setAdmin bool
|
||
|
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{
|
||
|
cid: "client1",
|
||
|
wantAdmin: false,
|
||
|
setAdmin: true,
|
||
|
},
|
||
|
{
|
||
|
cid: "client1",
|
||
|
wantAdmin: false,
|
||
|
setAdmin: false,
|
||
|
},
|
||
|
{
|
||
|
cid: "client2",
|
||
|
wantAdmin: true,
|
||
|
setAdmin: true,
|
||
|
},
|
||
|
{
|
||
|
cid: "client2",
|
||
|
wantAdmin: true,
|
||
|
setAdmin: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, tt := range tests {
|
||
|
repo := makeTestClientIdentityRepo()
|
||
|
for _, cid := range startAdmins {
|
||
|
err := repo.SetDexAdmin(cid, true)
|
||
|
if err != nil {
|
||
|
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
gotAdmin, err := repo.IsDexAdmin(tt.cid)
|
||
|
if tt.wantErr {
|
||
|
if err == nil {
|
||
|
t.Errorf("case %d: want non-nil err", i)
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
if err != nil {
|
||
|
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||
|
}
|
||
|
if gotAdmin != tt.wantAdmin {
|
||
|
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
|
||
|
}
|
||
|
|
||
|
err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
|
||
|
if err != nil {
|
||
|
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||
|
}
|
||
|
|
||
|
gotAdmin, err = repo.IsDexAdmin(tt.cid)
|
||
|
if err != nil {
|
||
|
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||
|
}
|
||
|
if gotAdmin != tt.setAdmin {
|
||
|
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|