This repository has been archived on 2022-08-17. You can view files and clone it, but cannot push or open issues or pull requests.
dex/vendor/github.com/go-gorp/gorp/gorp.go
2016-03-09 13:04:05 -08:00

558 lines
14 KiB
Go

// Copyright 2012 James Cooper. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
// Package gorp provides a simple way to marshal Go structs to and from
// SQL databases. It uses the database/sql package, and should work with any
// compliant database/sql driver.
//
// Source code and project home:
// https://github.com/go-gorp/gorp
//
package gorp
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strings"
"time"
)
// OracleString (empty string is null)
// TODO: move to dialect/oracle?, rename to String?
type OracleString struct {
sql.NullString
}
// Scan implements the Scanner interface.
func (os *OracleString) Scan(value interface{}) error {
if value == nil {
os.String, os.Valid = "", false
return nil
}
os.Valid = true
return os.NullString.Scan(value)
}
// Value implements the driver Valuer interface.
func (os OracleString) Value() (driver.Value, error) {
if !os.Valid || os.String == "" {
return nil, nil
}
return os.String, nil
}
// SqlTyper is a type that returns its database type. Most of the
// time, the type can just use "database/sql/driver".Valuer; but when
// it returns nil for its empty value, it needs to implement SqlTyper
// to have its column type detected properly during table creation.
type SqlTyper interface {
SqlType() driver.Valuer
}
// for fields that exists in DB table, but not exists in struct
type dummyField struct{}
// Scan implements the Scanner interface.
func (nt *dummyField) Scan(value interface{}) error {
return nil
}
var zeroVal reflect.Value
var versFieldConst = "[gorp_ver_field]"
// The TypeConverter interface provides a way to map a value of one
// type to another type when persisting to, or loading from, a database.
//
// Example use cases: Implement type converter to convert bool types to "y"/"n" strings,
// or serialize a struct member as a JSON blob.
type TypeConverter interface {
// ToDb converts val to another type. Called before INSERT/UPDATE operations
ToDb(val interface{}) (interface{}, error)
// FromDb returns a CustomScanner appropriate for this type. This will be used
// to hold values returned from SELECT queries.
//
// In particular the CustomScanner returned should implement a Binder
// function appropriate for the Go type you wish to convert the db value to
//
// If bool==false, then no custom scanner will be used for this field.
FromDb(target interface{}) (CustomScanner, bool)
}
// Executor exposes the sql.DB and sql.Tx Exec function so that it can be used
// on internal functions that convert named parameters for the Exec function.
type executor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}
// SqlExecutor exposes gorp operations that can be run from Pre/Post
// hooks. This hides whether the current operation that triggered the
// hook is in a transaction.
//
// See the DbMap function docs for each of the functions below for more
// information.
type SqlExecutor interface {
Get(i interface{}, keys ...interface{}) (interface{}, error)
Insert(list ...interface{}) error
Update(list ...interface{}) (int64, error)
Delete(list ...interface{}) (int64, error)
Exec(query string, args ...interface{}) (sql.Result, error)
Select(i interface{}, query string,
args ...interface{}) ([]interface{}, error)
SelectInt(query string, args ...interface{}) (int64, error)
SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error)
SelectFloat(query string, args ...interface{}) (float64, error)
SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error)
SelectStr(query string, args ...interface{}) (string, error)
SelectNullStr(query string, args ...interface{}) (sql.NullString, error)
SelectOne(holder interface{}, query string, args ...interface{}) error
query(query string, args ...interface{}) (*sql.Rows, error)
queryRow(query string, args ...interface{}) *sql.Row
}
// Compile-time check that DbMap and Transaction implement the SqlExecutor
// interface.
var _, _ SqlExecutor = &DbMap{}, &Transaction{}
func argsString(args ...interface{}) string {
var margs string
for i, a := range args {
var v interface{} = a
if x, ok := v.(driver.Valuer); ok {
y, err := x.Value()
if err == nil {
v = y
}
}
switch v.(type) {
case string:
v = fmt.Sprintf("%q", v)
default:
v = fmt.Sprintf("%v", v)
}
margs += fmt.Sprintf("%d:%s", i+1, v)
if i+1 < len(args) {
margs += " "
}
}
return margs
}
// Calls the Exec function on the executor, but attempts to expand any eligible named
// query arguments first.
func exec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) {
var dbMap *DbMap
var executor executor
switch m := e.(type) {
case *DbMap:
executor = m.Db
dbMap = m
case *Transaction:
executor = m.tx
dbMap = m.dbmap
}
if len(args) == 1 {
query, args = maybeExpandNamedQuery(dbMap, query, args)
}
return executor.Exec(query, args...)
}
// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used
// as input to a named query. If so, it rewrites the query to use
// dialect-dependent bindvars and instantiates the corresponding slice of
// parameters by extracting data from the map / struct.
// If not, returns the input values unchanged.
func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) {
var (
arg = args[0]
argval = reflect.ValueOf(arg)
)
if argval.Kind() == reflect.Ptr {
argval = argval.Elem()
}
if argval.Kind() == reflect.Map && argval.Type().Key().Kind() == reflect.String {
return expandNamedQuery(m, query, func(key string) reflect.Value {
return argval.MapIndex(reflect.ValueOf(key))
})
}
if argval.Kind() != reflect.Struct {
return query, args
}
if _, ok := arg.(time.Time); ok {
// time.Time is driver.Value
return query, args
}
if _, ok := arg.(driver.Valuer); ok {
// driver.Valuer will be converted to driver.Value.
return query, args
}
return expandNamedQuery(m, query, argval.FieldByName)
}
var keyRegexp = regexp.MustCompile(`:[[:word:]]+`)
// expandNamedQuery accepts a query with placeholders of the form ":key", and a
// single arg of Kind Struct or Map[string]. It returns the query with the
// dialect's placeholders, and a slice of args ready for positional insertion
// into the query.
func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) {
var (
n int
args []interface{}
)
return keyRegexp.ReplaceAllStringFunc(query, func(key string) string {
val := keyGetter(key[1:])
if !val.IsValid() {
return key
}
args = append(args, val.Interface())
newVar := m.Dialect.BindVar(n)
n++
return newVar
}), args
}
func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error) {
colToFieldIndex := make([][]int, len(cols))
// check if type t is a mapped table - if so we'll
// check the table for column aliasing below
tableMapped := false
table := tableOrNil(m, t)
if table != nil {
tableMapped = true
}
// Loop over column names and find field in i to bind to
// based on column name. all returned columns must match
// a field in the i struct
missingColNames := []string{}
for x := range cols {
colName := strings.ToLower(cols[x])
field, found := t.FieldByNameFunc(func(fieldName string) bool {
field, _ := t.FieldByName(fieldName)
cArguments := strings.Split(field.Tag.Get("db"), ",")
fieldName = cArguments[0]
if fieldName == "-" {
return false
} else if fieldName == "" {
fieldName = field.Name
}
if tableMapped {
colMap := colMapOrNil(table, fieldName)
if colMap != nil {
fieldName = colMap.ColumnName
}
}
return colName == strings.ToLower(fieldName)
})
if found {
colToFieldIndex[x] = field.Index
}
if colToFieldIndex[x] == nil {
missingColNames = append(missingColNames, colName)
}
}
if len(missingColNames) > 0 {
return colToFieldIndex, &NoFieldInTypeError{
TypeName: t.Name(),
MissingColNames: missingColNames,
}
}
return colToFieldIndex, nil
}
func fieldByName(val reflect.Value, fieldName string) *reflect.Value {
// try to find field by exact match
f := val.FieldByName(fieldName)
if f != zeroVal {
return &f
}
// try to find by case insensitive match - only the Postgres driver
// seems to require this - in the case where columns are aliased in the sql
fieldNameL := strings.ToLower(fieldName)
fieldCount := val.NumField()
t := val.Type()
for i := 0; i < fieldCount; i++ {
sf := t.Field(i)
if strings.ToLower(sf.Name) == fieldNameL {
f := val.Field(i)
return &f
}
}
return nil
}
// toSliceType returns the element type of the given object, if the object is a
// "*[]*Element" or "*[]Element". If not, returns nil.
// err is returned if the user was trying to pass a pointer-to-slice but failed.
func toSliceType(i interface{}) (reflect.Type, error) {
t := reflect.TypeOf(i)
if t.Kind() != reflect.Ptr {
// If it's a slice, return a more helpful error message
if t.Kind() == reflect.Slice {
return nil, fmt.Errorf("gorp: cannot SELECT into a non-pointer slice: %v", t)
}
return nil, nil
}
if t = t.Elem(); t.Kind() != reflect.Slice {
return nil, nil
}
return t.Elem(), nil
}
func toType(i interface{}) (reflect.Type, error) {
t := reflect.TypeOf(i)
// If a Pointer to a type, follow
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("gorp: cannot SELECT into this type: %v", reflect.TypeOf(i))
}
return t, nil
}
func get(m *DbMap, exec SqlExecutor, i interface{},
keys ...interface{}) (interface{}, error) {
t, err := toType(i)
if err != nil {
return nil, err
}
table, err := m.TableFor(t, true)
if err != nil {
return nil, err
}
plan := table.bindGet()
v := reflect.New(t)
dest := make([]interface{}, len(plan.argFields))
conv := m.TypeConverter
custScan := make([]CustomScanner, 0)
for x, fieldName := range plan.argFields {
f := v.Elem().FieldByName(fieldName)
target := f.Addr().Interface()
if conv != nil {
scanner, ok := conv.FromDb(target)
if ok {
target = scanner.Holder
custScan = append(custScan, scanner)
}
}
dest[x] = target
}
row := exec.queryRow(plan.query, keys...)
err = row.Scan(dest...)
if err != nil {
if err == sql.ErrNoRows {
err = nil
}
return nil, err
}
for _, c := range custScan {
err = c.Bind()
if err != nil {
return nil, err
}
}
if v, ok := v.Interface().(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}
return v.Interface(), nil
}
func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
count := int64(0)
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreDelete); ok {
err = v.PreDelete(exec)
if err != nil {
return -1, err
}
}
bi, err := table.bindDelete(elem)
if err != nil {
return -1, err
}
res, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return -1, err
}
rows, err := res.RowsAffected()
if err != nil {
return -1, err
}
if rows == 0 && bi.existingVersion > 0 {
return lockError(m, exec, table.TableName,
bi.existingVersion, elem, bi.keys...)
}
count += rows
if v, ok := eval.(HasPostDelete); ok {
err := v.PostDelete(exec)
if err != nil {
return -1, err
}
}
}
return count, nil
}
func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
count := int64(0)
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreUpdate); ok {
err = v.PreUpdate(exec)
if err != nil {
return -1, err
}
}
bi, err := table.bindUpdate(elem)
if err != nil {
return -1, err
}
res, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return -1, err
}
rows, err := res.RowsAffected()
if err != nil {
return -1, err
}
if rows == 0 && bi.existingVersion > 0 {
return lockError(m, exec, table.TableName,
bi.existingVersion, elem, bi.keys...)
}
if bi.versField != "" {
elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1)
}
count += rows
if v, ok := eval.(HasPostUpdate); ok {
err = v.PostUpdate(exec)
if err != nil {
return -1, err
}
}
}
return count, nil
}
func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, false)
if err != nil {
return err
}
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreInsert); ok {
err := v.PreInsert(exec)
if err != nil {
return err
}
}
bi, err := table.bindInsert(elem)
if err != nil {
return err
}
if bi.autoIncrIdx > -1 {
f := elem.FieldByName(bi.autoIncrFieldName)
switch inserter := m.Dialect.(type) {
case IntegerAutoIncrInserter:
id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...)
if err != nil {
return err
}
k := f.Kind()
if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
f.SetInt(id)
} else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) {
f.SetUint(uint64(id))
} else {
return fmt.Errorf("gorp: cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName)
}
case TargetedAutoIncrInserter:
err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...)
if err != nil {
return err
}
case TargetQueryInserter:
var idQuery = table.ColMap(bi.autoIncrFieldName).GeneratedIdQuery
if idQuery == "" {
return fmt.Errorf("gorp: cannot set %s value if its ColumnMap.GeneratedIdQuery is empty", bi.autoIncrFieldName)
}
err := inserter.InsertQueryToTarget(exec, bi.query, idQuery, f.Addr().Interface(), bi.args...)
if err != nil {
return err
}
default:
return fmt.Errorf("gorp: cannot use autoincrement fields on dialects that do not implement an autoincrementing interface")
}
} else {
_, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return err
}
}
if v, ok := eval.(HasPostInsert); ok {
err := v.PostInsert(exec)
if err != nil {
return err
}
}
}
return nil
}