forked from mystiq/dex
client: add transaction support
This commit is contained in:
parent
02bf115026
commit
3da98fcb8e
15 changed files with 48 additions and 46 deletions
|
@ -148,7 +148,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem
|
|||
|
||||
cli.Credentials.ID = id
|
||||
|
||||
creds, err := a.clientRepo.New(cli)
|
||||
creds, err := a.clientRepo.New(nil, cli)
|
||||
if err != nil {
|
||||
return adminschema.ClientCreateResponse{}, mapError(err)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/url"
|
||||
"reflect"
|
||||
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -24,26 +25,26 @@ type Client struct {
|
|||
}
|
||||
|
||||
type ClientRepo interface {
|
||||
Get(clientID string) (Client, error)
|
||||
Get(tx repo.Transaction, clientID string) (Client, error)
|
||||
|
||||
// Metadata returns one matching ClientMetadata if the given client
|
||||
// exists, otherwise nil. The returned error will be non-nil only
|
||||
// if the repo was unable to determine client existence.
|
||||
Metadata(clientID string) (*oidc.ClientMetadata, error)
|
||||
Metadata(tx repo.Transaction, clientID string) (*oidc.ClientMetadata, error)
|
||||
|
||||
// Authenticate asserts that a client with the given ID exists and
|
||||
// that the provided secret matches. If either of these assertions
|
||||
// fail, (false, nil) will be returned. Only if the repo is unable
|
||||
// to make these assertions will a non-nil error be returned.
|
||||
Authenticate(creds oidc.ClientCredentials) (bool, error)
|
||||
Authenticate(tx repo.Transaction, creds oidc.ClientCredentials) (bool, error)
|
||||
|
||||
// All returns all registered Clients
|
||||
All() ([]Client, error)
|
||||
All(tx repo.Transaction) ([]Client, error)
|
||||
|
||||
// New registers a Client with the repo.
|
||||
// An unused ID must be provided. A corresponding secret will be returned
|
||||
// in a ClientCredentials struct along with the provided ID.
|
||||
New(client Client) (*oidc.ClientCredentials, error)
|
||||
New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error)
|
||||
|
||||
SetDexAdmin(clientID string, isAdmin bool) error
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return d.ciRepo.New(client.Client{
|
||||
return d.ciRepo.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: clientID,
|
||||
},
|
||||
|
|
21
db/client.go
21
db/client.go
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/coreos/dex/client"
|
||||
pcrypto "github.com/coreos/dex/pkg/crypto"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -140,8 +141,8 @@ type clientRepo struct {
|
|||
secretGenerator SecretGenerator
|
||||
}
|
||||
|
||||
func (r *clientRepo) Get(clientID string) (client.Client, error) {
|
||||
m, err := r.executor(nil).Get(clientModel{}, clientID)
|
||||
func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) {
|
||||
m, err := r.executor(tx).Get(clientModel{}, clientID)
|
||||
if err == sql.ErrNoRows || m == nil {
|
||||
return client.Client{}, client.ErrorNotFound
|
||||
}
|
||||
|
@ -163,8 +164,8 @@ func (r *clientRepo) Get(clientID string) (client.Client, error) {
|
|||
return *ci, nil
|
||||
}
|
||||
|
||||
func (r *clientRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||
c, err := r.Get(clientID)
|
||||
func (r *clientRepo) Metadata(tx repo.Transaction, clientID string) (*oidc.ClientMetadata, error) {
|
||||
c, err := r.Get(tx, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -215,8 +216,8 @@ func (r *clientRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
|||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *clientRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
||||
m, err := r.executor(nil).Get(clientModel{}, creds.ID)
|
||||
func (r *clientRepo) Authenticate(tx repo.Transaction, creds oidc.ClientCredentials) (bool, error) {
|
||||
m, err := r.executor(tx).Get(clientModel{}, creds.ID)
|
||||
if m == nil || err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -266,7 +267,7 @@ func DefaultSecretGenerator() ([]byte, error) {
|
|||
return pcrypto.RandBytes(maxSecretLength)
|
||||
}
|
||||
|
||||
func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
|
||||
func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) {
|
||||
secret, err := r.secretGenerator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -279,7 +280,7 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := r.executor(nil).Insert(cim); err != nil {
|
||||
if err := r.executor(tx).Insert(cim); err != nil {
|
||||
if isAlreadyExistsErr(err) {
|
||||
err = errors.New("client ID already exists")
|
||||
}
|
||||
|
@ -294,10 +295,10 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
|
|||
return &cc, nil
|
||||
}
|
||||
|
||||
func (r *clientRepo) All() ([]client.Client, error) {
|
||||
func (r *clientRepo) All(tx repo.Transaction) ([]client.Client, error) {
|
||||
qt := r.quote(clientTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||
objs, err := r.executor(nil).Select(&clientModel{}, q)
|
||||
objs, err := r.executor(tx).Select(&clientModel{}, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -191,7 +191,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := r.New(client.Client{
|
||||
_, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -201,7 +201,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
|
|||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
got, err := r.Metadata("foo")
|
||||
got, err := r.Metadata(nil, "foo")
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
@ -214,7 +214,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
|
|||
func TestDBClientRepoMetadataNoExist(t *testing.T) {
|
||||
r := db.NewClientRepo(connect(t))
|
||||
|
||||
got, err := r.Metadata("noexist")
|
||||
got, err := r.Metadata(nil, "noexist")
|
||||
if err != client.ErrorNotFound {
|
||||
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
if _, err := r.New(client.Client{
|
||||
if _, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -247,7 +247,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
if _, err := r.New(client.Client{
|
||||
if _, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -261,7 +261,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
|
|||
|
||||
for _, admin := range []bool{true, false} {
|
||||
r := db.NewClientRepo(connect(t))
|
||||
if _, err := r.New(client.Client{
|
||||
if _, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -283,7 +283,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
|
|||
t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin)
|
||||
}
|
||||
|
||||
cli, err := r.Get("foo")
|
||||
cli, err := r.Get(nil, "foo")
|
||||
if err != nil {
|
||||
t.Fatalf("expected non-nil error")
|
||||
}
|
||||
|
@ -302,7 +302,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
cc, err := r.New(client.Client{
|
||||
cc, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "baz",
|
||||
},
|
||||
|
@ -316,7 +316,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
|
|||
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
|
||||
}
|
||||
|
||||
ok, err := r.Authenticate(*cc)
|
||||
ok, err := r.Authenticate(nil, *cc)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
} else if !ok {
|
||||
|
@ -337,7 +337,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
|
|||
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
|
||||
}
|
||||
for i, c := range creds {
|
||||
ok, err := r.Authenticate(c)
|
||||
ok, err := r.Authenticate(nil, c)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
} else if ok {
|
||||
|
@ -355,7 +355,7 @@ func TestDBClientAll(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := r.New(client.Client{
|
||||
_, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -365,7 +365,7 @@ func TestDBClientAll(t *testing.T) {
|
|||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
got, err := r.All()
|
||||
got, err := r.All(nil)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
@ -383,7 +383,7 @@ func TestDBClientAll(t *testing.T) {
|
|||
url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
|
||||
},
|
||||
}
|
||||
_, err = r.New(client.Client{
|
||||
_, err = r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "bar",
|
||||
},
|
||||
|
@ -393,7 +393,7 @@ func TestDBClientAll(t *testing.T) {
|
|||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
got, err = r.All()
|
||||
got, err = r.All(nil)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
|
|
@ -402,7 +402,7 @@ func TestCreateClient(t *testing.T) {
|
|||
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
|
||||
}
|
||||
|
||||
repoClient, err := f.cr.Get(resp.Client.Id)
|
||||
repoClient, err := f.cr.Get(nil, resp.Client.Id)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: Unexpected error getting client: %v", i, err)
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ func TestClientCreate(t *testing.T) {
|
|||
t.Error("Expected non-empty Client Secret")
|
||||
}
|
||||
|
||||
meta, err := srv.ClientRepo.Metadata(newClient.Id)
|
||||
meta, err := srv.ClientRepo.Metadata(nil, newClient.Id)
|
||||
if err != nil {
|
||||
t.Errorf("Error looking up client metadata: %v", err)
|
||||
} else if meta == nil {
|
||||
|
|
|
@ -83,7 +83,7 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
md, err := c.ciRepo.Metadata(clientID)
|
||||
md, err := c.ciRepo.Metadata(nil, clientID)
|
||||
if md == nil || err != nil {
|
||||
log.Errorf("Failed to find clientID: %s, error=%v", clientID, err)
|
||||
respondError()
|
||||
|
|
|
@ -45,7 +45,7 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR
|
|||
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
|
||||
}
|
||||
|
||||
creds, err := s.ClientRepo.New(client.Client{
|
||||
creds, err := s.ClientRepo.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: id,
|
||||
},
|
||||
|
|
|
@ -143,7 +143,7 @@ func TestClientRegistration(t *testing.T) {
|
|||
return fmt.Errorf("no client id in registration response")
|
||||
}
|
||||
|
||||
metadata, err := fixtures.clientRepo.Metadata(r.ClientID)
|
||||
metadata, err := fixtures.clientRepo.Metadata(nil, r.ClientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lookup client id after creation")
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ func (c *clientResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (c *clientResource) list(w http.ResponseWriter, r *http.Request) {
|
||||
cs, err := c.repo.All()
|
||||
cs, err := c.repo.All(nil)
|
||||
if err != nil {
|
||||
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients"))
|
||||
return
|
||||
|
@ -97,7 +97,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ci.Credentials.ID = clientID
|
||||
creds, err := c.repo.New(ci)
|
||||
creds, err := c.repo.New(nil, ci)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("Failed creating client: %v", err)
|
||||
|
|
|
@ -57,7 +57,7 @@ func handleVerifyEmailResendFunc(
|
|||
return
|
||||
}
|
||||
|
||||
cm, err := clientRepo.Metadata(clientID)
|
||||
cm, err := clientRepo.Metadata(nil, clientID)
|
||||
if err == client.ErrorNotFound {
|
||||
log.Errorf("No such client: %v", err)
|
||||
writeAPIError(w, http.StatusBadRequest,
|
||||
|
|
|
@ -128,7 +128,7 @@ func (h *SendResetPasswordEmailHandler) validateRedirectURL(clientID string, red
|
|||
return url.URL{}, false
|
||||
}
|
||||
|
||||
cm, err := h.cr.Metadata(clientID)
|
||||
cm, err := h.cr.Metadata(nil, clientID)
|
||||
if err != nil || cm == nil {
|
||||
log.Errorf("Error getting ClientMetadata: %v", err)
|
||||
return url.URL{}, false
|
||||
|
|
|
@ -278,7 +278,7 @@ func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||
return s.ClientRepo.Metadata(clientID)
|
||||
return s.ClientRepo.Metadata(nil, clientID)
|
||||
}
|
||||
|
||||
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
|
||||
|
@ -365,7 +365,7 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
|||
}
|
||||
|
||||
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
|
||||
ok, err := s.ClientRepo.Authenticate(creds)
|
||||
ok, err := s.ClientRepo.Authenticate(nil, creds)
|
||||
if err != nil {
|
||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
|
@ -397,7 +397,7 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro
|
|||
}
|
||||
|
||||
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
|
||||
ok, err := s.ClientRepo.Authenticate(creds)
|
||||
ok, err := s.ClientRepo.Authenticate(nil, creds)
|
||||
if err != nil {
|
||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||
|
@ -466,7 +466,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
|||
}
|
||||
|
||||
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
|
||||
ok, err := s.ClientRepo.Authenticate(creds)
|
||||
ok, err := s.ClientRepo.Authenticate(nil, creds)
|
||||
if err != nil {
|
||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
|
|
|
@ -157,7 +157,7 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
|
|||
return schema.UserCreateResponse{}, mapError(err)
|
||||
}
|
||||
|
||||
metadata, err := u.clientRepo.Metadata(creds.ClientID)
|
||||
metadata, err := u.clientRepo.Metadata(nil, creds.ClientID)
|
||||
if err != nil {
|
||||
return schema.UserCreateResponse{}, mapError(err)
|
||||
}
|
||||
|
@ -202,7 +202,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
|
|||
return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized
|
||||
}
|
||||
|
||||
metadata, err := u.clientRepo.Metadata(creds.ClientID)
|
||||
metadata, err := u.clientRepo.Metadata(nil, creds.ClientID)
|
||||
if err != nil {
|
||||
return schema.ResendEmailInvitationResponse{}, mapError(err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue