forked from mystiq/dex
Merge pull request #809 from rithujohn191/set-error-flag
storage: Surface "already exists" errors.
This commit is contained in:
commit
c76832eaea
7 changed files with 119 additions and 18 deletions
|
@ -53,8 +53,10 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap
|
||||||
LogoURL: req.Client.LogoUrl,
|
LogoURL: req.Client.LogoUrl,
|
||||||
}
|
}
|
||||||
if err := d.s.CreateClient(c); err != nil {
|
if err := d.s.CreateClient(c); err != nil {
|
||||||
|
if err == storage.ErrAlreadyExists {
|
||||||
|
return &api.CreateClientResp{AlreadyExists: true}, nil
|
||||||
|
}
|
||||||
d.logger.Errorf("api: failed to create client: %v", err)
|
d.logger.Errorf("api: failed to create client: %v", err)
|
||||||
// TODO(ericchiang): Surface "already exists" errors.
|
|
||||||
return nil, fmt.Errorf("create client: %v", err)
|
return nil, fmt.Errorf("create client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,6 +111,9 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq)
|
||||||
UserID: req.Password.UserId,
|
UserID: req.Password.UserId,
|
||||||
}
|
}
|
||||||
if err := d.s.CreatePassword(p); err != nil {
|
if err := d.s.CreatePassword(p); err != nil {
|
||||||
|
if err == storage.ErrAlreadyExists {
|
||||||
|
return &api.CreatePasswordResp{AlreadyExists: true}, nil
|
||||||
|
}
|
||||||
d.logger.Errorf("api: failed to create password: %v", err)
|
d.logger.Errorf("api: failed to create password: %v", err)
|
||||||
return nil, fmt.Errorf("create password: %v", err)
|
return nil, fmt.Errorf("create password: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,10 +37,18 @@ func TestPassword(t *testing.T) {
|
||||||
Password: &p,
|
Password: &p,
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := serv.CreatePassword(ctx, &createReq); err != nil {
|
if resp, err := serv.CreatePassword(ctx, &createReq); err != nil || resp.AlreadyExists {
|
||||||
|
if resp.AlreadyExists {
|
||||||
|
t.Fatalf("Unable to create password since %s already exists", createReq.Password.Email)
|
||||||
|
}
|
||||||
t.Fatalf("Unable to create password: %v", err)
|
t.Fatalf("Unable to create password: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to create a password that already exists.
|
||||||
|
if resp, _ := serv.CreatePassword(ctx, &createReq); !resp.AlreadyExists {
|
||||||
|
t.Fatalf("Created password %s twice", createReq.Password.Email)
|
||||||
|
}
|
||||||
|
|
||||||
updateReq := api.UpdatePasswordReq{
|
updateReq := api.UpdatePasswordReq{
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
NewUsername: "test1",
|
NewUsername: "test1",
|
||||||
|
|
|
@ -70,6 +70,15 @@ func mustBeErrNotFound(t *testing.T, kind string, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustBeErrAlreadyExists(t *testing.T, kind string, err error) {
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
t.Errorf("attempting to create an existing %s should return an error", kind)
|
||||||
|
case err != storage.ErrAlreadyExists:
|
||||||
|
t.Errorf("creating an existing %s expected storage.ErrAlreadyExists, got %v", kind, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
|
func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
|
||||||
a := storage.AuthRequest{
|
a := storage.AuthRequest{
|
||||||
ID: storage.NewID(),
|
ID: storage.NewID(),
|
||||||
|
@ -98,6 +107,11 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
|
||||||
if err := s.CreateAuthRequest(a); err != nil {
|
if err := s.CreateAuthRequest(a); err != nil {
|
||||||
t.Fatalf("failed creating auth request: %v", err)
|
t.Fatalf("failed creating auth request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to create same AuthRequest twice.
|
||||||
|
err := s.CreateAuthRequest(a)
|
||||||
|
mustBeErrAlreadyExists(t, "auth request", err)
|
||||||
|
|
||||||
if err := s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
|
if err := s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
|
||||||
old.Claims = identity
|
old.Claims = identity
|
||||||
old.ConnectorID = "connID"
|
old.ConnectorID = "connID"
|
||||||
|
@ -138,6 +152,10 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("failed creating auth code: %v", err)
|
t.Fatalf("failed creating auth code: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to create same AuthCode twice.
|
||||||
|
err := s.CreateAuthCode(a)
|
||||||
|
mustBeErrAlreadyExists(t, "auth code", err)
|
||||||
|
|
||||||
got, err := s.GetAuthCode(a.ID)
|
got, err := s.GetAuthCode(a.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get auth req: %v", err)
|
t.Fatalf("failed to get auth req: %v", err)
|
||||||
|
@ -174,6 +192,10 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("create client: %v", err)
|
t.Fatalf("create client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to create same Client twice.
|
||||||
|
err = s.CreateClient(c)
|
||||||
|
mustBeErrAlreadyExists(t, "client", err)
|
||||||
|
|
||||||
getAndCompare := func(id string, want storage.Client) {
|
getAndCompare := func(id string, want storage.Client) {
|
||||||
gc, err := s.GetClient(id)
|
gc, err := s.GetClient(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -230,6 +252,10 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("create refresh token: %v", err)
|
t.Fatalf("create refresh token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to create same Refresh Token twice.
|
||||||
|
err := s.CreateRefresh(refresh)
|
||||||
|
mustBeErrAlreadyExists(t, "refresh token", err)
|
||||||
|
|
||||||
getAndCompare := func(id string, want storage.RefreshToken) {
|
getAndCompare := func(id string, want storage.RefreshToken) {
|
||||||
gr, err := s.GetRefresh(id)
|
gr, err := s.GetRefresh(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -261,9 +287,8 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("failed to delete refresh request: %v", err)
|
t.Fatalf("failed to delete refresh request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.GetRefresh(id); err != storage.ErrNotFound {
|
_, err = s.GetRefresh(id)
|
||||||
t.Errorf("after deleting refresh expected storage.ErrNotFound, got %v", err)
|
mustBeErrNotFound(t, "refresh token", err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type byEmail []storage.Password
|
type byEmail []storage.Password
|
||||||
|
@ -289,6 +314,10 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("create password token: %v", err)
|
t.Fatalf("create password token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to create same Password twice.
|
||||||
|
err = s.CreatePassword(password)
|
||||||
|
mustBeErrAlreadyExists(t, "password", err)
|
||||||
|
|
||||||
getAndCompare := func(id string, want storage.Password) {
|
getAndCompare := func(id string, want storage.Password) {
|
||||||
gr, err := s.GetPassword(id)
|
gr, err := s.GetPassword(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -335,9 +364,8 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("failed to delete password: %v", err)
|
t.Fatalf("failed to delete password: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.GetPassword(password.Email); err != storage.ErrNotFound {
|
_, err = s.GetPassword(password.Email)
|
||||||
t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err)
|
mustBeErrNotFound(t, "password", err)
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,6 +382,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("create offline session: %v", err)
|
t.Fatalf("create offline session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to create same OfflineSession twice.
|
||||||
|
err := s.CreateOfflineSessions(session)
|
||||||
|
mustBeErrAlreadyExists(t, "offline session", err)
|
||||||
|
|
||||||
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
|
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
|
||||||
gr, err := s.GetOfflineSessions(userID, connID)
|
gr, err := s.GetOfflineSessions(userID, connID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -389,9 +421,8 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("failed to delete offline session: %v", err)
|
t.Fatalf("failed to delete offline session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound {
|
_, err = s.GetOfflineSessions(session.UserID, session.ConnID)
|
||||||
t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err)
|
mustBeErrNotFound(t, "offline session", err)
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,13 @@ import (
|
||||||
|
|
||||||
"github.com/Sirupsen/logrus"
|
"github.com/Sirupsen/logrus"
|
||||||
"github.com/coreos/dex/storage"
|
"github.com/coreos/dex/storage"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
sqlite3 "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// postgres error codes
|
||||||
|
pgErrUniqueViolation = "23505" // unique_violation
|
||||||
)
|
)
|
||||||
|
|
||||||
// SQLite3 options for creating an SQL db.
|
// SQLite3 options for creating an SQL db.
|
||||||
|
@ -35,7 +42,16 @@ func (s *SQLite3) open(logger logrus.FieldLogger) (*conn, error) {
|
||||||
// doesn't support this, so limit the number of connections to 1.
|
// doesn't support this, so limit the number of connections to 1.
|
||||||
db.SetMaxOpenConns(1)
|
db.SetMaxOpenConns(1)
|
||||||
}
|
}
|
||||||
c := &conn{db, flavorSQLite3, logger}
|
|
||||||
|
errCheck := func(err error) bool {
|
||||||
|
sqlErr, ok := err.(sqlite3.Error)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &conn{db, flavorSQLite3, logger, errCheck}
|
||||||
if _, err := c.migrate(); err != nil {
|
if _, err := c.migrate(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -114,7 +130,16 @@ func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c := &conn{db, flavorPostgres, logger}
|
|
||||||
|
errCheck := func(err error) bool {
|
||||||
|
sqlErr, ok := err.(*pq.Error)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return sqlErr.Code == pgErrUniqueViolation
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &conn{db, flavorPostgres, logger, errCheck}
|
||||||
if _, err := c.migrate(); err != nil {
|
if _, err := c.migrate(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -125,6 +125,9 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||||
a.Expiry,
|
a.Expiry,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if c.alreadyExistsCheck(err) {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
return fmt.Errorf("insert auth request: %v", err)
|
return fmt.Errorf("insert auth request: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -212,7 +215,14 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error {
|
||||||
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
|
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
|
||||||
a.ConnectorID, a.ConnectorData, a.Expiry,
|
a.ConnectorID, a.ConnectorData, a.Expiry,
|
||||||
)
|
)
|
||||||
return err
|
|
||||||
|
if err != nil {
|
||||||
|
if c.alreadyExistsCheck(err) {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
|
return fmt.Errorf("insert auth code: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
||||||
|
@ -256,6 +266,9 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||||
r.Token, r.CreatedAt, r.LastUsed,
|
r.Token, r.CreatedAt, r.LastUsed,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if c.alreadyExistsCheck(err) {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
return fmt.Errorf("insert refresh_token: %v", err)
|
return fmt.Errorf("insert refresh_token: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -477,6 +490,9 @@ func (c *conn) CreateClient(cli storage.Client) error {
|
||||||
cli.Public, cli.Name, cli.LogoURL,
|
cli.Public, cli.Name, cli.LogoURL,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if c.alreadyExistsCheck(err) {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
return fmt.Errorf("insert client: %v", err)
|
return fmt.Errorf("insert client: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -544,6 +560,9 @@ func (c *conn) CreatePassword(p storage.Password) error {
|
||||||
p.Email, p.Hash, p.Username, p.UserID,
|
p.Email, p.Hash, p.Username, p.UserID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if c.alreadyExistsCheck(err) {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
return fmt.Errorf("insert password: %v", err)
|
return fmt.Errorf("insert password: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -636,6 +655,9 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||||
s.UserID, s.ConnID, encoder(s.Refresh),
|
s.UserID, s.ConnID, encoder(s.Refresh),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if c.alreadyExistsCheck(err) {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
return fmt.Errorf("insert offline session: %v", err)
|
return fmt.Errorf("insert offline session: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Sirupsen/logrus"
|
"github.com/Sirupsen/logrus"
|
||||||
|
sqlite3 "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMigrate(t *testing.T) {
|
func TestMigrate(t *testing.T) {
|
||||||
|
@ -21,7 +22,15 @@ func TestMigrate(t *testing.T) {
|
||||||
Level: logrus.DebugLevel,
|
Level: logrus.DebugLevel,
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &conn{db, flavorSQLite3, logger}
|
errCheck := func(err error) bool {
|
||||||
|
sqlErr, ok := err.(sqlite3.Error)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &conn{db, flavorSQLite3, logger, errCheck}
|
||||||
for _, want := range []int{len(migrations), 0} {
|
for _, want := range []int{len(migrations), 0} {
|
||||||
got, err := c.migrate()
|
got, err := c.migrate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -134,6 +134,7 @@ type conn struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
flavor flavor
|
flavor flavor
|
||||||
logger logrus.FieldLogger
|
logger logrus.FieldLogger
|
||||||
|
alreadyExistsCheck func(err error) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) Close() error {
|
func (c *conn) Close() error {
|
||||||
|
|
Loading…
Reference in a new issue