diff --git a/db/client.go b/db/client.go index eee4e755..6754368a 100644 --- a/db/client.go +++ b/db/client.go @@ -10,8 +10,6 @@ import ( "github.com/coreos/go-oidc/oidc" "github.com/go-gorp/gorp" - "github.com/lib/pq" - "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" "github.com/coreos/dex/client" @@ -217,6 +215,25 @@ func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, e return ok, nil } +var alreadyExistsCheckers []func(err error) bool + +func registerAlreadyExistsChecker(f func(err error) bool) { + alreadyExistsCheckers = append(alreadyExistsCheckers, f) +} + +// isAlreadyExistsErr detects database error codes for failing a unique constraint. +// +// Because database drivers are optionally compiled, use registerAlreadyExistsChecker to +// register driver specific implementations. +func isAlreadyExistsErr(err error) bool { + for _, checker := range alreadyExistsCheckers { + if checker(err) { + return true + } + } + return false +} + func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) { secret, err := pcrypto.RandBytes(maxSecretLength) if err != nil { @@ -229,17 +246,9 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli } if err := r.executor(nil).Insert(cim); err != nil { - switch sqlErr := err.(type) { - case *pq.Error: - if sqlErr.Code == pgErrorCodeUniqueViolation { - err = errors.New("client ID already exists") - } - case *sqlite3.Error: - if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique { - err = errors.New("client ID already exists") - } + if isAlreadyExistsErr(err) { + err = errors.New("client ID already exists") } - return nil, err } diff --git a/db/conn.go b/db/conn.go index e6256cc5..d74f7d18 100644 --- a/db/conn.go +++ b/db/conn.go @@ -9,10 +9,6 @@ import ( "github.com/go-gorp/gorp" "github.com/coreos/dex/repo" - - // Import database drivers - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" ) type table struct { diff --git a/db/conn_postgres.go b/db/conn_postgres.go new file mode 100644 index 00000000..dd52dae3 --- /dev/null +++ b/db/conn_postgres.go @@ -0,0 +1,15 @@ +package db + +// Register the postgres driver. + +import "github.com/lib/pq" + +func init() { + registerAlreadyExistsChecker(func(err error) bool { + sqlErr, ok := err.(*pq.Error) + if !ok { + return false + } + return sqlErr.Code == pgErrorCodeUniqueViolation + }) +} diff --git a/db/conn_sqlite3.go b/db/conn_sqlite3.go new file mode 100644 index 00000000..5c2d332b --- /dev/null +++ b/db/conn_sqlite3.go @@ -0,0 +1,17 @@ +// +build cgo + +package db + +// Register the sqlite3 driver. + +import "github.com/mattn/go-sqlite3" + +func init() { + registerAlreadyExistsChecker(func(err error) bool { + sqlErr, ok := err.(*sqlite3.Error) + if !ok { + return false + } + return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique + }) +}