dex/db/user.go
2016-03-02 14:47:00 -08:00

482 lines
10 KiB
Go

package db
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
)
const (
// This table is named authd_user for historical reasons; namely, that the
// original name of the project was authd, and there are existing tables out
// there that we don't want to have to rename in production.
userTableName = "authd_user"
remoteIdentityMappingTableName = "remote_identity_mapping"
)
func init() {
register(table{
name: userTableName,
model: userModel{},
autoinc: false,
pkey: []string{"id"},
unique: []string{"email"},
})
register(table{
name: remoteIdentityMappingTableName,
model: remoteIdentityMappingModel{},
autoinc: false,
pkey: []string{"connector_id", "remote_id"},
})
}
func NewUserRepo(dbm *gorp.DbMap) user.UserRepo {
return &userRepo{
db: &db{dbm},
}
}
func NewUserRepoFromUsers(dbm *gorp.DbMap, us []user.UserWithRemoteIdentities) (user.UserRepo, error) {
repo := NewUserRepo(dbm).(*userRepo)
for _, u := range us {
um, err := newUserModel(&u.User)
if err != nil {
return nil, err
}
err = repo.executor(nil).Insert(um)
for _, ri := range u.RemoteIdentities {
err = repo.AddRemoteIdentity(nil, u.User.ID, ri)
if err != nil {
return nil, err
}
}
}
return repo, nil
}
type userRepo struct {
*db
}
func (r *userRepo) Get(tx repo.Transaction, userID string) (user.User, error) {
return r.get(tx, userID)
}
func (r *userRepo) Create(tx repo.Transaction, usr user.User) (err error) {
if usr.ID == "" {
return user.ErrorInvalidID
}
_, err = r.get(tx, usr.ID)
if err == nil {
return user.ErrorDuplicateID
}
if err != user.ErrorNotFound {
return err
}
if !user.ValidEmail(usr.Email) {
return user.ErrorInvalidEmail
}
// make sure there's no other user with the same Email
_, err = r.getByEmail(tx, usr.Email)
if err == nil {
return user.ErrorDuplicateEmail
}
if err != user.ErrorNotFound {
return err
}
err = r.insert(tx, usr)
return err
}
func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) error {
if userID == "" {
return user.ErrorInvalidID
}
qt := r.quote(userTableName)
ex := r.executor(tx)
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID)
if err != nil {
return err
}
ct, err := result.RowsAffected()
switch {
case err != nil:
return err
case ct == 0:
return user.ErrorNotFound
}
return nil
}
func (r *userRepo) GetByEmail(tx repo.Transaction, email string) (user.User, error) {
return r.getByEmail(tx, email)
}
func (r *userRepo) Update(tx repo.Transaction, usr user.User) error {
if usr.ID == "" {
return user.ErrorInvalidID
}
if !user.ValidEmail(usr.Email) {
return user.ErrorInvalidEmail
}
// make sure this user exists already
_, err := r.get(tx, usr.ID)
if err != nil {
return err
}
// make sure there's no other user with the same Email
otherUser, err := r.getByEmail(tx, usr.Email)
if err != user.ErrorNotFound {
if err != nil {
return err
}
if otherUser.ID != usr.ID {
return user.ErrorDuplicateEmail
}
}
err = r.update(tx, usr)
if err != nil {
return err
}
return nil
}
func (r *userRepo) GetByRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (user.User, error) {
userID, err := r.getUserIDForRemoteIdentity(tx, ri)
if err != nil {
return user.User{}, err
}
usr, err := r.get(tx, userID)
if err != nil {
return user.User{}, err
}
if err != nil {
return user.User{}, err
}
return usr, nil
}
func (r *userRepo) AddRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error {
_, err := r.get(tx, userID)
if err != nil {
return err
}
otherUserID, err := r.getUserIDForRemoteIdentity(tx, ri)
if err != user.ErrorNotFound {
if err == nil && otherUserID != "" {
return user.ErrorDuplicateRemoteIdentity
}
return err
}
err = r.insertRemoteIdentity(tx, userID, ri)
if err != nil {
return err
}
return nil
}
func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error {
if userID == "" || rid.ID == "" || rid.ConnectorID == "" {
return user.ErrorInvalidID
}
otherUserID, err := r.getUserIDForRemoteIdentity(tx, rid)
if err != nil {
return err
}
if otherUserID != userID {
return user.ErrorNotFound
}
rim, err := newRemoteIdentityMappingModel(userID, rid)
if err != nil {
return err
}
ex := r.executor(tx)
deleted, err := ex.Delete(rim)
if err != nil {
return err
}
if deleted == 0 {
return user.ErrorNotFound
}
return nil
}
func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) {
ex := r.executor(tx)
if userID == "" {
return nil, user.ErrorInvalidID
}
qt := r.quote(remoteIdentityMappingTableName)
rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID)
if err != nil {
if err != sql.ErrNoRows {
return nil, err
}
return nil, err
}
if len(rims) == 0 {
return nil, nil
}
var ris []user.RemoteIdentity
for _, m := range rims {
rim, ok := m.(*remoteIdentityMappingModel)
if !ok {
log.Errorf("expected remoteIdentityMappingModel but found %v", reflect.TypeOf(m))
return nil, errors.New("unrecognized model")
}
ris = append(ris, user.RemoteIdentity{
ID: rim.RemoteID,
ConnectorID: rim.ConnectorID,
})
}
return ris, nil
}
func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) {
qt := r.quote(userTableName)
ex := r.executor(tx)
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt))
return int(i), err
}
func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults int, nextPageToken string) ([]user.User, string, error) {
var offset int
var err error
if nextPageToken != "" {
filter, maxResults, offset, err = user.DecodeNextPageToken(nextPageToken)
}
if err != nil {
return nil, "", err
}
ex := r.executor(tx)
qt := r.quote(userTableName)
// Ask for one more than needed so we know if there's more results, and
// hence, whether a nextPageToken is necessary.
ums, err := ex.Select(&userModel{}, fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2", qt), maxResults+1, offset)
if err != nil {
return nil, "", err
}
if len(ums) == 0 {
return nil, "", user.ErrorNotFound
}
var more bool
var numUsers int
if len(ums) <= maxResults {
numUsers = len(ums)
} else {
numUsers = maxResults
more = true
}
users := make([]user.User, numUsers)
for i := 0; i < numUsers; i++ {
um, ok := ums[i].(*userModel)
if !ok {
log.Errorf("expected userModel but found %v", reflect.TypeOf(ums[i]))
return nil, "", errors.New("unrecognized model")
}
usr, err := um.user()
if err != nil {
return nil, "", err
}
users[i] = usr
}
var tok string
if more {
tok, err = user.EncodeNextPageToken(filter, maxResults, offset+maxResults)
if err != nil {
return nil, "", err
}
}
return users, tok, nil
}
func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
ex := r.executor(tx)
um, err := newUserModel(&usr)
if err != nil {
return err
}
return ex.Insert(um)
}
func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
ex := r.executor(tx)
um, err := newUserModel(&usr)
if err != nil {
return err
}
_, err = ex.Update(um)
return err
}
func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
ex := r.executor(tx)
m, err := ex.Get(userModel{}, userID)
if err != nil {
return user.User{}, err
}
if m == nil {
return user.User{}, user.ErrorNotFound
}
um, ok := m.(*userModel)
if !ok {
log.Errorf("expected userModel but found %v", reflect.TypeOf(m))
return user.User{}, errors.New("unrecognized model")
}
return um.user()
}
func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) {
ex := r.executor(tx)
m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID)
if err != nil {
return "", err
}
if m == nil {
return "", user.ErrorNotFound
}
rim, ok := m.(*remoteIdentityMappingModel)
if !ok {
log.Errorf("expected remoteIdentityMappingModel but found %v", reflect.TypeOf(m))
return "", errors.New("unrecognized model")
}
return rim.UserID, nil
}
func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) {
qt := r.quote(userTableName)
ex := r.executor(tx)
var um userModel
err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), strings.ToLower(email))
if err != nil {
if err == sql.ErrNoRows {
return user.User{}, user.ErrorNotFound
}
return user.User{}, err
}
return um.user()
}
func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error {
ex := r.executor(tx)
rim, err := newRemoteIdentityMappingModel(userID, ri)
if err != nil {
return err
}
err = ex.Insert(rim)
return err
}
type userModel struct {
ID string `db:"id"`
Email string `db:"email"`
EmailVerified bool `db:"email_verified"`
DisplayName string `db:"display_name"`
Disabled bool `db:"disabled"`
Admin bool `db:"admin"`
CreatedAt int64 `db:"created_at"`
}
func (u *userModel) user() (user.User, error) {
usr := user.User{
ID: u.ID,
DisplayName: u.DisplayName,
Email: u.Email,
EmailVerified: u.EmailVerified,
Admin: u.Admin,
Disabled: u.Disabled,
}
if u.CreatedAt != 0 {
usr.CreatedAt = time.Unix(u.CreatedAt, 0).UTC()
}
return usr, nil
}
func newUserModel(u *user.User) (*userModel, error) {
um := userModel{
ID: u.ID,
DisplayName: u.DisplayName,
Email: strings.ToLower(u.Email),
EmailVerified: u.EmailVerified,
Admin: u.Admin,
Disabled: u.Disabled,
}
if !u.CreatedAt.IsZero() {
um.CreatedAt = u.CreatedAt.Unix()
}
return &um, nil
}
func newRemoteIdentityMappingModel(userID string, ri user.RemoteIdentity) (*remoteIdentityMappingModel, error) {
return &remoteIdentityMappingModel{
ConnectorID: ri.ConnectorID,
UserID: userID,
RemoteID: ri.ID,
}, nil
}
type remoteIdentityMappingModel struct {
ConnectorID string `db:"connector_id"`
UserID string `db:"user_id"`
RemoteID string `db:"remote_id"`
}