*: remove in memory client repo

The DB implementation expects secrets to be base64 encoded blobs.
Because of this a bunch of tests broke moving to sqlite.

A lot of this commit is fixing those tests.
This commit is contained in:
Eric Chiang 2016-02-09 15:06:07 -08:00
parent 72d1ecab64
commit b572b8dd6c
14 changed files with 231 additions and 435 deletions

View file

@ -1,16 +1,10 @@
package client package client
import ( import (
"encoding/base64"
"encoding/json"
"errors" "errors"
"io"
"io/ioutil"
"net/url" "net/url"
"reflect" "reflect"
"sort"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
@ -46,146 +40,6 @@ type ClientIdentityRepo interface {
IsDexAdmin(clientID string) (bool, error) IsDexAdmin(clientID string) (bool, error)
} }
func NewClientIdentityRepo(cs []oidc.ClientIdentity) ClientIdentityRepo {
cr := memClientIdentityRepo{
idents: make(map[string]oidc.ClientIdentity, len(cs)),
admins: make(map[string]bool),
}
for _, c := range cs {
c := c
cr.idents[c.Credentials.ID] = c
}
return &cr
}
type memClientIdentityRepo struct {
idents map[string]oidc.ClientIdentity
admins map[string]bool
}
func (cr *memClientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) {
if _, ok := cr.idents[id]; ok {
return nil, errors.New("client ID already exists")
}
secret, err := pcrypto.RandBytes(32)
if err != nil {
return nil, err
}
cc := oidc.ClientCredentials{
ID: id,
Secret: base64.URLEncoding.EncodeToString(secret),
}
cr.idents[id] = oidc.ClientIdentity{
Metadata: meta,
Credentials: cc,
}
return &cc, nil
}
func (cr *memClientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
ci, ok := cr.idents[clientID]
if !ok {
return nil, ErrorNotFound
}
return &ci.Metadata, nil
}
func (cr *memClientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
ci, ok := cr.idents[creds.ID]
ok = ok && ci.Credentials.Secret == creds.Secret
return ok, nil
}
func (cr *memClientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
cs := make(sortableClientIdentities, 0, len(cr.idents))
for _, ci := range cr.idents {
ci := ci
cs = append(cs, ci)
}
sort.Sort(cs)
return cs, nil
}
func (cr *memClientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
cr.admins[clientID] = isAdmin
return nil
}
func (cr *memClientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
return cr.admins[clientID], nil
}
type sortableClientIdentities []oidc.ClientIdentity
func (s sortableClientIdentities) Len() int {
return len([]oidc.ClientIdentity(s))
}
func (s sortableClientIdentities) Less(i, j int) bool {
return s[i].Credentials.ID < s[j].Credentials.ID
}
func (s sortableClientIdentities) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func NewClientIdentityRepoFromReader(r io.Reader) (ClientIdentityRepo, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
var cs []clientIdentity
if err = json.Unmarshal(b, &cs); err != nil {
return nil, err
}
ocs := make([]oidc.ClientIdentity, len(cs))
for i, c := range cs {
ocs[i] = oidc.ClientIdentity(c)
}
return NewClientIdentityRepo(ocs), nil
}
type clientIdentity oidc.ClientIdentity
func (ci *clientIdentity) UnmarshalJSON(data []byte) error {
c := struct {
ID string `json:"id"`
Secret string `json:"secret"`
RedirectURLs []string `json:"redirectURLs"`
}{}
if err := json.Unmarshal(data, &c); err != nil {
return err
}
ci.Credentials = oidc.ClientCredentials{
ID: c.ID,
Secret: c.Secret,
}
ci.Metadata = oidc.ClientMetadata{
RedirectURIs: make([]url.URL, len(c.RedirectURLs)),
}
for i, us := range c.RedirectURLs {
up, err := url.Parse(us)
if err != nil {
return err
}
ci.Metadata.RedirectURIs[i] = *up
}
return nil
}
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise. // ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.
// If nil is passed in as the rURL and there is only one URL in redirectURLs, // If nil is passed in as the rURL and there is only one URL in redirectURLs,
// that URL will be returned. If nil is passed but theres >1 URL in the slice, // that URL will be returned. If nil is passed but theres >1 URL in the slice,

View file

@ -1,190 +0,0 @@
package client
import (
"encoding/json"
"net/url"
"reflect"
"sort"
"testing"
"github.com/coreos/go-oidc/oidc"
)
func TestMemClientIdentityRepoNew(t *testing.T) {
tests := []struct {
id string
meta oidc.ClientMetadata
}{
{
id: "foo",
meta: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{
Scheme: "https",
Host: "example.com",
},
},
},
},
{
id: "bar",
meta: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "example.com/foo"},
url.URL{Scheme: "https", Host: "example.com/bar"},
},
},
},
}
for i, tt := range tests {
cr := NewClientIdentityRepo(nil)
creds, err := cr.New(tt.id, tt.meta)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if creds.ID != tt.id {
t.Errorf("case %d: expected non-empty Client ID", i)
}
if creds.Secret == "" {
t.Errorf("case %d: expected non-empty Client Secret", i)
}
all, err := cr.All()
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if len(all) != 1 {
t.Errorf("case %d: expected repo to contain newly created Client", i)
}
wantURLs := tt.meta.RedirectURIs
gotURLs := all[0].Metadata.RedirectURIs
if !reflect.DeepEqual(wantURLs, gotURLs) {
t.Errorf("case %d: redirect url mismatch, want=%v, got=%v", i, wantURLs, gotURLs)
}
}
}
func TestMemClientIdentityRepoNewDuplicate(t *testing.T) {
cr := NewClientIdentityRepo(nil)
meta1 := oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "foo.example.com"},
},
}
if _, err := cr.New("foo", meta1); err != nil {
t.Errorf("unexpected error: %v", err)
}
meta2 := oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "bar.example.com"},
},
}
if _, err := cr.New("foo", meta2); err == nil {
t.Errorf("expected non-nil error")
}
}
func TestMemClientIdentityRepoAll(t *testing.T) {
tests := []struct {
ids []string
}{
{
ids: nil,
},
{
ids: []string{"foo"},
},
{
ids: []string{"foo", "bar"},
},
}
for i, tt := range tests {
cs := make([]oidc.ClientIdentity, len(tt.ids))
for i, s := range tt.ids {
cs[i] = oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
ID: s,
},
}
}
cr := NewClientIdentityRepo(cs)
all, err := cr.All()
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
want := sortableClientIdentities(cs)
sort.Sort(want)
got := sortableClientIdentities(all)
sort.Sort(got)
if len(got) != len(want) {
t.Errorf("case %d: wrong length: %d", i, len(got))
}
if !reflect.DeepEqual(want, got) {
t.Errorf("case %d: want=%#v, got=%#v", i, want, got)
}
}
}
func TestClientIdentityUnmarshalJSON(t *testing.T) {
for i, test := range []struct {
json string
expectedID string
expectedSecret string
expectedURLs []string
}{
{
json: `{"id":"12345","secret":"rosebud","redirectURLs":["https://redirectone.com", "https://redirecttwo.com"]}`,
expectedID: "12345",
expectedSecret: "rosebud",
expectedURLs: []string{
"https://redirectone.com",
"https://redirecttwo.com",
},
},
} {
var actual clientIdentity
err := json.Unmarshal([]byte(test.json), &actual)
if err != nil {
t.Errorf("case %d: error unmarshalling: %v", i, err)
continue
}
if actual.Credentials.ID != test.expectedID {
t.Errorf("case %d: actual.Credentials.ID == %v, want %v", i, actual.Credentials.ID, test.expectedID)
}
if actual.Credentials.Secret != test.expectedSecret {
t.Errorf("case %d: actual.Credentials.Secret == %v, want %v", i, actual.Credentials.Secret, test.expectedSecret)
}
expectedURLs := test.expectedURLs
sort.Strings(expectedURLs)
actualURLs := make([]string, 0)
for _, u := range actual.Metadata.RedirectURIs {
actualURLs = append(actualURLs, u.String())
}
sort.Strings(actualURLs)
if len(actualURLs) != len(expectedURLs) {
t.Errorf("case %d: len(actualURLs) == %v, want %v", i, len(actualURLs), len(expectedURLs))
}
for ui, actualURL := range actualURLs {
if actualURL != expectedURLs[ui] {
t.Errorf("case %d: actualURLs[%d] == %q, want %q", i, ui, actualURL, expectedURLs[ui])
}
}
}
}

