forked from mystiq/dex
84bc8073de
All repo tests build their tables by applying all the migrations. This way we know our migrations are functional.
416 lines
8.8 KiB
Go
416 lines
8.8 KiB
Go
package functional
|
|
|
|
import (
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/coopernurse/gorp"
|
|
"github.com/coreos/go-oidc/key"
|
|
"github.com/coreos/go-oidc/oidc"
|
|
"github.com/kylelemons/godebug/pretty"
|
|
|
|
"github.com/coreos/dex/client"
|
|
"github.com/coreos/dex/db"
|
|
"github.com/coreos/dex/refresh"
|
|
"github.com/coreos/dex/session"
|
|
)
|
|
|
|
var (
|
|
dsn string
|
|
)
|
|
|
|
func init() {
|
|
dsn = os.Getenv("DEX_TEST_DSN")
|
|
if dsn == "" {
|
|
fmt.Println("Unable to proceed with empty env var DEX_TEST_DSN")
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func connect(t *testing.T) *gorp.DbMap {
|
|
c, err := db.NewConnection(db.Config{DSN: dsn})
|
|
if err != nil {
|
|
t.Fatalf("Unable to connect to database: %v", err)
|
|
}
|
|
|
|
if err = c.DropTablesIfExists(); err != nil {
|
|
t.Fatalf("Unable to drop database tables: %v", err)
|
|
}
|
|
|
|
if err = db.DropMigrationsTable(c); err != nil {
|
|
panic(fmt.Sprintf("Unable to drop migration table: %v", err))
|
|
}
|
|
|
|
db.MigrateToLatest(c)
|
|
|
|
return c
|
|
}
|
|
|
|
func TestDBSessionKeyRepoPushPop(t *testing.T) {
|
|
r := db.NewSessionKeyRepo(connect(t))
|
|
|
|
key := "123"
|
|
sessionID := "456"
|
|
|
|
r.Push(session.SessionKey{Key: key, SessionID: sessionID}, time.Second)
|
|
|
|
got, err := r.Pop(key)
|
|
if err != nil {
|
|
t.Fatalf("Expected nil error: %v", err)
|
|
}
|
|
if got != sessionID {
|
|
t.Fatalf("Incorrect sessionID: want=%s got=%s", sessionID, got)
|
|
}
|
|
|
|
// attempting to Pop a second time must fail
|
|
if _, err := r.Pop(key); err == nil {
|
|
t.Fatalf("Second call to Pop succeeded, expected non-nil error")
|
|
}
|
|
}
|
|
|
|
func TestDBSessionRepoCreateUpdate(t *testing.T) {
|
|
r := db.NewSessionRepo(connect(t))
|
|
|
|
// postgres stores its time type with a lower precision
|
|
// than we generate here. Stripping off nanoseconds gives
|
|
// us a predictable value to use in comparisions.
|
|
now := time.Now().Round(time.Second).UTC()
|
|
|
|
ses := session.Session{
|
|
ID: "AAA",
|
|
State: session.SessionStateIdentified,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(time.Minute),
|
|
ClientID: "ZZZ",
|
|
ClientState: "foo",
|
|
RedirectURL: url.URL{
|
|
Scheme: "http",
|
|
Host: "example.com",
|
|
Path: "/callback",
|
|
},
|
|
Identity: oidc.Identity{
|
|
ID: "YYY",
|
|
Name: "Elroy",
|
|
Email: "elroy@example.com",
|
|
ExpiresAt: now.Add(time.Minute),
|
|
},
|
|
}
|
|
|
|
if err := r.Create(ses); err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
got, err := r.Get(ses.ID)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if diff := pretty.Compare(ses, got); diff != "" {
|
|
t.Fatalf("Retrieved incorrect Session: Compare(want,got): %v", diff)
|
|
}
|
|
}
|
|
|
|
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
|
|
r, err := db.NewPrivateKeySetRepo(connect(t), "roflroflroflroflroflroflroflrofl")
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
k1, err := key.GeneratePrivateKey()
|
|
if err != nil {
|
|
t.Fatalf("Unable to generate RSA key: %v", err)
|
|
}
|
|
|
|
k2, err := key.GeneratePrivateKey()
|
|
if err != nil {
|
|
t.Fatalf("Unable to generate RSA key: %v", err)
|
|
}
|
|
|
|
ks := key.NewPrivateKeySet([]*key.PrivateKey{k1, k2}, time.Now().Add(time.Minute))
|
|
if err := r.Set(ks); err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
got, err := r.Get()
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if diff := pretty.Compare(ks, got); diff != "" {
|
|
t.Fatalf("Retrieved incorrect KeySet: Compare(want,got): %v", diff)
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityRepoMetadata(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
cm := oidc.ClientMetadata{
|
|
RedirectURLs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
|
|
url.URL{Scheme: "https", Host: "example.com", Path: "/callback"},
|
|
},
|
|
}
|
|
|
|
_, err := r.New("foo", cm)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
got, err := r.Metadata("foo")
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
if diff := pretty.Compare(cm, *got); diff != "" {
|
|
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityRepoMetadataNoExist(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
got, err := r.Metadata("noexist")
|
|
if err != client.ErrorNotFound {
|
|
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
|
|
}
|
|
if got != nil {
|
|
t.Fatalf("Retrieved incorrect ClientMetadata: want=nil got=%#v", got)
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityRepoNewDuplicate(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
meta1 := oidc.ClientMetadata{
|
|
RedirectURLs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "foo.example.com"},
|
|
},
|
|
}
|
|
|
|
if _, err := r.New("foo", meta1); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
meta2 := oidc.ClientMetadata{
|
|
RedirectURLs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "bar.example.com"},
|
|
},
|
|
}
|
|
|
|
if _, err := r.New("foo", meta2); err == nil {
|
|
t.Fatalf("expected non-nil error")
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityRepoAuthenticate(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
cm := oidc.ClientMetadata{
|
|
RedirectURLs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
|
|
},
|
|
}
|
|
|
|
cc, err := r.New("baz", cm)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
if cc.ID != "baz" {
|
|
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
|
|
}
|
|
|
|
ok, err := r.Authenticate(*cc)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
} else if !ok {
|
|
t.Fatalf("Authentication failed for good creds")
|
|
}
|
|
|
|
creds := []oidc.ClientCredentials{
|
|
// completely made up
|
|
oidc.ClientCredentials{ID: "foo", Secret: "bar"},
|
|
|
|
// good client ID, bad secret
|
|
oidc.ClientCredentials{ID: cc.ID, Secret: "bar"},
|
|
|
|
// bad client ID, good secret
|
|
oidc.ClientCredentials{ID: "foo", Secret: cc.Secret},
|
|
|
|
// good client ID, secret with some fluff on the end
|
|
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
|
|
}
|
|
for i, c := range creds {
|
|
ok, err := r.Authenticate(c)
|
|
if err != nil {
|
|
t.Errorf("case %d: unexpected error: %v", i, err)
|
|
} else if ok {
|
|
t.Errorf("case %d: authentication succeeded for bad creds", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityAll(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
cm := oidc.ClientMetadata{
|
|
RedirectURLs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
|
|
},
|
|
}
|
|
|
|
_, err := r.New("foo", cm)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
got, err := r.All()
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
count := len(got)
|
|
if count != 1 {
|
|
t.Fatalf("Retrieved incorrect number of ClientIdentities: want=1 got=%d", count)
|
|
}
|
|
|
|
if diff := pretty.Compare(cm, got[0].Metadata); diff != "" {
|
|
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
|
|
}
|
|
|
|
cm = oidc.ClientMetadata{
|
|
RedirectURLs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
|
|
},
|
|
}
|
|
_, err = r.New("bar", cm)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
got, err = r.All()
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
count = len(got)
|
|
if count != 2 {
|
|
t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count)
|
|
}
|
|
}
|
|
|
|
func TestDBRefreshRepoCreate(t *testing.T) {
|
|
r := db.NewRefreshTokenRepo(connect(t))
|
|
|
|
tests := []struct {
|
|
userID string
|
|
clientID string
|
|
err error
|
|
}{
|
|
{
|
|
"",
|
|
"client-foo",
|
|
refresh.ErrorInvalidUserID,
|
|
},
|
|
{
|
|
"user-foo",
|
|
"",
|
|
refresh.ErrorInvalidClientID,
|
|
},
|
|
{
|
|
"user-foo",
|
|
"client-foo",
|
|
nil,
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
_, err := r.Create(tt.userID, tt.clientID)
|
|
if err != tt.err {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDBRefreshRepoVerify(t *testing.T) {
|
|
r := db.NewRefreshTokenRepo(connect(t))
|
|
|
|
token, err := r.Create("user-foo", "client-foo")
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
token string
|
|
creds oidc.ClientCredentials
|
|
err error
|
|
expected string
|
|
}{
|
|
{
|
|
"invalid-token-foo",
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
token,
|
|
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidClientID,
|
|
"",
|
|
},
|
|
{
|
|
token,
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
nil,
|
|
"user-foo",
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
result, err := r.Verify(tt.creds.ID, tt.token)
|
|
if err != tt.err {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
|
}
|
|
if result != tt.expected {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDBRefreshRepoRevoke(t *testing.T) {
|
|
r := db.NewRefreshTokenRepo(connect(t))
|
|
|
|
token, err := r.Create("user-foo", "client-foo")
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
token string
|
|
userID string
|
|
err error
|
|
}{
|
|
{
|
|
"invalid-token-foo",
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
token,
|
|
"invalid-user",
|
|
refresh.ErrorInvalidUserID,
|
|
},
|
|
{
|
|
token,
|
|
"user-foo",
|
|
nil,
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
if err := r.Revoke(tt.userID, tt.token); err != tt.err {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
|
}
|
|
}
|
|
}
|