package functional import ( "fmt" "net/url" "os" "testing" "time" "github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/oidc" "github.com/go-gorp/gorp" "github.com/kylelemons/godebug/pretty" "github.com/coreos/dex/client" "github.com/coreos/dex/db" "github.com/coreos/dex/refresh" "github.com/coreos/dex/session" ) var ( dsn string ) func init() { dsn = os.Getenv("DEX_TEST_DSN") if dsn == "" { fmt.Println("Unable to proceed with empty env var DEX_TEST_DSN") os.Exit(1) } } func connect(t *testing.T) *gorp.DbMap { c, err := db.NewConnection(db.Config{DSN: dsn}) if err != nil { t.Fatalf("Unable to connect to database: %v", err) } if err = c.DropTablesIfExists(); err != nil { t.Fatalf("Unable to drop database tables: %v", err) } if err = db.DropMigrationsTable(c); err != nil { panic(fmt.Sprintf("Unable to drop migration table: %v", err)) } db.MigrateToLatest(c) return c } func TestDBSessionKeyRepoPushPop(t *testing.T) { r := db.NewSessionKeyRepo(connect(t)) key := "123" sessionID := "456" r.Push(session.SessionKey{Key: key, SessionID: sessionID}, time.Second) got, err := r.Pop(key) if err != nil { t.Fatalf("Expected nil error: %v", err) } if got != sessionID { t.Fatalf("Incorrect sessionID: want=%s got=%s", sessionID, got) } // attempting to Pop a second time must fail if _, err := r.Pop(key); err == nil { t.Fatalf("Second call to Pop succeeded, expected non-nil error") } } func TestDBSessionRepoCreateUpdate(t *testing.T) { r := db.NewSessionRepo(connect(t)) // postgres stores its time type with a lower precision // than we generate here. Stripping off nanoseconds gives // us a predictable value to use in comparisions. now := time.Now().Round(time.Second).UTC() ses := session.Session{ ID: "AAA", State: session.SessionStateIdentified, CreatedAt: now, ExpiresAt: now.Add(time.Minute), ClientID: "ZZZ", ClientState: "foo", RedirectURL: url.URL{ Scheme: "http", Host: "example.com", Path: "/callback", }, Identity: oidc.Identity{ ID: "YYY", Name: "Elroy", Email: "elroy@example.com", ExpiresAt: now.Add(time.Minute), }, } if err := r.Create(ses); err != nil { t.Fatalf("Unexpected error: %v", err) } got, err := r.Get(ses.ID) if err != nil { t.Fatalf("Unexpected error: %v", err) } if diff := pretty.Compare(ses, got); diff != "" { t.Fatalf("Retrieved incorrect Session: Compare(want,got): %v", diff) } } func TestDBPrivateKeySetRepoSetGet(t *testing.T) { r, err := db.NewPrivateKeySetRepo(connect(t), "roflroflroflroflroflroflroflrofl") if err != nil { t.Fatalf(err.Error()) } k1, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Unable to generate RSA key: %v", err) } k2, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Unable to generate RSA key: %v", err) } ks := key.NewPrivateKeySet([]*key.PrivateKey{k1, k2}, time.Now().Add(time.Minute)) if err := r.Set(ks); err != nil { t.Fatalf("Unexpected error: %v", err) } got, err := r.Get() if err != nil { t.Fatalf("Unexpected error: %v", err) } if diff := pretty.Compare(ks, got); diff != "" { t.Fatalf("Retrieved incorrect KeySet: Compare(want,got): %v", diff) } } func TestDBClientIdentityRepoMetadata(t *testing.T) { r := db.NewClientIdentityRepo(connect(t)) cm := oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"}, url.URL{Scheme: "https", Host: "example.com", Path: "/callback"}, }, } _, err := r.New("foo", cm) if err != nil { t.Fatalf(err.Error()) } got, err := r.Metadata("foo") if err != nil { t.Fatalf(err.Error()) } if diff := pretty.Compare(cm, *got); diff != "" { t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff) } } func TestDBClientIdentityRepoMetadataNoExist(t *testing.T) { r := db.NewClientIdentityRepo(connect(t)) got, err := r.Metadata("noexist") if err != client.ErrorNotFound { t.Errorf("want==%q, got==%q", client.ErrorNotFound, err) } if got != nil { t.Fatalf("Retrieved incorrect ClientMetadata: want=nil got=%#v", got) } } func TestDBClientIdentityRepoNewDuplicate(t *testing.T) { r := db.NewClientIdentityRepo(connect(t)) meta1 := oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "foo.example.com"}, }, } if _, err := r.New("foo", meta1); err != nil { t.Fatalf("unexpected error: %v", err) } meta2 := oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "bar.example.com"}, }, } if _, err := r.New("foo", meta2); err == nil { t.Fatalf("expected non-nil error") } } func TestDBClientIdentityRepoAuthenticate(t *testing.T) { r := db.NewClientIdentityRepo(connect(t)) cm := oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"}, }, } cc, err := r.New("baz", cm) if err != nil { t.Fatalf(err.Error()) } if cc.ID != "baz" { t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID) } ok, err := r.Authenticate(*cc) if err != nil { t.Fatalf("Unexpected error: %v", err) } else if !ok { t.Fatalf("Authentication failed for good creds") } creds := []oidc.ClientCredentials{ // completely made up oidc.ClientCredentials{ID: "foo", Secret: "bar"}, // good client ID, bad secret oidc.ClientCredentials{ID: cc.ID, Secret: "bar"}, // bad client ID, good secret oidc.ClientCredentials{ID: "foo", Secret: cc.Secret}, // good client ID, secret with some fluff on the end oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)}, } for i, c := range creds { ok, err := r.Authenticate(c) if err != nil { t.Errorf("case %d: unexpected error: %v", i, err) } else if ok { t.Errorf("case %d: authentication succeeded for bad creds", i) } } } func TestDBClientIdentityAll(t *testing.T) { r := db.NewClientIdentityRepo(connect(t)) cm := oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"}, }, } _, err := r.New("foo", cm) if err != nil { t.Fatalf(err.Error()) } got, err := r.All() if err != nil { t.Fatalf(err.Error()) } count := len(got) if count != 1 { t.Fatalf("Retrieved incorrect number of ClientIdentities: want=1 got=%d", count) } if diff := pretty.Compare(cm, got[0].Metadata); diff != "" { t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff) } cm = oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"}, }, } _, err = r.New("bar", cm) if err != nil { t.Fatalf(err.Error()) } got, err = r.All() if err != nil { t.Fatalf(err.Error()) } count = len(got) if count != 2 { t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count) } } func TestDBRefreshRepoCreate(t *testing.T) { r := db.NewRefreshTokenRepo(connect(t)) tests := []struct { userID string clientID string err error }{ { "", "client-foo", refresh.ErrorInvalidUserID, }, { "user-foo", "", refresh.ErrorInvalidClientID, }, { "user-foo", "client-foo", nil, }, } for i, tt := range tests { _, err := r.Create(tt.userID, tt.clientID) if err != tt.err { t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) } } } func TestDBRefreshRepoVerify(t *testing.T) { r := db.NewRefreshTokenRepo(connect(t)) token, err := r.Create("user-foo", "client-foo") if err != nil { t.Fatalf("Unexpected error: %v", err) } tests := []struct { token string creds oidc.ClientCredentials err error expected string }{ { "invalid-token-foo", oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, refresh.ErrorInvalidToken, "", }, { token, oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"}, refresh.ErrorInvalidClientID, "", }, { token, oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, nil, "user-foo", }, } for i, tt := range tests { result, err := r.Verify(tt.creds.ID, tt.token) if err != tt.err { t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) } if result != tt.expected { t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result) } } } func TestDBRefreshRepoRevoke(t *testing.T) { r := db.NewRefreshTokenRepo(connect(t)) token, err := r.Create("user-foo", "client-foo") if err != nil { t.Fatalf("Unexpected error: %v", err) } tests := []struct { token string userID string err error }{ { "invalid-token-foo", "user-foo", refresh.ErrorInvalidToken, }, { token, "invalid-user", refresh.ErrorInvalidUserID, }, { token, "user-foo", nil, }, } for i, tt := range tests { if err := r.Revoke(tt.userID, tt.token); err != tt.err { t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) } } }