View file

@ -1,11 +1,13 @@
package repo package repo
import ( import (
"encoding/base64"
"net/url" "net/url"
"os" "os"
"testing" "testing"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
@ -16,7 +18,7 @@ var (
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "client1", ID: "client1",
Secret: "secret-1", Secret: base64.URLEncoding.EncodeToString([]byte("secret-1")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -30,7 +32,7 @@ var (
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "client2", ID: "client2",
Secret: "secret-2", Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -46,10 +48,12 @@ var (
func newClientIdentityRepo(t *testing.T) client.ClientIdentityRepo { func newClientIdentityRepo(t *testing.T) client.ClientIdentityRepo {
dsn := os.Getenv("DEX_TEST_DSN") dsn := os.Getenv("DEX_TEST_DSN")
var dbMap *gorp.DbMap
if dsn == "" { if dsn == "" {
return client.NewClientIdentityRepo(testClients) dbMap = db.NewMemDB()
} else {
dbMap = connect(t)
} }
dbMap := connect(t)
repo, err := db.NewClientIdentityRepoFromClients(dbMap, testClients) repo, err := db.NewClientIdentityRepoFromClients(dbMap, testClients)
if err != nil { if err != nil {
t.Fatalf("failed to create client repo from clients: %v", err) t.Fatalf("failed to create client repo from clients: %v", err)

View file

@ -1,7 +1,9 @@
package integration package integration
import ( import (
"encoding/base64"
"net/http" "net/http"
"net/url"
"reflect" "reflect"
"testing" "testing"
@ -13,7 +15,12 @@ func TestClientCreate(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: "72de74a9",
Secret: "XXX", Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
{Scheme: "https://", Host: "authn.example.com", Path: "/callback"},
},
}, },
} }
cis := []oidc.ClientIdentity{ci} cis := []oidc.ClientIdentity{ci}
@ -54,7 +61,7 @@ func TestClientCreate(t *testing.T) {
call := svc.Clients.Create(newClientInput) call := svc.Clients.Create(newClientInput)
newClient, err := call.Do() newClient, err := call.Do()
if err != nil { if err != nil {
t.Errorf("Call to create client API failed: %v", err) t.Fatalf("Call to create client API failed: %v", err)
} }
if newClient.Id == "" { if newClient.Id == "" {

View file

@ -1,6 +1,7 @@
package integration package integration
import ( import (
"encoding/base64"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -21,7 +22,7 @@ var (
testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"} testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"}
testClientID = "XXX" testClientID = "XXX"
testClientSecret = "yyy" testClientSecret = base64.URLEncoding.EncodeToString([]byte("yyy"))
testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"} testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"} testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
testPrivKey, _ = key.GeneratePrivateKey() testPrivKey, _ = key.GeneratePrivateKey()

View file

@ -1,6 +1,7 @@
package integration package integration
import ( import (
"encoding/base64"
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
@ -8,7 +9,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
phttp "github.com/coreos/dex/pkg/http" phttp "github.com/coreos/dex/pkg/http"
@ -23,6 +23,7 @@ import (
) )
func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) { func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
dbMap := db.NewMemDB()
k, err := key.GeneratePrivateKey() k, err := key.GeneratePrivateKey()
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to generate private key: %v", err) return nil, fmt.Errorf("Unable to generate private key: %v", err)
@ -33,12 +34,16 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
clientIdentityRepo, err := db.NewClientIdentityRepoFromClients(dbMap, cis)
if err != nil {
return nil, err
}
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
srv := &server.Server{ srv := &server.Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
ClientIdentityRepo: client.NewClientIdentityRepo(cis), ClientIdentityRepo: clientIdentityRepo,
SessionManager: sm, SessionManager: sm,
} }
@ -114,14 +119,18 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: "72de74a9",
Secret: "XXX", Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
}, },
} }
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) dbMap := db.NewMemDB()
cir, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: " + err.Error())
}
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
k, err := key.GeneratePrivateKey() k, err := key.GeneratePrivateKey()
if err != nil { if err != nil {
@ -253,7 +262,7 @@ func TestHTTPClientCredsToken(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: "72de74a9",
Secret: "XXX", Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
}, },
} }
cis := []oidc.ClientIdentity{ci} cis := []oidc.ClientIdentity{ci}

View file

@ -1,6 +1,7 @@
package integration package integration
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -15,6 +16,7 @@ import (
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/server" "github.com/coreos/dex/server"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
@ -97,7 +99,8 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
_, _, um := makeUserObjects(userUsers, userPasswords) _, _, um := makeUserObjects(userUsers, userPasswords)
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ cir := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: testClientID, ID: testClientID,
@ -112,7 +115,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: userBadClientID, ID: userBadClientID,
Secret: "secret", Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -121,6 +124,11 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
}, },
}, },
}) })
if err != nil {
panic("Failed to create client identity repo: " + err.Error())
}
return repo
}()
cir.SetDexAdmin(testClientID, true) cir.SetDexAdmin(testClientID, true)

