client: add transaction support

This commit is contained in:
Evan Cordell 2016-05-11 14:35:24 -07:00
parent 02bf115026
commit 3da98fcb8e
15 changed files with 48 additions and 46 deletions

View file

@ -148,7 +148,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem
cli.Credentials.ID = id
creds, err := a.clientRepo.New(cli)
creds, err := a.clientRepo.New(nil, cli)
if err != nil {
return adminschema.ClientCreateResponse{}, mapError(err)
}

View file

@ -7,6 +7,7 @@ import (
"net/url"
"reflect"
"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc"
)
@ -24,26 +25,26 @@ type Client struct {
}
type ClientRepo interface {
Get(clientID string) (Client, error)
Get(tx repo.Transaction, clientID string) (Client, error)
// Metadata returns one matching ClientMetadata if the given client
// exists, otherwise nil. The returned error will be non-nil only
// if the repo was unable to determine client existence.
Metadata(clientID string) (*oidc.ClientMetadata, error)
Metadata(tx repo.Transaction, clientID string) (*oidc.ClientMetadata, error)
// Authenticate asserts that a client with the given ID exists and
// that the provided secret matches. If either of these assertions
// fail, (false, nil) will be returned. Only if the repo is unable
// to make these assertions will a non-nil error be returned.
Authenticate(creds oidc.ClientCredentials) (bool, error)
Authenticate(tx repo.Transaction, creds oidc.ClientCredentials) (bool, error)
// All returns all registered Clients
All() ([]Client, error)
All(tx repo.Transaction) ([]Client, error)
// New registers a Client with the repo.
// An unused ID must be provided. A corresponding secret will be returned
// in a ClientCredentials struct along with the provided ID.
New(client Client) (*oidc.ClientCredentials, error)
New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error)
SetDexAdmin(clientID string, isAdmin bool) error

View file

