diff --git a/admin/api.go b/admin/api.go index 965e9ce2..14fece20 100644 --- a/admin/api.go +++ b/admin/api.go @@ -148,7 +148,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem cli.Credentials.ID = id - creds, err := a.clientRepo.New(cli) + creds, err := a.clientRepo.New(nil, cli) if err != nil { return adminschema.ClientCreateResponse{}, mapError(err) } diff --git a/client/client.go b/client/client.go index f0b25b65..cddb54e8 100644 --- a/client/client.go +++ b/client/client.go @@ -7,6 +7,7 @@ import ( "net/url" "reflect" + "github.com/coreos/dex/repo" "github.com/coreos/go-oidc/oidc" ) @@ -24,26 +25,26 @@ type Client struct { } type ClientRepo interface { - Get(clientID string) (Client, error) + Get(tx repo.Transaction, clientID string) (Client, error) // Metadata returns one matching ClientMetadata if the given client // exists, otherwise nil. The returned error will be non-nil only // if the repo was unable to determine client existence. - Metadata(clientID string) (*oidc.ClientMetadata, error) + Metadata(tx repo.Transaction, clientID string) (*oidc.ClientMetadata, error) // Authenticate asserts that a client with the given ID exists and // that the provided secret matches. If either of these assertions // fail, (false, nil) will be returned. Only if the repo is unable // to make these assertions will a non-nil error be returned. - Authenticate(creds oidc.ClientCredentials) (bool, error) + Authenticate(tx repo.Transaction, creds oidc.ClientCredentials) (bool, error) // All returns all registered Clients - All() ([]Client, error) + All(tx repo.Transaction) ([]Client, error) // New registers a Client with the repo. // An unused ID must be provided. A corresponding secret will be returned // in a ClientCredentials struct along with the provided ID. - New(client Client) (*oidc.ClientCredentials, error) + New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error) SetDexAdmin(clientID string, isAdmin bool) error diff --git a/cmd/dexctl/driver_db.go b/cmd/dexctl/driver_db.go index eba8b4b2..92fbb973 100644 --- a/cmd/dexctl/driver_db.go +++ b/cmd/dexctl/driver_db.go @@ -36,7 +36,7 @@ func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials, return nil, err } - return d.ciRepo.New(client.Client{ + return d.ciRepo.New(nil, client.Client{ Credentials: oidc.ClientCredentials{ ID: clientID, }, diff --git a/db/client.go b/db/client.go index 42dd5f8d..3b7c02a6 100644 --- a/db/client.go +++ b/db/client.go @@ -15,6 +15,7 @@ import ( "github.com/coreos/dex/client" pcrypto "github.com/coreos/dex/pkg/crypto" "github.com/coreos/dex/pkg/log" + "github.com/coreos/dex/repo" ) const ( @@ -140,8 +141,8 @@ type clientRepo struct { secretGenerator SecretGenerator } -func (r *clientRepo) Get(clientID string) (client.Client, error) { - m, err := r.executor(nil).Get(clientModel{}, clientID) +func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) { + m, err := r.executor(tx).Get(clientModel{}, clientID) if err == sql.ErrNoRows || m == nil { return client.Client{}, client.ErrorNotFound } @@ -163,8 +164,8 @@ func (r *clientRepo) Get(clientID string) (client.Client, error) { return *ci, nil } -func (r *clientRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) { - c, err := r.Get(clientID) +func (r *clientRepo) Metadata(tx repo.Transaction, clientID string) (*oidc.ClientMetadata, error) { + c, err := r.Get(tx, clientID) if err != nil { return nil, err } @@ -215,8 +216,8 @@ func (r *clientRepo) SetDexAdmin(clientID string, isAdmin bool) error { return tx.Commit() } -func (r *clientRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) { - m, err := r.executor(nil).Get(clientModel{}, creds.ID) +func (r *clientRepo) Authenticate(tx repo.Transaction, creds oidc.ClientCredentials) (bool, error) { + m, err := r.executor(tx).Get(clientModel{}, creds.ID) if m == nil || err != nil { return false, err } @@ -266,7 +267,7 @@ func DefaultSecretGenerator() ([]byte, error) { return pcrypto.RandBytes(maxSecretLength) } -func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) { +func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) { secret, err := r.secretGenerator() if err != nil { return nil, err @@ -279,7 +280,7 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) { return nil, err } - if err := r.executor(nil).Insert(cim); err != nil { + if err := r.executor(tx).Insert(cim); err != nil { if isAlreadyExistsErr(err) { err = errors.New("client ID already exists") } @@ -294,10 +295,10 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) { return &cc, nil } -func (r *clientRepo) All() ([]client.Client, error) { +func (r *clientRepo) All(tx repo.Transaction) ([]client.Client, error) { qt := r.quote(clientTableName) q := fmt.Sprintf("SELECT * FROM %s", qt) - objs, err := r.executor(nil).Select(&clientModel{}, q) + objs, err := r.executor(tx).Select(&clientModel{}, q) if err != nil { return nil, err } diff --git a/functional/db_test.go b/functional/db_test.go index ecbff9ae..e4b57d61 100644 --- a/functional/db_test.go +++ b/functional/db_test.go @@ -191,7 +191,7 @@ func TestDBClientRepoMetadata(t *testing.T) { }, } - _, err := r.New(client.Client{ + _, err := r.New(nil, client.Client{ Credentials: oidc.ClientCredentials{ ID: "foo", }, @@ -201,7 +201,7 @@ func TestDBClientRepoMetadata(t *testing.T) { t.Fatalf(err.Error()) } - got, err := r.Metadata("foo") + got, err := r.Metadata(nil, "foo") if err != nil { t.Fatalf(err.Error()) } @@ -214,7 +214,7 @@ func TestDBClientRepoMetadata(t *testing.T) { func TestDBClientRepoMetadataNoExist(t *testing.T) { r := db.NewClientRepo(connect(t)) - got, err := r.Metadata("noexist") + got, err := r.Metadata(nil, "noexist") if err != client.ErrorNotFound { t.Errorf("want==%q, got==%q", client.ErrorNotFound, err) } @@ -232,7 +232,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) { }, } - if _, err := r.New(client.Client{ + if _, err := r.New(nil, client.Client{ Credentials: oidc.ClientCredentials{ ID: "foo", }, @@ -247,7 +247,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) { }, } - if _, err := r.New(client.Client{ + if _, err := r.New(nil, client.Client{ Credentials: oidc.ClientCredentials{ ID: "foo", }, @@ -261,7 +261,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) { for _, admin := range []bool{true, false} { r := db.NewClientRepo(connect(t)) - if _, err := r.New(client.Client{ + if _, err := r.New(nil, client.Client{ Credentials: oidc.ClientCredentials{ ID: "foo", }, @@ -283,7 +283,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) { t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin) } - cli, err := r.Get("foo") + cli, err := r.Get(nil, "foo") if err != nil { t.Fatalf("expected non-nil error") } @@ -302,7 +302,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) { }, } - cc, err := r.New(client.Client{ + cc, err := r.New(nil, client.Client{ Credentials: oidc.ClientCredentials{ ID: "baz", }, @@ -316,7 +316,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) { t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID) } - ok, err := r.Authenticate(*cc) + ok, err := r.Authenticate(nil, *cc) if err != nil { t.Fatalf("Unexpected error: %v", err) } else if !ok { @@ -337,7 +337,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) { oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)}, } for i, c := range creds { - ok, err := r.Authenticate(c) + ok, err := r.Authenticate(nil, c) if err != nil { t.Errorf("case %d: unexpected error: %v", i, err) } else if ok { @@ -355,7 +355,7 @@ func TestDBClientAll(t *testing.T) { }, } - _, err := r.New(client.Client{ + _, err := r.New(nil, client.Client{ Credentials: oidc.ClientCredentials{ ID: "foo", }, @@ -365,7 +365,7 @@ func TestDBClientAll(t *testing.T) { t.Fatalf(err.Error()) } - got, err := r.All() + got, err := r.All(nil) if err != nil { t.Fatalf(err.Error()) } @@ -383,7 +383,7 @@ func TestDBClientAll(t *testing.T) { url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"}, }, } - _, err = r.New(client.Client{ + _, err = r.New(nil, client.Client{ Credentials: oidc.ClientCredentials{ ID: "bar", }, @@ -393,7 +393,7 @@ func TestDBClientAll(t *testing.T) { t.Fatalf(err.Error()) } - got, err = r.All() + got, err = r.All(nil) if err != nil { t.Fatalf(err.Error()) } diff --git a/integration/admin_api_test.go b/integration/admin_api_test.go index 1348a24b..56184703 100644 --- a/integration/admin_api_test.go +++ b/integration/admin_api_test.go @@ -402,7 +402,7 @@ func TestCreateClient(t *testing.T) { t.Errorf("case %d: Compare(want, got) = %v", i, diff) } - repoClient, err := f.cr.Get(resp.Client.Id) + repoClient, err := f.cr.Get(nil, resp.Client.Id) if err != nil { t.Errorf("case %d: Unexpected error getting client: %v", i, err) } diff --git a/integration/client_api_test.go b/integration/client_api_test.go index d19b18a4..4ca2550c 100644 --- a/integration/client_api_test.go +++ b/integration/client_api_test.go @@ -73,7 +73,7 @@ func TestClientCreate(t *testing.T) { t.Error("Expected non-empty Client Secret") } - meta, err := srv.ClientRepo.Metadata(newClient.Id) + meta, err := srv.ClientRepo.Metadata(nil, newClient.Id) if err != nil { t.Errorf("Error looking up client metadata: %v", err) } else if meta == nil { diff --git a/server/auth_middleware.go b/server/auth_middleware.go index 7c6fc789..201a7e4d 100644 --- a/server/auth_middleware.go +++ b/server/auth_middleware.go @@ -83,7 +83,7 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request return } - md, err := c.ciRepo.Metadata(clientID) + md, err := c.ciRepo.Metadata(nil, clientID) if md == nil || err != nil { log.Errorf("Failed to find clientID: %s, error=%v", clientID, err) respondError() diff --git a/server/client_registration.go b/server/client_registration.go index ca0a23fb..e6fd8849 100644 --- a/server/client_registration.go +++ b/server/client_registration.go @@ -45,7 +45,7 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata") } - creds, err := s.ClientRepo.New(client.Client{ + creds, err := s.ClientRepo.New(nil, client.Client{ Credentials: oidc.ClientCredentials{ ID: id, }, diff --git a/server/client_registration_test.go b/server/client_registration_test.go index ef8a19e8..f8099ca7 100644 --- a/server/client_registration_test.go +++ b/server/client_registration_test.go @@ -143,7 +143,7 @@ func TestClientRegistration(t *testing.T) { return fmt.Errorf("no client id in registration response") } - metadata, err := fixtures.clientRepo.Metadata(r.ClientID) + metadata, err := fixtures.clientRepo.Metadata(nil, r.ClientID) if err != nil { return fmt.Errorf("failed to lookup client id after creation") } diff --git a/server/client_resource.go b/server/client_resource.go index ea348668..891aa36a 100644 --- a/server/client_resource.go +++ b/server/client_resource.go @@ -41,7 +41,7 @@ func (c *clientResource) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (c *clientResource) list(w http.ResponseWriter, r *http.Request) { - cs, err := c.repo.All() + cs, err := c.repo.All(nil) if err != nil { writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients")) return @@ -97,7 +97,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) { } ci.Credentials.ID = clientID - creds, err := c.repo.New(ci) + creds, err := c.repo.New(nil, ci) if err != nil { log.Errorf("Failed creating client: %v", err) diff --git a/server/email_verification.go b/server/email_verification.go index f38a9821..477f1b41 100644 --- a/server/email_verification.go +++ b/server/email_verification.go @@ -57,7 +57,7 @@ func handleVerifyEmailResendFunc( return } - cm, err := clientRepo.Metadata(clientID) + cm, err := clientRepo.Metadata(nil, clientID) if err == client.ErrorNotFound { log.Errorf("No such client: %v", err) writeAPIError(w, http.StatusBadRequest, diff --git a/server/password.go b/server/password.go index 077a7a06..ebd2677a 100644 --- a/server/password.go +++ b/server/password.go @@ -128,7 +128,7 @@ func (h *SendResetPasswordEmailHandler) validateRedirectURL(clientID string, red return url.URL{}, false } - cm, err := h.cr.Metadata(clientID) + cm, err := h.cr.Metadata(nil, clientID) if err != nil || cm == nil { log.Errorf("Error getting ClientMetadata: %v", err) return url.URL{}, false diff --git a/server/server.go b/server/server.go index 1308c26b..0252fe6f 100644 --- a/server/server.go +++ b/server/server.go @@ -278,7 +278,7 @@ func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler { } func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) { - return s.ClientRepo.Metadata(clientID) + return s.ClientRepo.Metadata(nil, clientID) } func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) { @@ -365,7 +365,7 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { } func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) { - ok, err := s.ClientRepo.Authenticate(creds) + ok, err := s.ClientRepo.Authenticate(nil, creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) return nil, oauth2.NewError(oauth2.ErrorServerError) @@ -397,7 +397,7 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro } func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) { - ok, err := s.ClientRepo.Authenticate(creds) + ok, err := s.ClientRepo.Authenticate(nil, creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) return nil, "", oauth2.NewError(oauth2.ErrorServerError) @@ -466,7 +466,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo } func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) { - ok, err := s.ClientRepo.Authenticate(creds) + ok, err := s.ClientRepo.Authenticate(nil, creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) return nil, oauth2.NewError(oauth2.ErrorServerError) diff --git a/user/api/api.go b/user/api/api.go index 5d246b6d..1a7df719 100644 --- a/user/api/api.go +++ b/user/api/api.go @@ -157,7 +157,7 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s return schema.UserCreateResponse{}, mapError(err) } - metadata, err := u.clientRepo.Metadata(creds.ClientID) + metadata, err := u.clientRepo.Metadata(nil, creds.ClientID) if err != nil { return schema.UserCreateResponse{}, mapError(err) } @@ -202,7 +202,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized } - metadata, err := u.clientRepo.Metadata(creds.ClientID) + metadata, err := u.clientRepo.Metadata(nil, creds.ClientID) if err != nil { return schema.ResendEmailInvitationResponse{}, mapError(err) }