View file

@ -1,13 +1,16 @@
package server package server
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"time" "time"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
@ -26,9 +29,18 @@ func TestClientToken(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: validClientID, ID: validClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
},
}, },
} }
repo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
privKey, err := key.GeneratePrivateKey() privKey, err := key.GeneratePrivateKey()
if err != nil { if err != nil {
@ -102,7 +114,7 @@ func TestClientToken(t *testing.T) {
// empty repo // empty repo
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: client.NewClientIdentityRepo(nil), repo: db.NewClientIdentityRepo(db.NewMemDB()),
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -9,12 +10,14 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
"sort"
"strings" "strings"
"testing" "testing"
"github.com/coreos/dex/client" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/kylelemons/godebug/pretty"
) )
func makeBody(s string) io.ReadCloser { func makeBody(s string) io.ReadCloser {
@ -24,7 +27,7 @@ func makeBody(s string) io.ReadCloser {
func TestCreateInvalidRequest(t *testing.T) { func TestCreateInvalidRequest(t *testing.T) {
u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"} u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
h := http.Header{"Content-Type": []string{"application/json"}} h := http.Header{"Content-Type": []string{"application/json"}}
repo := client.NewClientIdentityRepo(nil) repo := db.NewClientIdentityRepo(db.NewMemDB())
res := &clientResource{repo: repo} res := &clientResource{repo: repo}
tests := []struct { tests := []struct {
req *http.Request req *http.Request
@ -115,7 +118,7 @@ func TestCreateInvalidRequest(t *testing.T) {
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
repo := client.NewClientIdentityRepo(nil) repo := db.NewClientIdentityRepo(db.NewMemDB())
res := &clientResource{repo: repo} res := &clientResource{repo: repo}
tests := [][]string{ tests := [][]string{
[]string{"http://example.com"}, []string{"http://example.com"},
@ -168,6 +171,11 @@ func TestCreate(t *testing.T) {
} }
func TestList(t *testing.T) { func TestList(t *testing.T) {
b64Encode := func(s string) string {
return base64.URLEncoding.EncodeToString([]byte(s))
}
tests := []struct { tests := []struct {
cs []oidc.ClientIdentity cs []oidc.ClientIdentity
want []*schema.Client want []*schema.Client
@ -181,7 +189,7 @@ func TestList(t *testing.T) {
{ {
cs: []oidc.ClientIdentity{ cs: []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: "bar"}, Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"}, url.URL{Scheme: "http", Host: "example.com"},
@ -200,7 +208,7 @@ func TestList(t *testing.T) {
{ {
cs: []oidc.ClientIdentity{ cs: []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: "bar"}, Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"}, url.URL{Scheme: "http", Host: "example.com"},
@ -208,7 +216,7 @@ func TestList(t *testing.T) {
}, },
}, },
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ID: "biz", Secret: "bang"}, Credentials: oidc.ClientCredentials{ID: "biz", Secret: b64Encode("bang")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"}, url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"},
@ -230,7 +238,11 @@ func TestList(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := client.NewClientIdentityRepo(tt.cs) repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), tt.cs)
if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
continue
}
res := &clientResource{repo: repo} res := &clientResource{repo: repo}
r, err := http.NewRequest("GET", "http://example.com/clients", nil) r, err := http.NewRequest("GET", "http://example.com/clients", nil)
@ -248,9 +260,17 @@ func TestList(t *testing.T) {
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Errorf("case %d: unexpected error=%v", i, err) t.Errorf("case %d: unexpected error=%v", i, err)
} }
sort.Sort(byClientId(tt.want))
sort.Sort(byClientId(resp.Clients))
if !reflect.DeepEqual(tt.want, resp.Clients) { if diff := pretty.Compare(tt.want, resp.Clients); diff != "" {
t.Errorf("case %d: invalid response body, want=%#v, got=%#v", i, tt.want, resp.Clients) t.Errorf("case %d: invalid response body: %s", i, diff)
} }
} }
} }
type byClientId []*schema.Client
func (b byClientId) Len() int { return len(b) }
func (b byClientId) Less(i, j int) bool { return b[i].Id < b[j].Id }
func (b byClientId) Swap(i, j int) { b[i], b[j] = b[j], b[i] }

View file

@ -12,9 +12,9 @@ import (
"time" "time"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/pkg/health" "github.com/coreos/pkg/health"
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
@ -113,10 +113,14 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err) return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
} }
defer cf.Close() defer cf.Close()
ciRepo, err := client.NewClientIdentityRepoFromReader(cf) var clients []oidc.ClientIdentity
if err != nil { if err := json.NewDecoder(cf).Decode(&clients); err != nil {
return fmt.Errorf("unable to read client identities from file %s: %v", cfg.ClientsFile, err) return fmt.Errorf("unable to read client identities from file %s: %v", cfg.ClientsFile, err)
} }
ciRepo, err := db.NewClientIdentityRepoFromClients(dbMap, clients)
if err != nil {
return fmt.Errorf("failed to create client identity repo: %v", err)
}
f, err := os.Open(cfg.ConnectorsFile) f, err := os.Open(cfg.ConnectorsFile)
if err != nil { if err != nil {

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -77,11 +78,12 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())), SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ ClientIdentityRepo: func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -89,7 +91,12 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
}, },
}, },
}, },
}), })
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}(),
} }
tests := []struct { tests := []struct {
@ -200,11 +207,12 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())), SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ ClientIdentityRepo: func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -213,7 +221,12 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
}, },
}, },
}, },
}), })
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}(),
} }
tests := []struct { tests := []struct {

View file

@ -21,6 +21,8 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
) )
var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secrete"))
type StaticKeyManager struct { type StaticKeyManager struct {
key.PrivateKeyManager key.PrivateKeyManager
expiresAt time.Time expiresAt time.Time
@ -180,7 +182,7 @@ func TestServerLogin(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -192,7 +194,13 @@ func TestServerLogin(t *testing.T) {
}, },
}, },
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) ciRepo := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
@ -236,13 +244,20 @@ func TestServerLogin(t *testing.T) {
} }
func TestServerLoginUnrecognizedSessionKey(t *testing.T) { func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ ciRepo := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", Secret: "secrete", ID: "XXX", Secret: clientTestSecret,
}, },
}, },
}) })
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: nil, err: errors.New("fail")}, signer: &StaticSigner{sig: nil, err: errors.New("fail")},
} }
@ -269,7 +284,7 @@ func TestServerLoginDisabledUser(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -281,7 +296,13 @@ func TestServerLoginDisabledUser(t *testing.T) {
}, },
}, },
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) ciRepo := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
@ -337,10 +358,16 @@ func TestServerCodeToken(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
}, },
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) ciRepo := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
@ -417,10 +444,16 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
}, },
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) ciRepo := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
@ -460,7 +493,7 @@ func TestServerTokenFail(t *testing.T) {
keyFixture := "goodkey" keyFixture := "goodkey"
ccFixture := oidc.ClientCredentials{ ccFixture := oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
} }
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
@ -536,9 +569,13 @@ func TestServerTokenFail(t *testing.T) {
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: tt.signer, signer: tt.signer,
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{Credentials: ccFixture}, oidc.ClientIdentity{Credentials: ccFixture},
}) })
if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
continue
}
_, err = sm.AttachUser(sessionID, "testid-1") _, err = sm.AttachUser(sessionID, "testid-1")
if err != nil { if err != nil {
@ -589,11 +626,11 @@ func TestServerRefreshToken(t *testing.T) {
credXXX := oidc.ClientCredentials{ credXXX := oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secret", Secret: clientTestSecret,
} }
credYYY := oidc.ClientCredentials{ credYYY := oidc.ClientCredentials{
ID: "YYY", ID: "YYY",
Secret: "secret", Secret: clientTestSecret,
} }
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
@ -694,10 +731,14 @@ func TestServerRefreshToken(t *testing.T) {
signer: tt.signer, signer: tt.signer,
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{Credentials: credXXX}, oidc.ClientIdentity{Credentials: credXXX},
oidc.ClientIdentity{Credentials: credYYY}, oidc.ClientIdentity{Credentials: credYYY},
}) })
if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
continue
}
userRepo, err := makeNewUserRepo() userRepo, err := makeNewUserRepo()
if err != nil { if err != nil {
@ -743,10 +784,13 @@ func TestServerRefreshToken(t *testing.T) {
signer: signerFixture, signer: signerFixture,
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{Credentials: credXXX}, oidc.ClientIdentity{Credentials: credXXX},
oidc.ClientIdentity{Credentials: credYYY}, oidc.ClientIdentity{Credentials: credYYY},
}) })
if err != nil {
t.Fatalf("failed to create client identity repo: %v", err)
}
userRepo, err := makeNewUserRepo() userRepo, err := makeNewUserRepo()
if err != nil { if err != nil {

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"time" "time"
@ -26,7 +27,6 @@ const (
var ( var (
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"} testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
testClientID = "XXX" testClientID = "XXX"
testClientSecret = "secrete"
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"} testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
@ -133,11 +133,11 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err return nil, err
} }
clientIdentityRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ clientIdentityRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: testClientSecret, Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -146,6 +146,9 @@ func makeTestFixtures() (*testFixtures, error) {
}, },
}, },
}) })
if err != nil {
return nil, err
}
km := key.NewPrivateKeyManager() km := key.NewPrivateKeyManager()
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute))) err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)))

View file

@ -1,6 +1,7 @@
package api package api
import ( import (
"encoding/base64"
"net/url" "net/url"
"testing" "testing"
"time" "time"
@ -148,7 +149,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -156,7 +157,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}, },
}, },
} }
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) cir := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
panic("Failed to create client identity repo: " + err.Error())
}
return repo
}()
emailer := &testEmailer{} emailer := &testEmailer{}
api := NewUsersAPI(mgr, cir, emailer, "local") api := NewUsersAPI(mgr, cir, emailer, "local")