diff --git a/server/api.go b/server/api.go index cbb030c7..25655d68 100644 --- a/server/api.go +++ b/server/api.go @@ -53,8 +53,10 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap LogoURL: req.Client.LogoUrl, } if err := d.s.CreateClient(c); err != nil { + if err == storage.ErrAlreadyExists { + return &api.CreateClientResp{AlreadyExists: true}, nil + } d.logger.Errorf("api: failed to create client: %v", err) - // TODO(ericchiang): Surface "already exists" errors. return nil, fmt.Errorf("create client: %v", err) } @@ -109,6 +111,9 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) UserID: req.Password.UserId, } if err := d.s.CreatePassword(p); err != nil { + if err == storage.ErrAlreadyExists { + return &api.CreatePasswordResp{AlreadyExists: true}, nil + } d.logger.Errorf("api: failed to create password: %v", err) return nil, fmt.Errorf("create password: %v", err) } diff --git a/server/api_test.go b/server/api_test.go index 4ee285f2..0d0381d9 100644 --- a/server/api_test.go +++ b/server/api_test.go @@ -37,10 +37,18 @@ func TestPassword(t *testing.T) { Password: &p, } - if _, err := serv.CreatePassword(ctx, &createReq); err != nil { + if resp, err := serv.CreatePassword(ctx, &createReq); err != nil || resp.AlreadyExists { + if resp.AlreadyExists { + t.Fatalf("Unable to create password since %s already exists", createReq.Password.Email) + } t.Fatalf("Unable to create password: %v", err) } + // Attempt to create a password that already exists. + if resp, _ := serv.CreatePassword(ctx, &createReq); !resp.AlreadyExists { + t.Fatalf("Created password %s twice", createReq.Password.Email) + } + updateReq := api.UpdatePasswordReq{ Email: "test@example.com", NewUsername: "test1", diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 01c62865..cd2efebb 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -70,6 +70,15 @@ func mustBeErrNotFound(t *testing.T, kind string, err error) { } } +func mustBeErrAlreadyExists(t *testing.T, kind string, err error) { + switch { + case err == nil: + t.Errorf("attempting to create an existing %s should return an error", kind) + case err != storage.ErrAlreadyExists: + t.Errorf("creating an existing %s expected storage.ErrAlreadyExists, got %v", kind, err) + } +} + func testAuthRequestCRUD(t *testing.T, s storage.Storage) { a := storage.AuthRequest{ ID: storage.NewID(), @@ -98,6 +107,11 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { if err := s.CreateAuthRequest(a); err != nil { t.Fatalf("failed creating auth request: %v", err) } + + // Attempt to create same AuthRequest twice. + err := s.CreateAuthRequest(a) + mustBeErrAlreadyExists(t, "auth request", err) + if err := s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { old.Claims = identity old.ConnectorID = "connID" @@ -138,6 +152,10 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { t.Fatalf("failed creating auth code: %v", err) } + // Attempt to create same AuthCode twice. + err := s.CreateAuthCode(a) + mustBeErrAlreadyExists(t, "auth code", err) + got, err := s.GetAuthCode(a.ID) if err != nil { t.Fatalf("failed to get auth req: %v", err) @@ -174,6 +192,10 @@ func testClientCRUD(t *testing.T, s storage.Storage) { t.Fatalf("create client: %v", err) } + // Attempt to create same Client twice. + err = s.CreateClient(c) + mustBeErrAlreadyExists(t, "client", err) + getAndCompare := func(id string, want storage.Client) { gc, err := s.GetClient(id) if err != nil { @@ -230,6 +252,10 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { t.Fatalf("create refresh token: %v", err) } + // Attempt to create same Refresh Token twice. + err := s.CreateRefresh(refresh) + mustBeErrAlreadyExists(t, "refresh token", err) + getAndCompare := func(id string, want storage.RefreshToken) { gr, err := s.GetRefresh(id) if err != nil { @@ -261,9 +287,8 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { t.Fatalf("failed to delete refresh request: %v", err) } - if _, err := s.GetRefresh(id); err != storage.ErrNotFound { - t.Errorf("after deleting refresh expected storage.ErrNotFound, got %v", err) - } + _, err = s.GetRefresh(id) + mustBeErrNotFound(t, "refresh token", err) } type byEmail []storage.Password @@ -289,6 +314,10 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { t.Fatalf("create password token: %v", err) } + // Attempt to create same Password twice. + err = s.CreatePassword(password) + mustBeErrAlreadyExists(t, "password", err) + getAndCompare := func(id string, want storage.Password) { gr, err := s.GetPassword(id) if err != nil { @@ -335,9 +364,8 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { t.Fatalf("failed to delete password: %v", err) } - if _, err := s.GetPassword(password.Email); err != storage.ErrNotFound { - t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err) - } + _, err = s.GetPassword(password.Email) + mustBeErrNotFound(t, "password", err) } @@ -354,6 +382,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { t.Fatalf("create offline session: %v", err) } + // Attempt to create same OfflineSession twice. + err := s.CreateOfflineSessions(session) + mustBeErrAlreadyExists(t, "offline session", err) + getAndCompare := func(userID string, connID string, want storage.OfflineSessions) { gr, err := s.GetOfflineSessions(userID, connID) if err != nil { @@ -389,9 +421,8 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { t.Fatalf("failed to delete offline session: %v", err) } - if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound { - t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err) - } + _, err = s.GetOfflineSessions(session.UserID, session.ConnID) + mustBeErrNotFound(t, "offline session", err) } diff --git a/storage/sql/config.go b/storage/sql/config.go index d77e5481..56ee0c1a 100644 --- a/storage/sql/config.go +++ b/storage/sql/config.go @@ -8,6 +8,13 @@ import ( "github.com/Sirupsen/logrus" "github.com/coreos/dex/storage" + "github.com/lib/pq" + sqlite3 "github.com/mattn/go-sqlite3" +) + +const ( + // postgres error codes + pgErrUniqueViolation = "23505" // unique_violation ) // SQLite3 options for creating an SQL db. @@ -35,7 +42,16 @@ func (s *SQLite3) open(logger logrus.FieldLogger) (*conn, error) { // doesn't support this, so limit the number of connections to 1. db.SetMaxOpenConns(1) } - c := &conn{db, flavorSQLite3, logger} + + errCheck := func(err error) bool { + sqlErr, ok := err.(sqlite3.Error) + if !ok { + return false + } + return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey + } + + c := &conn{db, flavorSQLite3, logger, errCheck} if _, err := c.migrate(); err != nil { return nil, fmt.Errorf("failed to perform migrations: %v", err) } @@ -114,7 +130,16 @@ func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) { if err != nil { return nil, err } - c := &conn{db, flavorPostgres, logger} + + errCheck := func(err error) bool { + sqlErr, ok := err.(*pq.Error) + if !ok { + return false + } + return sqlErr.Code == pgErrUniqueViolation + } + + c := &conn{db, flavorPostgres, logger, errCheck} if _, err := c.migrate(); err != nil { return nil, fmt.Errorf("failed to perform migrations: %v", err) } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index ef1a8fbd..8c00dfd7 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -125,6 +125,9 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { a.Expiry, ) if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } return fmt.Errorf("insert auth request: %v", err) } return nil @@ -212,7 +215,14 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error { a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry, ) - return err + + if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } + return fmt.Errorf("insert auth code: %v", err) + } + return nil } func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { @@ -256,6 +266,9 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { r.Token, r.CreatedAt, r.LastUsed, ) if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } return fmt.Errorf("insert refresh_token: %v", err) } return nil @@ -477,6 +490,9 @@ func (c *conn) CreateClient(cli storage.Client) error { cli.Public, cli.Name, cli.LogoURL, ) if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } return fmt.Errorf("insert client: %v", err) } return nil @@ -544,6 +560,9 @@ func (c *conn) CreatePassword(p storage.Password) error { p.Email, p.Hash, p.Username, p.UserID, ) if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } return fmt.Errorf("insert password: %v", err) } return nil @@ -636,6 +655,9 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { s.UserID, s.ConnID, encoder(s.Refresh), ) if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } return fmt.Errorf("insert offline session: %v", err) } return nil diff --git a/storage/sql/migrate_test.go b/storage/sql/migrate_test.go index d46839f1..0e9e0179 100644 --- a/storage/sql/migrate_test.go +++ b/storage/sql/migrate_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/Sirupsen/logrus" + sqlite3 "github.com/mattn/go-sqlite3" ) func TestMigrate(t *testing.T) { @@ -21,7 +22,15 @@ func TestMigrate(t *testing.T) { Level: logrus.DebugLevel, } - c := &conn{db, flavorSQLite3, logger} + errCheck := func(err error) bool { + sqlErr, ok := err.(sqlite3.Error) + if !ok { + return false + } + return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique + } + + c := &conn{db, flavorSQLite3, logger, errCheck} for _, want := range []int{len(migrations), 0} { got, err := c.migrate() if err != nil { diff --git a/storage/sql/sql.go b/storage/sql/sql.go index 0d3026e4..2c6b74bc 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -131,9 +131,10 @@ func (c *conn) translateArgs(args []interface{}) []interface{} { // conn is the main database connection. type conn struct { - db *sql.DB - flavor flavor - logger logrus.FieldLogger + db *sql.DB + flavor flavor + logger logrus.FieldLogger + alreadyExistsCheck func(err error) bool } func (c *conn) Close() error {