db: protect the sqlite3 import with a cgo tag

This commit is contained in:
Eric Chiang 2016-03-02 12:02:55 -08:00
parent 350571acf6
commit 93b89ad0e9
4 changed files with 53 additions and 16 deletions

View file

@ -10,8 +10,6 @@ import (
"github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client"
@ -217,6 +215,25 @@ func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, e
return ok, nil
}
var alreadyExistsCheckers []func(err error) bool
func registerAlreadyExistsChecker(f func(err error) bool) {
alreadyExistsCheckers = append(alreadyExistsCheckers, f)
}
// isAlreadyExistsErr detects database error codes for failing a unique constraint.
//
// Because database drivers are optionally compiled, use registerAlreadyExistsChecker to
// register driver specific implementations.
func isAlreadyExistsErr(err error) bool {
for _, checker := range alreadyExistsCheckers {
if checker(err) {
return true
}
}
return false
}
func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) {
secret, err := pcrypto.RandBytes(maxSecretLength)
if err != nil {
@ -229,17 +246,9 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
}
if err := r.executor(nil).Insert(cim); err != nil {
switch sqlErr := err.(type) {
case *pq.Error:
if sqlErr.Code == pgErrorCodeUniqueViolation {
if isAlreadyExistsErr(err) {
err = errors.New("client ID already exists")
}
case *sqlite3.Error:
if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique {
err = errors.New("client ID already exists")
}
}
return nil, err
}

View file

@ -9,10 +9,6 @@ import (
"github.com/go-gorp/gorp"
"github.com/coreos/dex/repo"
// Import database drivers
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
type table struct {

15
db/conn_postgres.go Normal file
View file

@ -0,0 +1,15 @@
package db
// Register the postgres driver.
import "github.com/lib/pq"
func init() {
registerAlreadyExistsChecker(func(err error) bool {
sqlErr, ok := err.(*pq.Error)
if !ok {
return false
}
return sqlErr.Code == pgErrorCodeUniqueViolation
})
}

17
db/conn_sqlite3.go Normal file
View file

@ -0,0 +1,17 @@
// +build cgo
package db
// Register the sqlite3 driver.
import "github.com/mattn/go-sqlite3"
func init() {
registerAlreadyExistsChecker(func(err error) bool {
sqlErr, ok := err.(*sqlite3.Error)
if !ok {
return false
}
return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique
})
}