604 lines
18 KiB
Go
604 lines
18 KiB
Go
// Copyright 2012 James Cooper. All rights reserved.
|
|
// Use of this source code is governed by a MIT-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// Package gorp provides a simple way to marshal Go structs to and from
|
|
// SQL databases. It uses the database/sql package, and should work with any
|
|
// compliant database/sql driver.
|
|
//
|
|
// Source code and project home:
|
|
// https://github.com/go-gorp/gorp
|
|
//
|
|
package gorp
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// DbMap is the root gorp mapping object. Create one of these for each
|
|
// database schema you wish to map. Each DbMap contains a list of
|
|
// mapped tables.
|
|
//
|
|
// Example:
|
|
//
|
|
// dialect := gorp.MySQLDialect{"InnoDB", "UTF8"}
|
|
// dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
|
//
|
|
type DbMap struct {
|
|
// Db handle to use with this map
|
|
Db *sql.DB
|
|
|
|
// Dialect implementation to use with this map
|
|
Dialect Dialect
|
|
|
|
TypeConverter TypeConverter
|
|
|
|
tables []*TableMap
|
|
logger GorpLogger
|
|
logPrefix string
|
|
}
|
|
|
|
func (m *DbMap) CreateIndex() error {
|
|
|
|
var err error
|
|
dialect := reflect.TypeOf(m.Dialect)
|
|
for _, table := range m.tables {
|
|
for _, index := range table.indexes {
|
|
|
|
s := bytes.Buffer{}
|
|
s.WriteString("create")
|
|
if index.Unique {
|
|
s.WriteString(" unique")
|
|
}
|
|
s.WriteString(" index")
|
|
s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName))
|
|
if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" {
|
|
s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType))
|
|
}
|
|
s.WriteString(" (")
|
|
x := 0
|
|
for _, col := range index.columns {
|
|
if x > 0 {
|
|
s.WriteString(", ")
|
|
}
|
|
s.WriteString(m.Dialect.QuoteField(col))
|
|
}
|
|
s.WriteString(")")
|
|
|
|
if dname := dialect.Name(); dname == "MySQLDialect" && index.IndexType != "" {
|
|
s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType))
|
|
}
|
|
s.WriteString(";")
|
|
_, err = m.Exec(s.String())
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (t *TableMap) DropIndex(name string) error {
|
|
|
|
var err error
|
|
dialect := reflect.TypeOf(t.dbmap.Dialect)
|
|
for _, idx := range t.indexes {
|
|
if idx.IndexName == name {
|
|
s := bytes.Buffer{}
|
|
s.WriteString(fmt.Sprintf("DROP INDEX %s", idx.IndexName))
|
|
|
|
if dname := dialect.Name(); dname == "MySQLDialect" {
|
|
s.WriteString(fmt.Sprintf(" %s %s", t.dbmap.Dialect.DropIndexSuffix(), t.TableName))
|
|
}
|
|
s.WriteString(";")
|
|
_, e := t.dbmap.Exec(s.String())
|
|
if e != nil {
|
|
err = e
|
|
}
|
|
break
|
|
}
|
|
}
|
|
t.ResetSql()
|
|
return err
|
|
}
|
|
|
|
// AddTable registers the given interface type with gorp. The table name
|
|
// will be given the name of the TypeOf(i). You must call this function,
|
|
// or AddTableWithName, for any struct type you wish to persist with
|
|
// the given DbMap.
|
|
//
|
|
// This operation is idempotent. If i's type is already mapped, the
|
|
// existing *TableMap is returned
|
|
func (m *DbMap) AddTable(i interface{}) *TableMap {
|
|
return m.AddTableWithName(i, "")
|
|
}
|
|
|
|
// AddTableWithName has the same behavior as AddTable, but sets
|
|
// table.TableName to name.
|
|
func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap {
|
|
return m.AddTableWithNameAndSchema(i, "", name)
|
|
}
|
|
|
|
// AddTableWithNameAndSchema has the same behavior as AddTable, but sets
|
|
// table.TableName to name.
|
|
func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name string) *TableMap {
|
|
t := reflect.TypeOf(i)
|
|
if name == "" {
|
|
name = t.Name()
|
|
}
|
|
|
|
// check if we have a table for this type already
|
|
// if so, update the name and return the existing pointer
|
|
for i := range m.tables {
|
|
table := m.tables[i]
|
|
if table.gotype == t {
|
|
table.TableName = name
|
|
return table
|
|
}
|
|
}
|
|
|
|
tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m}
|
|
var primaryKey []*ColumnMap
|
|
tmap.Columns, primaryKey = m.readStructColumns(t)
|
|
m.tables = append(m.tables, tmap)
|
|
if len(primaryKey) > 0 {
|
|
tmap.keys = append(tmap.keys, primaryKey...)
|
|
}
|
|
|
|
return tmap
|
|
}
|
|
|
|
func (m *DbMap) readStructColumns(t reflect.Type) (cols []*ColumnMap, primaryKey []*ColumnMap) {
|
|
primaryKey = make([]*ColumnMap, 0)
|
|
n := t.NumField()
|
|
for i := 0; i < n; i++ {
|
|
f := t.Field(i)
|
|
if f.Anonymous && f.Type.Kind() == reflect.Struct {
|
|
// Recursively add nested fields in embedded structs.
|
|
subcols, subpk := m.readStructColumns(f.Type)
|
|
// Don't append nested fields that have the same field
|
|
// name as an already-mapped field.
|
|
for _, subcol := range subcols {
|
|
shouldAppend := true
|
|
for _, col := range cols {
|
|
if !subcol.Transient && subcol.fieldName == col.fieldName {
|
|
shouldAppend = false
|
|
break
|
|
}
|
|
}
|
|
if shouldAppend {
|
|
cols = append(cols, subcol)
|
|
}
|
|
}
|
|
if subpk != nil {
|
|
primaryKey = append(primaryKey, subpk...)
|
|
}
|
|
} else {
|
|
// Tag = Name { ',' Option }
|
|
// Option = OptionKey [ ':' OptionValue ]
|
|
cArguments := strings.Split(f.Tag.Get("db"), ",")
|
|
columnName := cArguments[0]
|
|
var maxSize int
|
|
var defaultValue string
|
|
var isAuto bool
|
|
var isPK bool
|
|
for _, argString := range cArguments[1:] {
|
|
argString = strings.TrimSpace(argString)
|
|
arg := strings.SplitN(argString, ":", 2)
|
|
|
|
// check mandatory/unexpected option values
|
|
switch arg[0] {
|
|
case "size", "default":
|
|
// options requiring value
|
|
if len(arg) == 1 {
|
|
panic(fmt.Sprintf("missing option value for option %v on field %v", arg[0], f.Name))
|
|
}
|
|
default:
|
|
// options where value is invalid (currently all other options)
|
|
if len(arg) == 2 {
|
|
panic(fmt.Sprintf("unexpected option value for option %v on field %v", arg[0], f.Name))
|
|
}
|
|
}
|
|
|
|
switch arg[0] {
|
|
case "size":
|
|
maxSize, _ = strconv.Atoi(arg[1])
|
|
case "default":
|
|
defaultValue = arg[1]
|
|
case "primarykey":
|
|
isPK = true
|
|
case "autoincrement":
|
|
isAuto = true
|
|
default:
|
|
panic(fmt.Sprintf("Unrecognized tag option for field %v: %v", f.Name, arg))
|
|
}
|
|
}
|
|
if columnName == "" {
|
|
columnName = f.Name
|
|
}
|
|
|
|
gotype := f.Type
|
|
value := reflect.New(gotype).Interface()
|
|
if m.TypeConverter != nil {
|
|
// Make a new pointer to a value of type gotype and
|
|
// pass it to the TypeConverter's FromDb method to see
|
|
// if a different type should be used for the column
|
|
// type during table creation.
|
|
scanner, useHolder := m.TypeConverter.FromDb(value)
|
|
if useHolder {
|
|
value = scanner.Holder
|
|
gotype = reflect.TypeOf(value)
|
|
}
|
|
}
|
|
if typer, ok := value.(SqlTyper); ok {
|
|
gotype = reflect.TypeOf(typer.SqlType())
|
|
} else if valuer, ok := value.(driver.Valuer); ok {
|
|
// Only check for driver.Valuer if SqlTyper wasn't
|
|
// found.
|
|
v, err := valuer.Value()
|
|
if err == nil && v != nil {
|
|
gotype = reflect.TypeOf(v)
|
|
}
|
|
}
|
|
cm := &ColumnMap{
|
|
ColumnName: columnName,
|
|
DefaultValue: defaultValue,
|
|
Transient: columnName == "-",
|
|
fieldName: f.Name,
|
|
gotype: gotype,
|
|
isPK: isPK,
|
|
isAutoIncr: isAuto,
|
|
MaxSize: maxSize,
|
|
}
|
|
if isPK {
|
|
primaryKey = append(primaryKey, cm)
|
|
}
|
|
// Check for nested fields of the same field name and
|
|
// override them.
|
|
shouldAppend := true
|
|
for index, col := range cols {
|
|
if !col.Transient && col.fieldName == cm.fieldName {
|
|
cols[index] = cm
|
|
shouldAppend = false
|
|
break
|
|
}
|
|
}
|
|
if shouldAppend {
|
|
cols = append(cols, cm)
|
|
}
|
|
}
|
|
|
|
}
|
|
return
|
|
}
|
|
|
|
// CreateTables iterates through TableMaps registered to this DbMap and
|
|
// executes "create table" statements against the database for each.
|
|
//
|
|
// This is particularly useful in unit tests where you want to create
|
|
// and destroy the schema automatically.
|
|
func (m *DbMap) CreateTables() error {
|
|
return m.createTables(false)
|
|
}
|
|
|
|
// CreateTablesIfNotExists is similar to CreateTables, but starts
|
|
// each statement with "create table if not exists" so that existing
|
|
// tables do not raise errors
|
|
func (m *DbMap) CreateTablesIfNotExists() error {
|
|
return m.createTables(true)
|
|
}
|
|
|
|
func (m *DbMap) createTables(ifNotExists bool) error {
|
|
var err error
|
|
for i := range m.tables {
|
|
table := m.tables[i]
|
|
sql := table.SqlForCreate(ifNotExists)
|
|
_, err = m.Exec(sql)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
// DropTable drops an individual table.
|
|
// Returns an error when the table does not exist.
|
|
func (m *DbMap) DropTable(table interface{}) error {
|
|
t := reflect.TypeOf(table)
|
|
return m.dropTable(t, false)
|
|
}
|
|
|
|
// DropTableIfExists drops an individual table when the table exists.
|
|
func (m *DbMap) DropTableIfExists(table interface{}) error {
|
|
t := reflect.TypeOf(table)
|
|
return m.dropTable(t, true)
|
|
}
|
|
|
|
// DropTables iterates through TableMaps registered to this DbMap and
|
|
// executes "drop table" statements against the database for each.
|
|
func (m *DbMap) DropTables() error {
|
|
return m.dropTables(false)
|
|
}
|
|
|
|
// DropTablesIfExists is the same as DropTables, but uses the "if exists" clause to
|
|
// avoid errors for tables that do not exist.
|
|
func (m *DbMap) DropTablesIfExists() error {
|
|
return m.dropTables(true)
|
|
}
|
|
|
|
// Goes through all the registered tables, dropping them one by one.
|
|
// If an error is encountered, then it is returned and the rest of
|
|
// the tables are not dropped.
|
|
func (m *DbMap) dropTables(addIfExists bool) (err error) {
|
|
for _, table := range m.tables {
|
|
err = m.dropTableImpl(table, addIfExists)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Implementation of dropping a single table.
|
|
func (m *DbMap) dropTable(t reflect.Type, addIfExists bool) error {
|
|
table := tableOrNil(m, t)
|
|
if table == nil {
|
|
return fmt.Errorf("table %s was not registered", table.TableName)
|
|
}
|
|
|
|
return m.dropTableImpl(table, addIfExists)
|
|
}
|
|
|
|
func (m *DbMap) dropTableImpl(table *TableMap, ifExists bool) (err error) {
|
|
tableDrop := "drop table"
|
|
if ifExists {
|
|
tableDrop = m.Dialect.IfTableExists(tableDrop, table.SchemaName, table.TableName)
|
|
}
|
|
_, err = m.Exec(fmt.Sprintf("%s %s;", tableDrop, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
|
|
return err
|
|
}
|
|
|
|
// TruncateTables iterates through TableMaps registered to this DbMap and
|
|
// executes "truncate table" statements against the database for each, or in the case of
|
|
// sqlite, a "delete from" with no "where" clause, which uses the truncate optimization
|
|
// (http://www.sqlite.org/lang_delete.html)
|
|
func (m *DbMap) TruncateTables() error {
|
|
var err error
|
|
for i := range m.tables {
|
|
table := m.tables[i]
|
|
_, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
|
|
if e != nil {
|
|
err = e
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Insert runs a SQL INSERT statement for each element in list. List
|
|
// items must be pointers.
|
|
//
|
|
// Any interface whose TableMap has an auto-increment primary key will
|
|
// have its last insert id bound to the PK field on the struct.
|
|
//
|
|
// The hook functions PreInsert() and/or PostInsert() will be executed
|
|
// before/after the INSERT statement if the interface defines them.
|
|
//
|
|
// Panics if any interface in the list has not been registered with AddTable
|
|
func (m *DbMap) Insert(list ...interface{}) error {
|
|
return insert(m, m, list...)
|
|
}
|
|
|
|
// Update runs a SQL UPDATE statement for each element in list. List
|
|
// items must be pointers.
|
|
//
|
|
// The hook functions PreUpdate() and/or PostUpdate() will be executed
|
|
// before/after the UPDATE statement if the interface defines them.
|
|
//
|
|
// Returns the number of rows updated.
|
|
//
|
|
// Returns an error if SetKeys has not been called on the TableMap
|
|
// Panics if any interface in the list has not been registered with AddTable
|
|
func (m *DbMap) Update(list ...interface{}) (int64, error) {
|
|
return update(m, m, list...)
|
|
}
|
|
|
|
// Delete runs a SQL DELETE statement for each element in list. List
|
|
// items must be pointers.
|
|
//
|
|
// The hook functions PreDelete() and/or PostDelete() will be executed
|
|
// before/after the DELETE statement if the interface defines them.
|
|
//
|
|
// Returns the number of rows deleted.
|
|
//
|
|
// Returns an error if SetKeys has not been called on the TableMap
|
|
// Panics if any interface in the list has not been registered with AddTable
|
|
func (m *DbMap) Delete(list ...interface{}) (int64, error) {
|
|
return delete(m, m, list...)
|
|
}
|
|
|
|
// Get runs a SQL SELECT to fetch a single row from the table based on the
|
|
// primary key(s)
|
|
//
|
|
// i should be an empty value for the struct to load. keys should be
|
|
// the primary key value(s) for the row to load. If multiple keys
|
|
// exist on the table, the order should match the column order
|
|
// specified in SetKeys() when the table mapping was defined.
|
|
//
|
|
// The hook function PostGet() will be executed after the SELECT
|
|
// statement if the interface defines them.
|
|
//
|
|
// Returns a pointer to a struct that matches or nil if no row is found.
|
|
//
|
|
// Returns an error if SetKeys has not been called on the TableMap
|
|
// Panics if any interface in the list has not been registered with AddTable
|
|
func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) {
|
|
return get(m, m, i, keys...)
|
|
}
|
|
|
|
// Select runs an arbitrary SQL query, binding the columns in the result
|
|
// to fields on the struct specified by i. args represent the bind
|
|
// parameters for the SQL statement.
|
|
//
|
|
// Column names on the SELECT statement should be aliased to the field names
|
|
// on the struct i. Returns an error if one or more columns in the result
|
|
// do not match. It is OK if fields on i are not part of the SQL
|
|
// statement.
|
|
//
|
|
// The hook function PostGet() will be executed after the SELECT
|
|
// statement if the interface defines them.
|
|
//
|
|
// Values are returned in one of two ways:
|
|
// 1. If i is a struct or a pointer to a struct, returns a slice of pointers to
|
|
// matching rows of type i.
|
|
// 2. If i is a pointer to a slice, the results will be appended to that slice
|
|
// and nil returned.
|
|
//
|
|
// i does NOT need to be registered with AddTable()
|
|
func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
|
return hookedselect(m, m, i, query, args...)
|
|
}
|
|
|
|
// Exec runs an arbitrary SQL statement. args represent the bind parameters.
|
|
// This is equivalent to running: Exec() using database/sql
|
|
func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
if m.logger != nil {
|
|
now := time.Now()
|
|
defer m.trace(now, query, args...)
|
|
}
|
|
return exec(m, query, args...)
|
|
}
|
|
|
|
// SelectInt is a convenience wrapper around the gorp.SelectInt function
|
|
func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) {
|
|
return SelectInt(m, query, args...)
|
|
}
|
|
|
|
// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function
|
|
func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
|
|
return SelectNullInt(m, query, args...)
|
|
}
|
|
|
|
// SelectFloat is a convenience wrapper around the gorp.SelectFloat function
|
|
func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) {
|
|
return SelectFloat(m, query, args...)
|
|
}
|
|
|
|
// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function
|
|
func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
|
|
return SelectNullFloat(m, query, args...)
|
|
}
|
|
|
|
// SelectStr is a convenience wrapper around the gorp.SelectStr function
|
|
func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) {
|
|
return SelectStr(m, query, args...)
|
|
}
|
|
|
|
// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function
|
|
func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
|
|
return SelectNullStr(m, query, args...)
|
|
}
|
|
|
|
// SelectOne is a convenience wrapper around the gorp.SelectOne function
|
|
func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error {
|
|
return SelectOne(m, m, holder, query, args...)
|
|
}
|
|
|
|
// Begin starts a gorp Transaction
|
|
func (m *DbMap) Begin() (*Transaction, error) {
|
|
if m.logger != nil {
|
|
now := time.Now()
|
|
defer m.trace(now, "begin;")
|
|
}
|
|
tx, err := m.Db.Begin()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Transaction{m, tx, false}, nil
|
|
}
|
|
|
|
// TableFor returns the *TableMap corresponding to the given Go Type
|
|
// If no table is mapped to that type an error is returned.
|
|
// If checkPK is true and the mapped table has no registered PKs, an error is returned.
|
|
func (m *DbMap) TableFor(t reflect.Type, checkPK bool) (*TableMap, error) {
|
|
table := tableOrNil(m, t)
|
|
if table == nil {
|
|
return nil, fmt.Errorf("no table found for type: %v", t.Name())
|
|
}
|
|
|
|
if checkPK && len(table.keys) < 1 {
|
|
e := fmt.Sprintf("gorp: no keys defined for table: %s",
|
|
table.TableName)
|
|
return nil, errors.New(e)
|
|
}
|
|
|
|
return table, nil
|
|
}
|
|
|
|
// Prepare creates a prepared statement for later queries or executions.
|
|
// Multiple queries or executions may be run concurrently from the returned statement.
|
|
// This is equivalent to running: Prepare() using database/sql
|
|
func (m *DbMap) Prepare(query string) (*sql.Stmt, error) {
|
|
if m.logger != nil {
|
|
now := time.Now()
|
|
defer m.trace(now, query, nil)
|
|
}
|
|
return m.Db.Prepare(query)
|
|
}
|
|
|
|
func tableOrNil(m *DbMap, t reflect.Type) *TableMap {
|
|
for i := range m.tables {
|
|
table := m.tables[i]
|
|
if table.gotype == t {
|
|
return table
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, reflect.Value, error) {
|
|
ptrv := reflect.ValueOf(ptr)
|
|
if ptrv.Kind() != reflect.Ptr {
|
|
e := fmt.Sprintf("gorp: passed non-pointer: %v (kind=%v)", ptr,
|
|
ptrv.Kind())
|
|
return nil, reflect.Value{}, errors.New(e)
|
|
}
|
|
elem := ptrv.Elem()
|
|
etype := reflect.TypeOf(elem.Interface())
|
|
t, err := m.TableFor(etype, checkPK)
|
|
if err != nil {
|
|
return nil, reflect.Value{}, err
|
|
}
|
|
|
|
return t, elem, nil
|
|
}
|
|
|
|
func (m *DbMap) queryRow(query string, args ...interface{}) *sql.Row {
|
|
if m.logger != nil {
|
|
now := time.Now()
|
|
defer m.trace(now, query, args...)
|
|
}
|
|
return m.Db.QueryRow(query, args...)
|
|
}
|
|
|
|
func (m *DbMap) query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
if m.logger != nil {
|
|
now := time.Now()
|
|
defer m.trace(now, query, args...)
|
|
}
|
|
return m.Db.Query(query, args...)
|
|
}
|
|
|
|
func (m *DbMap) trace(started time.Time, query string, args ...interface{}) {
|
|
if m.logger != nil {
|
|
var margs = argsString(args...)
|
|
m.logger.Printf("%s%s [%s] (%v)", m.logPrefix, query, margs, (time.Now().Sub(started)))
|
|
}
|
|
}
|