@ -36,7 +36,7 @@ func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials,
return nil, err
}
return d.ciRepo.New(client.Client{
return d.ciRepo.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: clientID,
},

View file

@ -15,6 +15,7 @@ import (
"github.com/coreos/dex/client"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
)
const (
@ -140,8 +141,8 @@ type clientRepo struct {
secretGenerator SecretGenerator
}
func (r *clientRepo) Get(clientID string) (client.Client, error) {
m, err := r.executor(nil).Get(clientModel{}, clientID)
func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) {
m, err := r.executor(tx).Get(clientModel{}, clientID)
if err == sql.ErrNoRows || m == nil {
return client.Client{}, client.ErrorNotFound
}
@ -163,8 +164,8 @@ func (r *clientRepo) Get(clientID string) (client.Client, error) {
return *ci, nil
}
func (r *clientRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
c, err := r.Get(clientID)
func (r *clientRepo) Metadata(tx repo.Transaction, clientID string) (*oidc.ClientMetadata, error) {
c, err := r.Get(tx, clientID)
if err != nil {
return nil, err
}
@ -215,8 +216,8 @@ func (r *clientRepo) SetDexAdmin(clientID string, isAdmin bool) error {
return tx.Commit()
}
func (r *clientRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
m, err := r.executor(nil).Get(clientModel{}, creds.ID)
func (r *clientRepo) Authenticate(tx repo.Transaction, creds oidc.ClientCredentials) (bool, error) {
m, err := r.executor(tx).Get(clientModel{}, creds.ID)
if m == nil || err != nil {
return false, err
}
@ -266,7 +267,7 @@ func DefaultSecretGenerator() ([]byte, error) {
return pcrypto.RandBytes(maxSecretLength)
}
func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) {
secret, err := r.secretGenerator()
if err != nil {
return nil, err
@ -279,7 +280,7 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
return nil, err
}
if err := r.executor(nil).Insert(cim); err != nil {
if err := r.executor(tx).Insert(cim); err != nil {
if isAlreadyExistsErr(err) {
err = errors.New("client ID already exists")
}
@ -294,10 +295,10 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
return &cc, nil
}
func (r *clientRepo) All() ([]client.Client, error) {
func (r *clientRepo) All(tx repo.Transaction) ([]client.Client, error) {
qt := r.quote(clientTableName)
q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.executor(nil).Select(&clientModel{}, q)
objs, err := r.executor(tx).Select(&clientModel{}, q)
if err != nil {
return nil, err
}

View file

@ -191,7 +191,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
},
}
_, err := r.New(client.Client{
_, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
@ -201,7 +201,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
t.Fatalf(err.Error())
}
got, err := r.Metadata("foo")
got, err := r.Metadata(nil, "foo")
if err != nil {
t.Fatalf(err.Error())
}
@ -214,7 +214,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
func TestDBClientRepoMetadataNoExist(t *testing.T) {
r := db.NewClientRepo(connect(t))
got, err := r.Metadata("noexist")
got, err := r.Metadata(nil, "noexist")
if err != client.ErrorNotFound {
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
}
@ -232,7 +232,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
},
}
if _, err := r.New(client.Client{
if _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
@ -247,7 +247,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
},
}
if _, err := r.New(client.Client{
if _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
@ -261,7 +261,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
for _, admin := range []bool{true, false} {
r := db.NewClientRepo(connect(t))
if _, err := r.New(client.Client{
if _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
@ -283,7 +283,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin)
}
cli, err := r.Get("foo")
cli, err := r.Get(nil, "foo")
if err != nil {
t.Fatalf("expected non-nil error")
}
@ -302,7 +302,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
},
}
cc, err := r.New(client.Client{
cc, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "baz",
},
@ -316,7 +316,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
}
ok, err := r.Authenticate(*cc)
ok, err := r.Authenticate(nil, *cc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
} else if !ok {
@ -337,7 +337,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
}
for i, c := range creds {
ok, err := r.Authenticate(c)
ok, err := r.Authenticate(nil, c)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
} else if ok {
@ -355,7 +355,7 @@ func TestDBClientAll(t *testing.T) {
},
}
_, err := r.New(client.Client{
_, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
@ -365,7 +365,7 @@ func TestDBClientAll(t *testing.T) {
t.Fatalf(err.Error())
}
got, err := r.All()
got, err := r.All(nil)
if err != nil {
t.Fatalf(err.Error())
}
@ -383,7 +383,7 @@ func TestDBClientAll(t *testing.T) {
url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
},
}
_, err = r.New(client.Client{
_, err = r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "bar",
},
@ -393,7 +393,7 @@ func TestDBClientAll(t *testing.T) {
t.Fatalf(err.Error())
}
got, err = r.All()
got, err = r.All(nil)
if err != nil {
t.Fatalf(err.Error())
}

View file

@ -402,7 +402,7 @@ func TestCreateClient(t *testing.T) {
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
}
repoClient, err := f.cr.Get(resp.Client.Id)
repoClient, err := f.cr.Get(nil, resp.Client.Id)
if err != nil {
t.Errorf("case %d: Unexpected error getting client: %v", i, err)
}

View file

@ -73,7 +73,7 @@ func TestClientCreate(t *testing.T) {
t.Error("Expected non-empty Client Secret")
}
meta, err := srv.ClientRepo.Metadata(newClient.Id)
meta, err := srv.ClientRepo.Metadata(nil, newClient.Id)
if err != nil {
t.Errorf("Error looking up client metadata: %v", err)
} else if meta == nil {

View file

@ -83,7 +83,7 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
return
}
md, err := c.ciRepo.Metadata(clientID)
md, err := c.ciRepo.Metadata(nil, clientID)
if md == nil || err != nil {
log.Errorf("Failed to find clientID: %s, error=%v", clientID, err)
respondError()

View file

@ -45,7 +45,7 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
}
creds, err := s.ClientRepo.New(client.Client{
creds, err := s.ClientRepo.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: id,
},

View file

@ -143,7 +143,7 @@ func TestClientRegistration(t *testing.T) {
return fmt.Errorf("no client id in registration response")
}
metadata, err := fixtures.clientRepo.Metadata(r.ClientID)
metadata, err := fixtures.clientRepo.Metadata(nil, r.ClientID)
if err != nil {
return fmt.Errorf("failed to lookup client id after creation")
}

View file

@ -41,7 +41,7 @@ func (c *clientResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func (c *clientResource) list(w http.ResponseWriter, r *http.Request) {
cs, err := c.repo.All()
cs, err := c.repo.All(nil)
if err != nil {
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients"))
return
@ -97,7 +97,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) {
}
ci.Credentials.ID = clientID
creds, err := c.repo.New(ci)
creds, err := c.repo.New(nil, ci)
if err != nil {
log.Errorf("Failed creating client: %v", err)

View file

@ -57,7 +57,7 @@ func handleVerifyEmailResendFunc(
return
}
cm, err := clientRepo.Metadata(clientID)
cm, err := clientRepo.Metadata(nil, clientID)
if err == client.ErrorNotFound {
log.Errorf("No such client: %v", err)
writeAPIError(w, http.StatusBadRequest,

View file

@ -128,7 +128,7 @@ func (h *SendResetPasswordEmailHandler) validateRedirectURL(clientID string, red
return url.URL{}, false
}
cm, err := h.cr.Metadata(clientID)
cm, err := h.cr.Metadata(nil, clientID)
if err != nil || cm == nil {
log.Errorf("Error getting ClientMetadata: %v", err)
return url.URL{}, false

View file

@ -278,7 +278,7 @@ func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler {
}
func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) {
return s.ClientRepo.Metadata(clientID)
return s.ClientRepo.Metadata(nil, clientID)
}
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
@ -365,7 +365,7 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
}
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
ok, err := s.ClientRepo.Authenticate(creds)
ok, err := s.ClientRepo.Authenticate(nil, creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
@ -397,7 +397,7 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro
}
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
ok, err := s.ClientRepo.Authenticate(creds)
ok, err := s.ClientRepo.Authenticate(nil, creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
@ -466,7 +466,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
}
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
ok, err := s.ClientRepo.Authenticate(creds)
ok, err := s.ClientRepo.Authenticate(nil, creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)

View file

@ -157,7 +157,7 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
return schema.UserCreateResponse{}, mapError(err)
}
metadata, err := u.clientRepo.Metadata(creds.ClientID)
metadata, err := u.clientRepo.Metadata(nil, creds.ClientID)
if err != nil {
return schema.UserCreateResponse{}, mapError(err)
}
@ -202,7 +202,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized
}
metadata, err := u.clientRepo.Metadata(creds.ClientID)
metadata, err := u.clientRepo.Metadata(nil, creds.ClientID)
if err != nil {
return schema.ResendEmailInvitationResponse{}, mapError(err)
}