forked from mystiq/dex
Merge pull request #809 from rithujohn191/set-error-flag
storage: Surface "already exists" errors.
This commit is contained in:
commit
c76832eaea
7 changed files with 119 additions and 18 deletions
|
@ -53,8 +53,10 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap
|
|||
LogoURL: req.Client.LogoUrl,
|
||||
}
|
||||
if err := d.s.CreateClient(c); err != nil {
|
||||
if err == storage.ErrAlreadyExists {
|
||||
return &api.CreateClientResp{AlreadyExists: true}, nil
|
||||
}
|
||||
d.logger.Errorf("api: failed to create client: %v", err)
|
||||
// TODO(ericchiang): Surface "already exists" errors.
|
||||
return nil, fmt.Errorf("create client: %v", err)
|
||||
}
|
||||
|
||||
|
@ -109,6 +111,9 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq)
|
|||
UserID: req.Password.UserId,
|
||||
}
|
||||
if err := d.s.CreatePassword(p); err != nil {
|
||||
if err == storage.ErrAlreadyExists {
|
||||
return &api.CreatePasswordResp{AlreadyExists: true}, nil
|
||||
}
|
||||
d.logger.Errorf("api: failed to create password: %v", err)
|
||||
return nil, fmt.Errorf("create password: %v", err)
|
||||
}
|
||||
|
|
|
@ -37,10 +37,18 @@ func TestPassword(t *testing.T) {
|
|||
Password: &p,
|
||||
}
|
||||
|
||||
if _, err := serv.CreatePassword(ctx, &createReq); err != nil {
|
||||
if resp, err := serv.CreatePassword(ctx, &createReq); err != nil || resp.AlreadyExists {
|
||||
if resp.AlreadyExists {
|
||||
t.Fatalf("Unable to create password since %s already exists", createReq.Password.Email)
|
||||
}
|
||||
t.Fatalf("Unable to create password: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create a password that already exists.
|
||||
if resp, _ := serv.CreatePassword(ctx, &createReq); !resp.AlreadyExists {
|
||||
t.Fatalf("Created password %s twice", createReq.Password.Email)
|
||||
}
|
||||
|
||||
updateReq := api.UpdatePasswordReq{
|
||||
Email: "test@example.com",
|
||||
NewUsername: "test1",
|
||||
|
|
|
@ -70,6 +70,15 @@ func mustBeErrNotFound(t *testing.T, kind string, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func mustBeErrAlreadyExists(t *testing.T, kind string, err error) {
|
||||
switch {
|
||||
case err == nil:
|
||||
t.Errorf("attempting to create an existing %s should return an error", kind)
|
||||
case err != storage.ErrAlreadyExists:
|
||||
t.Errorf("creating an existing %s expected storage.ErrAlreadyExists, got %v", kind, err)
|
||||
}
|
||||
}
|
||||
|
||||
func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
|
||||
a := storage.AuthRequest{
|
||||
ID: storage.NewID(),
|
||||
|
@ -98,6 +107,11 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
|
|||
if err := s.CreateAuthRequest(a); err != nil {
|
||||
t.Fatalf("failed creating auth request: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create same AuthRequest twice.
|
||||
err := s.CreateAuthRequest(a)
|
||||
mustBeErrAlreadyExists(t, "auth request", err)
|
||||
|
||||
if err := s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
|
||||
old.Claims = identity
|
||||
old.ConnectorID = "connID"
|
||||
|
@ -138,6 +152,10 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("failed creating auth code: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create same AuthCode twice.
|
||||
err := s.CreateAuthCode(a)
|
||||
mustBeErrAlreadyExists(t, "auth code", err)
|
||||
|
||||
got, err := s.GetAuthCode(a.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get auth req: %v", err)
|
||||
|
@ -174,6 +192,10 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("create client: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create same Client twice.
|
||||
err = s.CreateClient(c)
|
||||
mustBeErrAlreadyExists(t, "client", err)
|
||||
|
||||
getAndCompare := func(id string, want storage.Client) {
|
||||
gc, err := s.GetClient(id)
|
||||
if err != nil {
|
||||
|
@ -230,6 +252,10 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("create refresh token: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create same Refresh Token twice.
|
||||
err := s.CreateRefresh(refresh)
|
||||
mustBeErrAlreadyExists(t, "refresh token", err)
|
||||
|
||||
getAndCompare := func(id string, want storage.RefreshToken) {
|
||||
gr, err := s.GetRefresh(id)
|
||||
if err != nil {
|
||||
|
@ -261,9 +287,8 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("failed to delete refresh request: %v", err)
|
||||
}
|
||||
|
||||
if _, err := s.GetRefresh(id); err != storage.ErrNotFound {
|
||||
t.Errorf("after deleting refresh expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
_, err = s.GetRefresh(id)
|
||||
mustBeErrNotFound(t, "refresh token", err)
|
||||
}
|
||||
|
||||
type byEmail []storage.Password
|
||||
|
@ -289,6 +314,10 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("create password token: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create same Password twice.
|
||||
err = s.CreatePassword(password)
|
||||
mustBeErrAlreadyExists(t, "password", err)
|
||||
|
||||
getAndCompare := func(id string, want storage.Password) {
|
||||
gr, err := s.GetPassword(id)
|
||||
if err != nil {
|
||||
|
@ -335,9 +364,8 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("failed to delete password: %v", err)
|
||||
}
|
||||
|
||||
if _, err := s.GetPassword(password.Email); err != storage.ErrNotFound {
|
||||
t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
_, err = s.GetPassword(password.Email)
|
||||
mustBeErrNotFound(t, "password", err)
|
||||
|
||||
}
|
||||
|
||||
|
@ -354,6 +382,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("create offline session: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create same OfflineSession twice.
|
||||
err := s.CreateOfflineSessions(session)
|
||||
mustBeErrAlreadyExists(t, "offline session", err)
|
||||
|
||||
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
|
||||
gr, err := s.GetOfflineSessions(userID, connID)
|
||||
if err != nil {
|
||||
|
@ -389,9 +421,8 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("failed to delete offline session: %v", err)
|
||||
}
|
||||
|
||||
if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound {
|
||||
t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
_, err = s.GetOfflineSessions(session.UserID, session.ConnID)
|
||||
mustBeErrNotFound(t, "offline session", err)
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,13 @@ import (
|
|||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
const (
|
||||
// postgres error codes
|
||||
pgErrUniqueViolation = "23505" // unique_violation
|
||||
)
|
||||
|
||||
// SQLite3 options for creating an SQL db.
|
||||
|
@ -35,7 +42,16 @@ func (s *SQLite3) open(logger logrus.FieldLogger) (*conn, error) {
|
|||
// doesn't support this, so limit the number of connections to 1.
|
||||
db.SetMaxOpenConns(1)
|
||||
}
|
||||
c := &conn{db, flavorSQLite3, logger}
|
||||
|
||||
errCheck := func(err error) bool {
|
||||
sqlErr, ok := err.(sqlite3.Error)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey
|
||||
}
|
||||
|
||||
c := &conn{db, flavorSQLite3, logger, errCheck}
|
||||
if _, err := c.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||
}
|
||||
|
@ -114,7 +130,16 @@ func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &conn{db, flavorPostgres, logger}
|
||||
|
||||
errCheck := func(err error) bool {
|
||||
sqlErr, ok := err.(*pq.Error)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return sqlErr.Code == pgErrUniqueViolation
|
||||
}
|
||||
|
||||
c := &conn{db, flavorPostgres, logger, errCheck}
|
||||
if _, err := c.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||
}
|
||||
|
|
|
@ -125,6 +125,9 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
|||
a.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert auth request: %v", err)
|
||||
}
|
||||
return nil
|
||||
|
@ -212,7 +215,14 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error {
|
|||
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
|
||||
a.ConnectorID, a.ConnectorData, a.Expiry,
|
||||
)
|
||||
return err
|
||||
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert auth code: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
||||
|
@ -256,6 +266,9 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
|||
r.Token, r.CreatedAt, r.LastUsed,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert refresh_token: %v", err)
|
||||
}
|
||||
return nil
|
||||
|
@ -477,6 +490,9 @@ func (c *conn) CreateClient(cli storage.Client) error {
|
|||
cli.Public, cli.Name, cli.LogoURL,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert client: %v", err)
|
||||
}
|
||||
return nil
|
||||
|
@ -544,6 +560,9 @@ func (c *conn) CreatePassword(p storage.Password) error {
|
|||
p.Email, p.Hash, p.Username, p.UserID,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert password: %v", err)
|
||||
}
|
||||
return nil
|
||||
|
@ -636,6 +655,9 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
|||
s.UserID, s.ConnID, encoder(s.Refresh),
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert offline session: %v", err)
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
|
@ -21,7 +22,15 @@ func TestMigrate(t *testing.T) {
|
|||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
c := &conn{db, flavorSQLite3, logger}
|
||||
errCheck := func(err error) bool {
|
||||
sqlErr, ok := err.(sqlite3.Error)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique
|
||||
}
|
||||
|
||||
c := &conn{db, flavorSQLite3, logger, errCheck}
|
||||
for _, want := range []int{len(migrations), 0} {
|
||||
got, err := c.migrate()
|
||||
if err != nil {
|
||||
|
|
|
@ -131,9 +131,10 @@ func (c *conn) translateArgs(args []interface{}) []interface{} {
|
|||
|
||||
// conn is the main database connection.
|
||||
type conn struct {
|
||||
db *sql.DB
|
||||
flavor flavor
|
||||
logger logrus.FieldLogger
|
||||
db *sql.DB
|
||||
flavor flavor
|
||||
logger logrus.FieldLogger
|
||||
alreadyExistsCheck func(err error) bool
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
|
|
Loading…
Reference in a new issue