diff --git a/client/client.go b/client/client.go index ad932404..e8d8b194 100644 --- a/client/client.go +++ b/client/client.go @@ -1,16 +1,10 @@ package client import ( - "encoding/base64" - "encoding/json" "errors" - "io" - "io/ioutil" "net/url" "reflect" - "sort" - pcrypto "github.com/coreos/dex/pkg/crypto" "github.com/coreos/go-oidc/oidc" ) @@ -46,146 +40,6 @@ type ClientIdentityRepo interface { IsDexAdmin(clientID string) (bool, error) } -func NewClientIdentityRepo(cs []oidc.ClientIdentity) ClientIdentityRepo { - cr := memClientIdentityRepo{ - idents: make(map[string]oidc.ClientIdentity, len(cs)), - admins: make(map[string]bool), - } - - for _, c := range cs { - c := c - cr.idents[c.Credentials.ID] = c - } - - return &cr -} - -type memClientIdentityRepo struct { - idents map[string]oidc.ClientIdentity - admins map[string]bool -} - -func (cr *memClientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) { - if _, ok := cr.idents[id]; ok { - return nil, errors.New("client ID already exists") - } - - secret, err := pcrypto.RandBytes(32) - if err != nil { - return nil, err - } - - cc := oidc.ClientCredentials{ - ID: id, - Secret: base64.URLEncoding.EncodeToString(secret), - } - - cr.idents[id] = oidc.ClientIdentity{ - Metadata: meta, - Credentials: cc, - } - - return &cc, nil -} - -func (cr *memClientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) { - ci, ok := cr.idents[clientID] - if !ok { - return nil, ErrorNotFound - } - return &ci.Metadata, nil -} - -func (cr *memClientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) { - ci, ok := cr.idents[creds.ID] - ok = ok && ci.Credentials.Secret == creds.Secret - return ok, nil -} - -func (cr *memClientIdentityRepo) All() ([]oidc.ClientIdentity, error) { - cs := make(sortableClientIdentities, 0, len(cr.idents)) - for _, ci := range cr.idents { - ci := ci - cs = append(cs, ci) - } - sort.Sort(cs) - return cs, nil -} - -func (cr *memClientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error { - cr.admins[clientID] = isAdmin - return nil -} - -func (cr *memClientIdentityRepo) IsDexAdmin(clientID string) (bool, error) { - return cr.admins[clientID], nil -} - -type sortableClientIdentities []oidc.ClientIdentity - -func (s sortableClientIdentities) Len() int { - return len([]oidc.ClientIdentity(s)) -} - -func (s sortableClientIdentities) Less(i, j int) bool { - return s[i].Credentials.ID < s[j].Credentials.ID -} - -func (s sortableClientIdentities) Swap(i, j int) { - s[i], s[j] = s[j], s[i] -} - -func NewClientIdentityRepoFromReader(r io.Reader) (ClientIdentityRepo, error) { - b, err := ioutil.ReadAll(r) - if err != nil { - return nil, err - } - - var cs []clientIdentity - if err = json.Unmarshal(b, &cs); err != nil { - return nil, err - } - - ocs := make([]oidc.ClientIdentity, len(cs)) - for i, c := range cs { - ocs[i] = oidc.ClientIdentity(c) - } - - return NewClientIdentityRepo(ocs), nil -} - -type clientIdentity oidc.ClientIdentity - -func (ci *clientIdentity) UnmarshalJSON(data []byte) error { - c := struct { - ID string `json:"id"` - Secret string `json:"secret"` - RedirectURLs []string `json:"redirectURLs"` - }{} - - if err := json.Unmarshal(data, &c); err != nil { - return err - } - - ci.Credentials = oidc.ClientCredentials{ - ID: c.ID, - Secret: c.Secret, - } - ci.Metadata = oidc.ClientMetadata{ - RedirectURIs: make([]url.URL, len(c.RedirectURLs)), - } - - for i, us := range c.RedirectURLs { - up, err := url.Parse(us) - if err != nil { - return err - } - ci.Metadata.RedirectURIs[i] = *up - } - - return nil -} - // ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise. // If nil is passed in as the rURL and there is only one URL in redirectURLs, // that URL will be returned. If nil is passed but theres >1 URL in the slice, diff --git a/client/client_test.go b/client/client_test.go deleted file mode 100644 index 666e1f05..00000000 --- a/client/client_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package client - -import ( - "encoding/json" - "net/url" - "reflect" - "sort" - "testing" - - "github.com/coreos/go-oidc/oidc" -) - -func TestMemClientIdentityRepoNew(t *testing.T) { - tests := []struct { - id string - meta oidc.ClientMetadata - }{ - { - id: "foo", - meta: oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - url.URL{ - Scheme: "https", - Host: "example.com", - }, - }, - }, - }, - { - id: "bar", - meta: oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - url.URL{Scheme: "https", Host: "example.com/foo"}, - url.URL{Scheme: "https", Host: "example.com/bar"}, - }, - }, - }, - } - - for i, tt := range tests { - cr := NewClientIdentityRepo(nil) - creds, err := cr.New(tt.id, tt.meta) - if err != nil { - t.Errorf("case %d: unexpected error: %v", i, err) - } - - if creds.ID != tt.id { - t.Errorf("case %d: expected non-empty Client ID", i) - } - - if creds.Secret == "" { - t.Errorf("case %d: expected non-empty Client Secret", i) - } - - all, err := cr.All() - if err != nil { - t.Errorf("case %d: unexpected error: %v", i, err) - } - if len(all) != 1 { - t.Errorf("case %d: expected repo to contain newly created Client", i) - } - - wantURLs := tt.meta.RedirectURIs - gotURLs := all[0].Metadata.RedirectURIs - if !reflect.DeepEqual(wantURLs, gotURLs) { - t.Errorf("case %d: redirect url mismatch, want=%v, got=%v", i, wantURLs, gotURLs) - } - } -} - -func TestMemClientIdentityRepoNewDuplicate(t *testing.T) { - cr := NewClientIdentityRepo(nil) - - meta1 := oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - url.URL{Scheme: "https", Host: "foo.example.com"}, - }, - } - - if _, err := cr.New("foo", meta1); err != nil { - t.Errorf("unexpected error: %v", err) - } - - meta2 := oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - url.URL{Scheme: "https", Host: "bar.example.com"}, - }, - } - - if _, err := cr.New("foo", meta2); err == nil { - t.Errorf("expected non-nil error") - } -} - -func TestMemClientIdentityRepoAll(t *testing.T) { - tests := []struct { - ids []string - }{ - { - ids: nil, - }, - { - ids: []string{"foo"}, - }, - { - ids: []string{"foo", "bar"}, - }, - } - - for i, tt := range tests { - cs := make([]oidc.ClientIdentity, len(tt.ids)) - for i, s := range tt.ids { - cs[i] = oidc.ClientIdentity{ - Credentials: oidc.ClientCredentials{ - ID: s, - }, - } - } - - cr := NewClientIdentityRepo(cs) - - all, err := cr.All() - if err != nil { - t.Errorf("case %d: unexpected error: %v", i, err) - } - - want := sortableClientIdentities(cs) - sort.Sort(want) - got := sortableClientIdentities(all) - sort.Sort(got) - - if len(got) != len(want) { - t.Errorf("case %d: wrong length: %d", i, len(got)) - } - - if !reflect.DeepEqual(want, got) { - t.Errorf("case %d: want=%#v, got=%#v", i, want, got) - } - } -} - -func TestClientIdentityUnmarshalJSON(t *testing.T) { - for i, test := range []struct { - json string - expectedID string - expectedSecret string - expectedURLs []string - }{ - { - json: `{"id":"12345","secret":"rosebud","redirectURLs":["https://redirectone.com", "https://redirecttwo.com"]}`, - expectedID: "12345", - expectedSecret: "rosebud", - expectedURLs: []string{ - "https://redirectone.com", - "https://redirecttwo.com", - }, - }, - } { - var actual clientIdentity - err := json.Unmarshal([]byte(test.json), &actual) - if err != nil { - t.Errorf("case %d: error unmarshalling: %v", i, err) - continue - } - - if actual.Credentials.ID != test.expectedID { - t.Errorf("case %d: actual.Credentials.ID == %v, want %v", i, actual.Credentials.ID, test.expectedID) - } - - if actual.Credentials.Secret != test.expectedSecret { - t.Errorf("case %d: actual.Credentials.Secret == %v, want %v", i, actual.Credentials.Secret, test.expectedSecret) - } - expectedURLs := test.expectedURLs - sort.Strings(expectedURLs) - - actualURLs := make([]string, 0) - for _, u := range actual.Metadata.RedirectURIs { - actualURLs = append(actualURLs, u.String()) - } - sort.Strings(actualURLs) - if len(actualURLs) != len(expectedURLs) { - t.Errorf("case %d: len(actualURLs) == %v, want %v", i, len(actualURLs), len(expectedURLs)) - } - for ui, actualURL := range actualURLs { - if actualURL != expectedURLs[ui] { - t.Errorf("case %d: actualURLs[%d] == %q, want %q", i, ui, actualURL, expectedURLs[ui]) - } - } - } -} diff --git a/functional/repo/client_repo_test.go b/functional/repo/client_repo_test.go index 0f77fd39..6c580d68 100644 --- a/functional/repo/client_repo_test.go +++ b/functional/repo/client_repo_test.go @@ -1,11 +1,13 @@ package repo import ( + "encoding/base64" "net/url" "os" "testing" "github.com/coreos/go-oidc/oidc" + "github.com/go-gorp/gorp" "github.com/coreos/dex/client" "github.com/coreos/dex/db" @@ -16,7 +18,7 @@ var ( oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "client1", - Secret: "secret-1", + Secret: base64.URLEncoding.EncodeToString([]byte("secret-1")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ @@ -30,7 +32,7 @@ var ( oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "client2", - Secret: "secret-2", + Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ @@ -46,10 +48,12 @@ var ( func newClientIdentityRepo(t *testing.T) client.ClientIdentityRepo { dsn := os.Getenv("DEX_TEST_DSN") + var dbMap *gorp.DbMap if dsn == "" { - return client.NewClientIdentityRepo(testClients) + dbMap = db.NewMemDB() + } else { + dbMap = connect(t) } - dbMap := connect(t) repo, err := db.NewClientIdentityRepoFromClients(dbMap, testClients) if err != nil { t.Fatalf("failed to create client repo from clients: %v", err) diff --git a/integration/client_api_test.go b/integration/client_api_test.go index 46a9b376..3b80eb33 100644 --- a/integration/client_api_test.go +++ b/integration/client_api_test.go @@ -1,7 +1,9 @@ package integration import ( + "encoding/base64" "net/http" + "net/url" "reflect" "testing" @@ -13,7 +15,12 @@ func TestClientCreate(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "72de74a9", - Secret: "XXX", + Secret: base64.URLEncoding.EncodeToString([]byte("XXX")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + {Scheme: "https://", Host: "authn.example.com", Path: "/callback"}, + }, }, } cis := []oidc.ClientIdentity{ci} @@ -54,7 +61,7 @@ func TestClientCreate(t *testing.T) { call := svc.Clients.Create(newClientInput) newClient, err := call.Do() if err != nil { - t.Errorf("Call to create client API failed: %v", err) + t.Fatalf("Call to create client API failed: %v", err) } if newClient.Id == "" { diff --git a/integration/common_test.go b/integration/common_test.go index fc8b0d04..eed147ef 100644 --- a/integration/common_test.go +++ b/integration/common_test.go @@ -1,6 +1,7 @@ package integration import ( + "encoding/base64" "fmt" "io/ioutil" "net/http" @@ -21,7 +22,7 @@ var ( testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"} testClientID = "XXX" - testClientSecret = "yyy" + testClientSecret = base64.URLEncoding.EncodeToString([]byte("yyy")) testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"} testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"} testPrivKey, _ = key.GeneratePrivateKey() diff --git a/integration/oidc_test.go b/integration/oidc_test.go index cacdd526..8c092953 100644 --- a/integration/oidc_test.go +++ b/integration/oidc_test.go @@ -1,6 +1,7 @@ package integration import ( + "encoding/base64" "fmt" "html/template" "net/http" @@ -8,7 +9,6 @@ import ( "testing" "time" - "github.com/coreos/dex/client" "github.com/coreos/dex/connector" "github.com/coreos/dex/db" phttp "github.com/coreos/dex/pkg/http" @@ -23,6 +23,7 @@ import ( ) func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) { + dbMap := db.NewMemDB() k, err := key.GeneratePrivateKey() if err != nil { return nil, fmt.Errorf("Unable to generate private key: %v", err) @@ -33,12 +34,16 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) { if err != nil { return nil, err } + clientIdentityRepo, err := db.NewClientIdentityRepoFromClients(dbMap, cis) + if err != nil { + return nil, err + } - sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) + sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap)) srv := &server.Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, - ClientIdentityRepo: client.NewClientIdentityRepo(cis), + ClientIdentityRepo: clientIdentityRepo, SessionManager: sm, } @@ -114,14 +119,18 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "72de74a9", - Secret: "XXX", + Secret: base64.URLEncoding.EncodeToString([]byte("XXX")), }, } - cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) + dbMap := db.NewMemDB() + cir, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci}) + if err != nil { + t.Fatalf("Failed to create client identity repo: " + err.Error()) + } issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} - sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) + sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap)) k, err := key.GeneratePrivateKey() if err != nil { @@ -253,7 +262,7 @@ func TestHTTPClientCredsToken(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "72de74a9", - Secret: "XXX", + Secret: base64.URLEncoding.EncodeToString([]byte("XXX")), }, } cis := []oidc.ClientIdentity{ci} diff --git a/integration/user_api_test.go b/integration/user_api_test.go index 6e7eaeaf..9293e3ec 100644 --- a/integration/user_api_test.go +++ b/integration/user_api_test.go @@ -1,6 +1,7 @@ package integration import ( + "encoding/base64" "fmt" "net/http" "net/http/httptest" @@ -15,6 +16,7 @@ import ( "google.golang.org/api/googleapi" "github.com/coreos/dex/client" + "github.com/coreos/dex/db" schema "github.com/coreos/dex/schema/workerschema" "github.com/coreos/dex/server" "github.com/coreos/dex/user" @@ -97,30 +99,36 @@ func makeUserAPITestFixtures() *userAPITestFixtures { _, _, um := makeUserObjects(userUsers, userPasswords) - cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ - oidc.ClientIdentity{ - Credentials: oidc.ClientCredentials{ - ID: testClientID, - Secret: testClientSecret, - }, - Metadata: oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - testRedirectURL, + cir := func() client.ClientIdentityRepo { + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ + oidc.ClientIdentity{ + Credentials: oidc.ClientCredentials{ + ID: testClientID, + Secret: testClientSecret, + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + testRedirectURL, + }, }, }, - }, - oidc.ClientIdentity{ - Credentials: oidc.ClientCredentials{ - ID: userBadClientID, - Secret: "secret", - }, - Metadata: oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - testRedirectURL, + oidc.ClientIdentity{ + Credentials: oidc.ClientCredentials{ + ID: userBadClientID, + Secret: base64.URLEncoding.EncodeToString([]byte("secret")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + testRedirectURL, + }, }, }, - }, - }) + }) + if err != nil { + panic("Failed to create client identity repo: " + err.Error()) + } + return repo + }() cir.SetDexAdmin(testClientID, true) diff --git a/server/auth_middleware_test.go b/server/auth_middleware_test.go index 693f971e..0b0f1b2a 100644 --- a/server/auth_middleware_test.go +++ b/server/auth_middleware_test.go @@ -1,13 +1,16 @@ package server import ( + "encoding/base64" "fmt" "net/http" "net/http/httptest" + "net/url" "testing" "time" "github.com/coreos/dex/client" + "github.com/coreos/dex/db" "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/oidc" @@ -25,10 +28,19 @@ func TestClientToken(t *testing.T) { validClientID := "valid-client" ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ - ID: validClientID, + ID: validClientID, + Secret: base64.URLEncoding.EncodeToString([]byte("secret")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + {Scheme: "https", Host: "authn.example.com", Path: "/callback"}, + }, }, } - repo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) + if err != nil { + t.Fatalf("Failed to create client identity repo: %v", err) + } privKey, err := key.GeneratePrivateKey() if err != nil { @@ -102,7 +114,7 @@ func TestClientToken(t *testing.T) { // empty repo { keys: []key.PublicKey{pubKey}, - repo: client.NewClientIdentityRepo(nil), + repo: db.NewClientIdentityRepo(db.NewMemDB()), header: fmt.Sprintf("BEARER %s", validJWT), wantCode: http.StatusUnauthorized, }, diff --git a/server/client_resource_test.go b/server/client_resource_test.go index 9557966c..ab527120 100644 --- a/server/client_resource_test.go +++ b/server/client_resource_test.go @@ -1,6 +1,7 @@ package server import ( + "encoding/base64" "encoding/json" "fmt" "io" @@ -9,12 +10,14 @@ import ( "net/http/httptest" "net/url" "reflect" + "sort" "strings" "testing" - "github.com/coreos/dex/client" + "github.com/coreos/dex/db" schema "github.com/coreos/dex/schema/workerschema" "github.com/coreos/go-oidc/oidc" + "github.com/kylelemons/godebug/pretty" ) func makeBody(s string) io.ReadCloser { @@ -24,7 +27,7 @@ func makeBody(s string) io.ReadCloser { func TestCreateInvalidRequest(t *testing.T) { u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"} h := http.Header{"Content-Type": []string{"application/json"}} - repo := client.NewClientIdentityRepo(nil) + repo := db.NewClientIdentityRepo(db.NewMemDB()) res := &clientResource{repo: repo} tests := []struct { req *http.Request @@ -115,7 +118,7 @@ func TestCreateInvalidRequest(t *testing.T) { } func TestCreate(t *testing.T) { - repo := client.NewClientIdentityRepo(nil) + repo := db.NewClientIdentityRepo(db.NewMemDB()) res := &clientResource{repo: repo} tests := [][]string{ []string{"http://example.com"}, @@ -168,6 +171,11 @@ func TestCreate(t *testing.T) { } func TestList(t *testing.T) { + + b64Encode := func(s string) string { + return base64.URLEncoding.EncodeToString([]byte(s)) + } + tests := []struct { cs []oidc.ClientIdentity want []*schema.Client @@ -181,7 +189,7 @@ func TestList(t *testing.T) { { cs: []oidc.ClientIdentity{ oidc.ClientIdentity{ - Credentials: oidc.ClientCredentials{ID: "foo", Secret: "bar"}, + Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")}, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ url.URL{Scheme: "http", Host: "example.com"}, @@ -200,7 +208,7 @@ func TestList(t *testing.T) { { cs: []oidc.ClientIdentity{ oidc.ClientIdentity{ - Credentials: oidc.ClientCredentials{ID: "foo", Secret: "bar"}, + Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")}, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ url.URL{Scheme: "http", Host: "example.com"}, @@ -208,7 +216,7 @@ func TestList(t *testing.T) { }, }, oidc.ClientIdentity{ - Credentials: oidc.ClientCredentials{ID: "biz", Secret: "bang"}, + Credentials: oidc.ClientCredentials{ID: "biz", Secret: b64Encode("bang")}, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"}, @@ -230,7 +238,11 @@ func TestList(t *testing.T) { } for i, tt := range tests { - repo := client.NewClientIdentityRepo(tt.cs) + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), tt.cs) + if err != nil { + t.Errorf("case %d: failed to create client identity repo: %v", i, err) + continue + } res := &clientResource{repo: repo} r, err := http.NewRequest("GET", "http://example.com/clients", nil) @@ -248,9 +260,17 @@ func TestList(t *testing.T) { if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Errorf("case %d: unexpected error=%v", i, err) } + sort.Sort(byClientId(tt.want)) + sort.Sort(byClientId(resp.Clients)) - if !reflect.DeepEqual(tt.want, resp.Clients) { - t.Errorf("case %d: invalid response body, want=%#v, got=%#v", i, tt.want, resp.Clients) + if diff := pretty.Compare(tt.want, resp.Clients); diff != "" { + t.Errorf("case %d: invalid response body: %s", i, diff) } } } + +type byClientId []*schema.Client + +func (b byClientId) Len() int { return len(b) } +func (b byClientId) Less(i, j int) bool { return b[i].Id < b[j].Id } +func (b byClientId) Swap(i, j int) { b[i], b[j] = b[j], b[i] } diff --git a/server/config.go b/server/config.go index 83d88a02..d204065f 100644 --- a/server/config.go +++ b/server/config.go @@ -12,9 +12,9 @@ import ( "time" "github.com/coreos/go-oidc/key" + "github.com/coreos/go-oidc/oidc" "github.com/coreos/pkg/health" - "github.com/coreos/dex/client" "github.com/coreos/dex/connector" "github.com/coreos/dex/db" "github.com/coreos/dex/email" @@ -113,10 +113,14 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err) } defer cf.Close() - ciRepo, err := client.NewClientIdentityRepoFromReader(cf) - if err != nil { + var clients []oidc.ClientIdentity + if err := json.NewDecoder(cf).Decode(&clients); err != nil { return fmt.Errorf("unable to read client identities from file %s: %v", cfg.ClientsFile, err) } + ciRepo, err := db.NewClientIdentityRepoFromClients(dbMap, clients) + if err != nil { + return fmt.Errorf("failed to create client identity repo: %v", err) + } f, err := os.Open(cfg.ConnectorsFile) if err != nil { diff --git a/server/http_test.go b/server/http_test.go index 0d2d5516..94052b06 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -1,6 +1,7 @@ package server import ( + "encoding/base64" "encoding/json" "errors" "fmt" @@ -77,19 +78,25 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())), - ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ - oidc.ClientIdentity{ - Credentials: oidc.ClientCredentials{ - ID: "XXX", - Secret: "secrete", - }, - Metadata: oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}, + ClientIdentityRepo: func() client.ClientIdentityRepo { + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ + oidc.ClientIdentity{ + Credentials: oidc.ClientCredentials{ + ID: "XXX", + Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}, + }, }, }, - }, - }), + }) + if err != nil { + t.Fatalf("Failed to create client identity repo: %v", err) + } + return repo + }(), } tests := []struct { @@ -200,20 +207,26 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())), - ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ - oidc.ClientIdentity{ - Credentials: oidc.ClientCredentials{ - ID: "XXX", - Secret: "secrete", - }, - Metadata: oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - url.URL{Scheme: "http", Host: "foo.example.com", Path: "/callback"}, - url.URL{Scheme: "http", Host: "bar.example.com", Path: "/callback"}, + ClientIdentityRepo: func() client.ClientIdentityRepo { + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ + oidc.ClientIdentity{ + Credentials: oidc.ClientCredentials{ + ID: "XXX", + Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + url.URL{Scheme: "http", Host: "foo.example.com", Path: "/callback"}, + url.URL{Scheme: "http", Host: "bar.example.com", Path: "/callback"}, + }, }, }, - }, - }), + }) + if err != nil { + t.Fatalf("Failed to create client identity repo: %v", err) + } + return repo + }(), } tests := []struct { diff --git a/server/server_test.go b/server/server_test.go index d3051339..0aec6468 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -21,6 +21,8 @@ import ( "github.com/kylelemons/godebug/pretty" ) +var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secrete")) + type StaticKeyManager struct { key.PrivateKeyManager expiresAt time.Time @@ -180,7 +182,7 @@ func TestServerLogin(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", - Secret: "secrete", + Secret: clientTestSecret, }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ @@ -192,7 +194,13 @@ func TestServerLogin(t *testing.T) { }, }, } - ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) + ciRepo := func() client.ClientIdentityRepo { + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) + if err != nil { + t.Fatalf("Failed to create client identity repo: %v", err) + } + return repo + }() km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, @@ -236,13 +244,20 @@ func TestServerLogin(t *testing.T) { } func TestServerLoginUnrecognizedSessionKey(t *testing.T) { - ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ - oidc.ClientIdentity{ - Credentials: oidc.ClientCredentials{ - ID: "XXX", Secret: "secrete", + ciRepo := func() client.ClientIdentityRepo { + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ + oidc.ClientIdentity{ + Credentials: oidc.ClientCredentials{ + ID: "XXX", Secret: clientTestSecret, + }, }, - }, - }) + }) + if err != nil { + t.Fatalf("Failed to create client identity repo: %v", err) + } + return repo + }() + km := &StaticKeyManager{ signer: &StaticSigner{sig: nil, err: errors.New("fail")}, } @@ -269,7 +284,7 @@ func TestServerLoginDisabledUser(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", - Secret: "secrete", + Secret: clientTestSecret, }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ @@ -281,7 +296,13 @@ func TestServerLoginDisabledUser(t *testing.T) { }, }, } - ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) + ciRepo := func() client.ClientIdentityRepo { + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) + if err != nil { + t.Fatalf("Failed to create client identity repo: %v", err) + } + return repo + }() km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, @@ -337,10 +358,16 @@ func TestServerCodeToken(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", - Secret: "secrete", + Secret: clientTestSecret, }, } - ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) + ciRepo := func() client.ClientIdentityRepo { + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) + if err != nil { + t.Fatalf("Failed to create client identity repo: %v", err) + } + return repo + }() km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, } @@ -417,10 +444,16 @@ func TestServerTokenUnrecognizedKey(t *testing.T) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", - Secret: "secrete", + Secret: clientTestSecret, }, } - ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) + ciRepo := func() client.ClientIdentityRepo { + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) + if err != nil { + t.Fatalf("Failed to create client identity repo: %v", err) + } + return repo + }() km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, } @@ -460,7 +493,7 @@ func TestServerTokenFail(t *testing.T) { keyFixture := "goodkey" ccFixture := oidc.ClientCredentials{ ID: "XXX", - Secret: "secrete", + Secret: clientTestSecret, } signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} @@ -536,9 +569,13 @@ func TestServerTokenFail(t *testing.T) { km := &StaticKeyManager{ signer: tt.signer, } - ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ + ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ oidc.ClientIdentity{Credentials: ccFixture}, }) + if err != nil { + t.Errorf("case %d: failed to create client identity repo: %v", i, err) + continue + } _, err = sm.AttachUser(sessionID, "testid-1") if err != nil { @@ -589,11 +626,11 @@ func TestServerRefreshToken(t *testing.T) { credXXX := oidc.ClientCredentials{ ID: "XXX", - Secret: "secret", + Secret: clientTestSecret, } credYYY := oidc.ClientCredentials{ ID: "YYY", - Secret: "secret", + Secret: clientTestSecret, } signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} @@ -694,10 +731,14 @@ func TestServerRefreshToken(t *testing.T) { signer: tt.signer, } - ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ + ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ oidc.ClientIdentity{Credentials: credXXX}, oidc.ClientIdentity{Credentials: credYYY}, }) + if err != nil { + t.Errorf("case %d: failed to create client identity repo: %v", i, err) + continue + } userRepo, err := makeNewUserRepo() if err != nil { @@ -743,10 +784,13 @@ func TestServerRefreshToken(t *testing.T) { signer: signerFixture, } - ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ + ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ oidc.ClientIdentity{Credentials: credXXX}, oidc.ClientIdentity{Credentials: credYYY}, }) + if err != nil { + t.Fatalf("failed to create client identity repo: %v", err) + } userRepo, err := makeNewUserRepo() if err != nil { diff --git a/server/testutil.go b/server/testutil.go index e32ff1aa..84788988 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -1,6 +1,7 @@ package server import ( + "encoding/base64" "fmt" "net/url" "time" @@ -24,9 +25,8 @@ const ( ) var ( - testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"} - testClientID = "XXX" - testClientSecret = "secrete" + testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"} + testClientID = "XXX" testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"} @@ -133,11 +133,11 @@ func makeTestFixtures() (*testFixtures, error) { return nil, err } - clientIdentityRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ + clientIdentityRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", - Secret: testClientSecret, + Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ @@ -146,6 +146,9 @@ func makeTestFixtures() (*testFixtures, error) { }, }, }) + if err != nil { + return nil, err + } km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute))) diff --git a/user/api/api_test.go b/user/api/api_test.go index ed121387..0cf1e019 100644 --- a/user/api/api_test.go +++ b/user/api/api_test.go @@ -1,6 +1,7 @@ package api import ( + "encoding/base64" "net/url" "testing" "time" @@ -148,7 +149,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", - Secret: "secrete", + Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ @@ -156,7 +157,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { }, }, } - cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) + cir := func() client.ClientIdentityRepo { + repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) + if err != nil { + panic("Failed to create client identity repo: " + err.Error()) + } + return repo + }() emailer := &testEmailer{} api := NewUsersAPI(mgr, cir, emailer, "local")