542 lines
12 KiB
Go
542 lines
12 KiB
Go
package functional
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/key"
|
|
"github.com/coreos/go-oidc/oidc"
|
|
"github.com/go-gorp/gorp"
|
|
"github.com/kylelemons/godebug/pretty"
|
|
|
|
"github.com/coreos/dex/client"
|
|
"github.com/coreos/dex/db"
|
|
"github.com/coreos/dex/refresh"
|
|
"github.com/coreos/dex/session"
|
|
)
|
|
|
|
func connect(t *testing.T) *gorp.DbMap {
|
|
dsn := os.Getenv("DEX_TEST_DSN")
|
|
if dsn == "" {
|
|
t.Fatal("Unable to proceed with empty env var DEX_TEST_DSN")
|
|
}
|
|
c, err := db.NewConnection(db.Config{DSN: dsn})
|
|
if err != nil {
|
|
t.Fatalf("Unable to connect to database: %v", err)
|
|
}
|
|
if err = c.DropTablesIfExists(); err != nil {
|
|
t.Fatalf("Unable to drop database tables: %v", err)
|
|
}
|
|
|
|
if err = db.DropMigrationsTable(c); err != nil {
|
|
t.Fatalf("Unable to drop migration table: %v", err)
|
|
}
|
|
|
|
if _, err = db.MigrateToLatest(c); err != nil {
|
|
t.Fatalf("Unable to migrate: %v", err)
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
func TestDBSessionKeyRepoPushPop(t *testing.T) {
|
|
r := db.NewSessionKeyRepo(connect(t))
|
|
|
|
key := "123"
|
|
sessionID := "456"
|
|
|
|
r.Push(session.SessionKey{Key: key, SessionID: sessionID}, time.Second)
|
|
|
|
got, err := r.Pop(key)
|
|
if err != nil {
|
|
t.Fatalf("Expected nil error: %v", err)
|
|
}
|
|
if got != sessionID {
|
|
t.Fatalf("Incorrect sessionID: want=%s got=%s", sessionID, got)
|
|
}
|
|
|
|
// attempting to Pop a second time must fail
|
|
if _, err := r.Pop(key); err == nil {
|
|
t.Fatalf("Second call to Pop succeeded, expected non-nil error")
|
|
}
|
|
}
|
|
|
|
func TestDBSessionRepoCreateUpdate(t *testing.T) {
|
|
r := db.NewSessionRepo(connect(t))
|
|
|
|
// postgres stores its time type with a lower precision
|
|
// than we generate here. Stripping off nanoseconds gives
|
|
// us a predictable value to use in comparisions.
|
|
now := time.Now().Round(time.Second).UTC()
|
|
|
|
ses := session.Session{
|
|
ID: "AAA",
|
|
State: session.SessionStateIdentified,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(time.Minute),
|
|
ClientID: "ZZZ",
|
|
ClientState: "foo",
|
|
RedirectURL: url.URL{
|
|
Scheme: "http",
|
|
Host: "example.com",
|
|
Path: "/callback",
|
|
},
|
|
Identity: oidc.Identity{
|
|
ID: "YYY",
|
|
Name: "Elroy",
|
|
Email: "elroy@example.com",
|
|
ExpiresAt: now.Add(time.Minute),
|
|
},
|
|
}
|
|
|
|
if err := r.Create(ses); err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
got, err := r.Get(ses.ID)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if diff := pretty.Compare(ses, got); diff != "" {
|
|
t.Fatalf("Retrieved incorrect Session: Compare(want,got): %v", diff)
|
|
}
|
|
}
|
|
|
|
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
|
|
s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
|
|
s2 := []byte("oooooooooooooooooooooooooooooooo")
|
|
s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")
|
|
|
|
keys := []*key.PrivateKey{}
|
|
for i := 0; i < 2; i++ {
|
|
k, err := key.GeneratePrivateKey()
|
|
if err != nil {
|
|
t.Fatalf("Unable to generate RSA key: %v", err)
|
|
}
|
|
keys = append(keys, k)
|
|
}
|
|
|
|
ks := key.NewPrivateKeySet(
|
|
[]*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute))
|
|
|
|
tests := []struct {
|
|
setSecrets [][]byte
|
|
getSecrets [][]byte
|
|
wantErr bool
|
|
}{
|
|
{
|
|
// same secrets used to encrypt, decrypt
|
|
setSecrets: [][]byte{s1, s2},
|
|
getSecrets: [][]byte{s1, s2},
|
|
},
|
|
{
|
|
// setSecrets got rotated, but getSecrets didn't yet.
|
|
setSecrets: [][]byte{s2, s3},
|
|
getSecrets: [][]byte{s1, s2},
|
|
},
|
|
{
|
|
// getSecrets doesn't have s3
|
|
setSecrets: [][]byte{s3},
|
|
getSecrets: [][]byte{s1, s2},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
dbMap := connect(t)
|
|
setRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.setSecrets...)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
getRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.getSecrets...)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
if err := setRepo.Set(ks); err != nil {
|
|
t.Fatalf("case %d: Unexpected error: %v", i, err)
|
|
}
|
|
|
|
got, err := getRepo.Get()
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("case %d: want err, got nil", i)
|
|
}
|
|
continue
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("case %d: Unexpected error: %v", i, err)
|
|
}
|
|
|
|
if diff := pretty.Compare(ks, got); diff != "" {
|
|
t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff)
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityRepoMetadata(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
cm := oidc.ClientMetadata{
|
|
RedirectURIs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
|
|
url.URL{Scheme: "https", Host: "example.com", Path: "/callback"},
|
|
},
|
|
}
|
|
|
|
_, err := r.New("foo", cm)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
got, err := r.Metadata("foo")
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
if diff := pretty.Compare(cm, *got); diff != "" {
|
|
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityRepoMetadataNoExist(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
got, err := r.Metadata("noexist")
|
|
if err != client.ErrorNotFound {
|
|
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
|
|
}
|
|
if got != nil {
|
|
t.Fatalf("Retrieved incorrect ClientMetadata: want=nil got=%#v", got)
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityRepoNewDuplicate(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
meta1 := oidc.ClientMetadata{
|
|
RedirectURIs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "foo.example.com"},
|
|
},
|
|
}
|
|
|
|
if _, err := r.New("foo", meta1); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
meta2 := oidc.ClientMetadata{
|
|
RedirectURIs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "bar.example.com"},
|
|
},
|
|
}
|
|
|
|
if _, err := r.New("foo", meta2); err == nil {
|
|
t.Fatalf("expected non-nil error")
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityRepoAuthenticate(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
cm := oidc.ClientMetadata{
|
|
RedirectURIs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
|
|
},
|
|
}
|
|
|
|
cc, err := r.New("baz", cm)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
if cc.ID != "baz" {
|
|
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
|
|
}
|
|
|
|
ok, err := r.Authenticate(*cc)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
} else if !ok {
|
|
t.Fatalf("Authentication failed for good creds")
|
|
}
|
|
|
|
creds := []oidc.ClientCredentials{
|
|
// completely made up
|
|
oidc.ClientCredentials{ID: "foo", Secret: "bar"},
|
|
|
|
// good client ID, bad secret
|
|
oidc.ClientCredentials{ID: cc.ID, Secret: "bar"},
|
|
|
|
// bad client ID, good secret
|
|
oidc.ClientCredentials{ID: "foo", Secret: cc.Secret},
|
|
|
|
// good client ID, secret with some fluff on the end
|
|
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
|
|
}
|
|
for i, c := range creds {
|
|
ok, err := r.Authenticate(c)
|
|
if err != nil {
|
|
t.Errorf("case %d: unexpected error: %v", i, err)
|
|
} else if ok {
|
|
t.Errorf("case %d: authentication succeeded for bad creds", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDBClientIdentityAll(t *testing.T) {
|
|
r := db.NewClientIdentityRepo(connect(t))
|
|
|
|
cm := oidc.ClientMetadata{
|
|
RedirectURIs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
|
|
},
|
|
}
|
|
|
|
_, err := r.New("foo", cm)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
got, err := r.All()
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
count := len(got)
|
|
if count != 1 {
|
|
t.Fatalf("Retrieved incorrect number of ClientIdentities: want=1 got=%d", count)
|
|
}
|
|
|
|
if diff := pretty.Compare(cm, got[0].Metadata); diff != "" {
|
|
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
|
|
}
|
|
|
|
cm = oidc.ClientMetadata{
|
|
RedirectURIs: []url.URL{
|
|
url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
|
|
},
|
|
}
|
|
_, err = r.New("bar", cm)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
got, err = r.All()
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
count = len(got)
|
|
if count != 2 {
|
|
t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count)
|
|
}
|
|
}
|
|
|
|
// buildRefreshToken combines the token ID and token payload to create a new token.
|
|
// used in the tests to created a refresh token.
|
|
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
|
|
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
|
|
}
|
|
|
|
func TestDBRefreshRepoCreate(t *testing.T) {
|
|
r := db.NewRefreshTokenRepo(connect(t))
|
|
|
|
tests := []struct {
|
|
userID string
|
|
clientID string
|
|
err error
|
|
}{
|
|
{
|
|
"",
|
|
"client-foo",
|
|
refresh.ErrorInvalidUserID,
|
|
},
|
|
{
|
|
"user-foo",
|
|
"",
|
|
refresh.ErrorInvalidClientID,
|
|
},
|
|
{
|
|
"user-foo",
|
|
"client-foo",
|
|
nil,
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
token, err := r.Create(tt.userID, tt.clientID)
|
|
if err != nil {
|
|
if tt.err == nil {
|
|
t.Errorf("case %d: create failed: %v", i, err)
|
|
}
|
|
continue
|
|
}
|
|
if tt.err != nil {
|
|
t.Errorf("case %d: expected error, didn't get one", i)
|
|
continue
|
|
}
|
|
userID, err := r.Verify(tt.clientID, token)
|
|
if err != nil {
|
|
t.Errorf("case %d: failed to verify good token: %v", i, err)
|
|
continue
|
|
}
|
|
if userID != tt.userID {
|
|
t.Errorf("case %d: want userID=%s, got userID=%s", i, tt.userID, userID)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDBRefreshRepoVerify(t *testing.T) {
|
|
r := db.NewRefreshTokenRepo(connect(t))
|
|
|
|
token, err := r.Create("user-foo", "client-foo")
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
tokenWithBadID := "404" + token[1:]
|
|
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
|
|
|
tests := []struct {
|
|
token string
|
|
creds oidc.ClientCredentials
|
|
err error
|
|
expected string
|
|
}{
|
|
{
|
|
"invalid-token-format",
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
"b/invalid-base64-encoded-format",
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
"1/invalid-base64-encoded-format",
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
token + "corrupted-token-payload",
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
// The token's ID content is invalid.
|
|
tokenWithBadID,
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
// The token's payload content is invalid.
|
|
tokenWithBadPayload,
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidToken,
|
|
"",
|
|
},
|
|
{
|
|
token,
|
|
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
|
|
refresh.ErrorInvalidClientID,
|
|
"",
|
|
},
|
|
{
|
|
token,
|
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
|
|
nil,
|
|
"user-foo",
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
result, err := r.Verify(tt.creds.ID, tt.token)
|
|
if err != tt.err {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
|
}
|
|
if result != tt.expected {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDBRefreshRepoRevoke(t *testing.T) {
|
|
r := db.NewRefreshTokenRepo(connect(t))
|
|
|
|
token, err := r.Create("user-foo", "client-foo")
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
tokenWithBadID := "404" + token[1:]
|
|
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
|
|
|
|
tests := []struct {
|
|
token string
|
|
userID string
|
|
err error
|
|
}{
|
|
{
|
|
"invalid-token-format",
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
"1/invalid-base64-encoded-format",
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
token + "corrupted-token-payload",
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
// The token's ID is invalid.
|
|
tokenWithBadID,
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
// The token's payload is invalid.
|
|
tokenWithBadPayload,
|
|
"user-foo",
|
|
refresh.ErrorInvalidToken,
|
|
},
|
|
{
|
|
token,
|
|
"invalid-user",
|
|
refresh.ErrorInvalidUserID,
|
|
},
|
|
{
|
|
token,
|
|
"user-foo",
|
|
nil,
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
if err := r.Revoke(tt.userID, tt.token); err != tt.err {
|
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
|
}
|
|
}
|
|
}
|