forked from mystiq/dex
Merge pull request #304 from ericchiang/sqlite3
move to sqlite3 for --no-db mode and tests
This commit is contained in:
commit
f51125f555
83 changed files with 159246 additions and 1796 deletions
|
@ -22,6 +22,7 @@ install:
|
|||
script:
|
||||
- docker run -d -p 127.0.0.1:15432:5432 quay.io/coreos/postgres
|
||||
- LDAPCONTAINER=`docker run -e LDAP_TLS_PROTOCOL_MIN=3.0 -e LDAP_TLS_CIPHER_SUITE=NORMAL -d -p 127.0.0.1:1389:389 -p 127.0.0.1:1636:636 -h tlstest.local osixia/openldap`
|
||||
- ./build
|
||||
- ./test
|
||||
- docker cp ${LDAPCONTAINER}:container/service/:cfssl/assets/default-ca/default-ca.pem /tmp/openldap-ca.pem
|
||||
- docker cp ${LDAPCONTAINER}:container/service/slapd/assets/certs/ldap.key /tmp/ldap.key
|
||||
|
@ -29,6 +30,7 @@ script:
|
|||
- docker cp ${LDAPCONTAINER}:container/service/slapd/assets/certs/ldap.crt /tmp/ldap.crt
|
||||
- sudo sh -c 'echo "127.0.0.1 tlstest.local" >> /etc/hosts'
|
||||
- ./test-functional
|
||||
- DEX_TEST_DSN="sqlite3://:memory:" ./test-functional
|
||||
|
||||
deploy:
|
||||
provider: script
|
||||
|
|
7
Godeps/Godeps.json
generated
7
Godeps/Godeps.json
generated
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"ImportPath": "github.com/coreos/dex",
|
||||
"GoVersion": "go1.4.2",
|
||||
"GoVersion": "go1.5",
|
||||
"Packages": [
|
||||
"./..."
|
||||
],
|
||||
|
@ -91,6 +91,11 @@
|
|||
"ImportPath": "github.com/mailgun/mailgun-go",
|
||||
"Rev": "9578dc67692294bb7e2a6f4b15dd18c97af19440"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/mattn/go-sqlite3",
|
||||
"Comment": "v1.1.0-25-g2513631",
|
||||
"Rev": "2513631704612107a1c8b1803fb8e6b3dda2230e"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/mbanzon/simplehttp",
|
||||
"Rev": "04c542e7ac706a25820090f274ea6a4f39a63326"
|
||||
|
|
3
Godeps/_workspace/src/github.com/mattn/go-sqlite3/.gitignore
generated
vendored
Normal file
3
Godeps/_workspace/src/github.com/mattn/go-sqlite3/.gitignore
generated
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
*.db
|
||||
*.exe
|
||||
*.dll
|
9
Godeps/_workspace/src/github.com/mattn/go-sqlite3/.travis.yml
generated
vendored
Normal file
9
Godeps/_workspace/src/github.com/mattn/go-sqlite3/.travis.yml
generated
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
language: go
|
||||
go:
|
||||
- tip
|
||||
before_install:
|
||||
- go get github.com/axw/gocov/gocov
|
||||
- go get github.com/mattn/goveralls
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
script:
|
||||
- $HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx
|
21
Godeps/_workspace/src/github.com/mattn/go-sqlite3/LICENSE
generated
vendored
Normal file
21
Godeps/_workspace/src/github.com/mattn/go-sqlite3/LICENSE
generated
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Yasuhiro Matsumoto
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
74
Godeps/_workspace/src/github.com/mattn/go-sqlite3/README.md
generated
vendored
Normal file
74
Godeps/_workspace/src/github.com/mattn/go-sqlite3/README.md
generated
vendored
Normal file
|
@ -0,0 +1,74 @@
|
|||
go-sqlite3
|
||||
==========
|
||||
|
||||
[![Build Status](https://travis-ci.org/mattn/go-sqlite3.png?branch=master)](https://travis-ci.org/mattn/go-sqlite3)
|
||||
[![Coverage Status](https://coveralls.io/repos/mattn/go-sqlite3/badge.png?branch=master)](https://coveralls.io/r/mattn/go-sqlite3?branch=master)
|
||||
|
||||
Description
|
||||
-----------
|
||||
|
||||
sqlite3 driver conforming to the built-in database/sql interface
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
This package can be installed with the go get command:
|
||||
|
||||
go get github.com/mattn/go-sqlite3
|
||||
|
||||
_go-sqlite3_ is *cgo* package.
|
||||
If you want to build your app using go-sqlite3, you need gcc.
|
||||
However, if you install _go-sqlite3_ with `go install github.com/mattn/go-sqlite`, you don't need gcc to build your app anymore.
|
||||
|
||||
Documentation
|
||||
-------------
|
||||
|
||||
API documentation can be found here: http://godoc.org/github.com/mattn/go-sqlite3
|
||||
|
||||
Examples can be found under the `./_example` directory
|
||||
|
||||
FAQ
|
||||
---
|
||||
|
||||
* Want to build go-sqlite3 with libsqlite3 on my linux.
|
||||
|
||||
Use `go build --tags "libsqlite3 linux"`
|
||||
|
||||
* Want to build go-sqlite3 with icu extension.
|
||||
|
||||
Use `go build --tags "icu"`
|
||||
|
||||
* Can't build go-sqlite3 on windows 64bit.
|
||||
|
||||
> Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit.
|
||||
> See: https://github.com/mattn/go-sqlite3/issues/27
|
||||
|
||||
* Getting insert error while query is opened.
|
||||
|
||||
> You can pass some arguments into the connection string, for example, a URI.
|
||||
> See: https://github.com/mattn/go-sqlite3/issues/39
|
||||
|
||||
* Do you want cross compiling? mingw on Linux or Mac?
|
||||
|
||||
> See: https://github.com/mattn/go-sqlite3/issues/106
|
||||
> See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html
|
||||
|
||||
* Want to get time.Time with current locale
|
||||
|
||||
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
|
||||
|
||||
License
|
||||
-------
|
||||
|
||||
MIT: http://mattn.mit-license.org/2012
|
||||
|
||||
sqlite3-binding.c, sqlite3-binding.h, sqlite3ext.h
|
||||
|
||||
The -binding suffix was added to avoid build failures under gccgo.
|
||||
|
||||
In this repository, those files are amalgamation code that copied from SQLite3. The license of those codes are depend on the license of SQLite3.
|
||||
|
||||
Author
|
||||
------
|
||||
|
||||
Yasuhiro Matsumoto (a.k.a mattn)
|
70
Godeps/_workspace/src/github.com/mattn/go-sqlite3/backup.go
generated
vendored
Normal file
70
Godeps/_workspace/src/github.com/mattn/go-sqlite3/backup.go
generated
vendored
Normal file
|
@ -0,0 +1,70 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#include <sqlite3-binding.h>
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type SQLiteBackup struct {
|
||||
b *C.sqlite3_backup
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*SQLiteBackup, error) {
|
||||
destptr := C.CString(dest)
|
||||
defer C.free(unsafe.Pointer(destptr))
|
||||
srcptr := C.CString(src)
|
||||
defer C.free(unsafe.Pointer(srcptr))
|
||||
|
||||
if b := C.sqlite3_backup_init(c.db, destptr, conn.db, srcptr); b != nil {
|
||||
bb := &SQLiteBackup{b: b}
|
||||
runtime.SetFinalizer(bb, (*SQLiteBackup).Finish)
|
||||
return bb, nil
|
||||
}
|
||||
return nil, c.lastError()
|
||||
}
|
||||
|
||||
// Backs up for one step. Calls the underlying `sqlite3_backup_step` function.
|
||||
// This function returns a boolean indicating if the backup is done and
|
||||
// an error signalling any other error. Done is returned if the underlying C
|
||||
// function returns SQLITE_DONE (Code 101)
|
||||
func (b *SQLiteBackup) Step(p int) (bool, error) {
|
||||
ret := C.sqlite3_backup_step(b.b, C.int(p))
|
||||
if ret == C.SQLITE_DONE {
|
||||
return true, nil
|
||||
} else if ret != 0 && ret != C.SQLITE_LOCKED && ret != C.SQLITE_BUSY {
|
||||
return false, Error{Code: ErrNo(ret)}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (b *SQLiteBackup) Remaining() int {
|
||||
return int(C.sqlite3_backup_remaining(b.b))
|
||||
}
|
||||
|
||||
func (b *SQLiteBackup) PageCount() int {
|
||||
return int(C.sqlite3_backup_pagecount(b.b))
|
||||
}
|
||||
|
||||
func (b *SQLiteBackup) Finish() error {
|
||||
return b.Close()
|
||||
}
|
||||
|
||||
func (b *SQLiteBackup) Close() error {
|
||||
ret := C.sqlite3_backup_finish(b.b)
|
||||
if ret != 0 {
|
||||
return Error{Code: ErrNo(ret)}
|
||||
}
|
||||
b.b = nil
|
||||
runtime.SetFinalizer(b, nil)
|
||||
return nil
|
||||
}
|
290
Godeps/_workspace/src/github.com/mattn/go-sqlite3/callback.go
generated
vendored
Normal file
290
Godeps/_workspace/src/github.com/mattn/go-sqlite3/callback.go
generated
vendored
Normal file
|
@ -0,0 +1,290 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
// You can't export a Go function to C and have definitions in the C
|
||||
// preamble in the same file, so we have to have callbackTrampoline in
|
||||
// its own file. Because we need a separate file anyway, the support
|
||||
// code for SQLite custom functions is in here.
|
||||
|
||||
/*
|
||||
#include <sqlite3-binding.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
|
||||
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
//export callbackTrampoline
|
||||
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
|
||||
fi.Call(ctx, args)
|
||||
}
|
||||
|
||||
//export stepTrampoline
|
||||
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
|
||||
ai.Step(ctx, args)
|
||||
}
|
||||
|
||||
//export doneTrampoline
|
||||
func doneTrampoline(ctx *C.sqlite3_context) {
|
||||
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
|
||||
ai.Done(ctx)
|
||||
}
|
||||
|
||||
// This is only here so that tests can refer to it.
|
||||
type callbackArgRaw C.sqlite3_value
|
||||
|
||||
type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
|
||||
|
||||
type callbackArgCast struct {
|
||||
f callbackArgConverter
|
||||
typ reflect.Type
|
||||
}
|
||||
|
||||
func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
val, err := c.f(v)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
if !val.Type().ConvertibleTo(c.typ) {
|
||||
return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
|
||||
}
|
||||
return val.Convert(c.typ), nil
|
||||
}
|
||||
|
||||
func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
||||
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
|
||||
}
|
||||
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
|
||||
}
|
||||
|
||||
func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
||||
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
|
||||
}
|
||||
i := int64(C.sqlite3_value_int64(v))
|
||||
val := false
|
||||
if i != 0 {
|
||||
val = true
|
||||
}
|
||||
return reflect.ValueOf(val), nil
|
||||
}
|
||||
|
||||
func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
|
||||
return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
|
||||
}
|
||||
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
|
||||
}
|
||||
|
||||
func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
switch C.sqlite3_value_type(v) {
|
||||
case C.SQLITE_BLOB:
|
||||
l := C.sqlite3_value_bytes(v)
|
||||
p := C.sqlite3_value_blob(v)
|
||||
return reflect.ValueOf(C.GoBytes(p, l)), nil
|
||||
case C.SQLITE_TEXT:
|
||||
l := C.sqlite3_value_bytes(v)
|
||||
c := unsafe.Pointer(C.sqlite3_value_text(v))
|
||||
return reflect.ValueOf(C.GoBytes(c, l)), nil
|
||||
default:
|
||||
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
|
||||
}
|
||||
}
|
||||
|
||||
func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
switch C.sqlite3_value_type(v) {
|
||||
case C.SQLITE_BLOB:
|
||||
l := C.sqlite3_value_bytes(v)
|
||||
p := (*C.char)(C.sqlite3_value_blob(v))
|
||||
return reflect.ValueOf(C.GoStringN(p, l)), nil
|
||||
case C.SQLITE_TEXT:
|
||||
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
|
||||
return reflect.ValueOf(C.GoString(c)), nil
|
||||
default:
|
||||
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
|
||||
}
|
||||
}
|
||||
|
||||
func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
switch C.sqlite3_value_type(v) {
|
||||
case C.SQLITE_INTEGER:
|
||||
return callbackArgInt64(v)
|
||||
case C.SQLITE_FLOAT:
|
||||
return callbackArgFloat64(v)
|
||||
case C.SQLITE_TEXT:
|
||||
return callbackArgString(v)
|
||||
case C.SQLITE_BLOB:
|
||||
return callbackArgBytes(v)
|
||||
case C.SQLITE_NULL:
|
||||
// Interpret NULL as a nil byte slice.
|
||||
var ret []byte
|
||||
return reflect.ValueOf(ret), nil
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
|
||||
switch typ.Kind() {
|
||||
case reflect.Interface:
|
||||
if typ.NumMethod() != 0 {
|
||||
return nil, errors.New("the only supported interface type is interface{}")
|
||||
}
|
||||
return callbackArgGeneric, nil
|
||||
case reflect.Slice:
|
||||
if typ.Elem().Kind() != reflect.Uint8 {
|
||||
return nil, errors.New("the only supported slice type is []byte")
|
||||
}
|
||||
return callbackArgBytes, nil
|
||||
case reflect.String:
|
||||
return callbackArgString, nil
|
||||
case reflect.Bool:
|
||||
return callbackArgBool, nil
|
||||
case reflect.Int64:
|
||||
return callbackArgInt64, nil
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||
c := callbackArgCast{callbackArgInt64, typ}
|
||||
return c.Run, nil
|
||||
case reflect.Float64:
|
||||
return callbackArgFloat64, nil
|
||||
case reflect.Float32:
|
||||
c := callbackArgCast{callbackArgFloat64, typ}
|
||||
return c.Run, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("don't know how to convert to %s", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
|
||||
var args []reflect.Value
|
||||
|
||||
if len(argv) < len(converters) {
|
||||
return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
|
||||
}
|
||||
|
||||
for i, arg := range argv[:len(converters)] {
|
||||
v, err := converters[i](arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
if variadic != nil {
|
||||
for _, arg := range argv[len(converters):] {
|
||||
v, err := variadic(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, v)
|
||||
}
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
|
||||
|
||||
func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Int64:
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||
v = v.Convert(reflect.TypeOf(int64(0)))
|
||||
case reflect.Bool:
|
||||
b := v.Interface().(bool)
|
||||
if b {
|
||||
v = reflect.ValueOf(int64(1))
|
||||
} else {
|
||||
v = reflect.ValueOf(int64(0))
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
|
||||
}
|
||||
|
||||
C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Float64:
|
||||
case reflect.Float32:
|
||||
v = v.Convert(reflect.TypeOf(float64(0)))
|
||||
default:
|
||||
return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
|
||||
}
|
||||
|
||||
C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return fmt.Errorf("cannot convert %s to BLOB", v.Type())
|
||||
}
|
||||
i := v.Interface()
|
||||
if i == nil || len(i.([]byte)) == 0 {
|
||||
C.sqlite3_result_null(ctx)
|
||||
} else {
|
||||
bs := i.([]byte)
|
||||
C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
if v.Type().Kind() != reflect.String {
|
||||
return fmt.Errorf("cannot convert %s to TEXT", v.Type())
|
||||
}
|
||||
C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
|
||||
switch typ.Kind() {
|
||||
case reflect.Slice:
|
||||
if typ.Elem().Kind() != reflect.Uint8 {
|
||||
return nil, errors.New("the only supported slice type is []byte")
|
||||
}
|
||||
return callbackRetBlob, nil
|
||||
case reflect.String:
|
||||
return callbackRetText, nil
|
||||
case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||
return callbackRetInteger, nil
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return callbackRetFloat, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("don't know how to convert to %s", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func callbackError(ctx *C.sqlite3_context, err error) {
|
||||
cstr := C.CString(err.Error())
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.sqlite3_result_error(ctx, cstr, -1)
|
||||
}
|
||||
|
||||
// Test support code. Tests are not allowed to import "C", so we can't
|
||||
// declare any functions that use C.sqlite3_value.
|
||||
func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
|
||||
return func(*C.sqlite3_value) (reflect.Value, error) {
|
||||
return v, err
|
||||
}
|
||||
}
|
147782
Godeps/_workspace/src/github.com/mattn/go-sqlite3/code/sqlite3-binding.c
generated
vendored
Normal file
147782
Godeps/_workspace/src/github.com/mattn/go-sqlite3/code/sqlite3-binding.c
generated
vendored
Normal file
File diff suppressed because it is too large
Load diff
7478
Godeps/_workspace/src/github.com/mattn/go-sqlite3/code/sqlite3-binding.h
generated
vendored
Normal file
7478
Godeps/_workspace/src/github.com/mattn/go-sqlite3/code/sqlite3-binding.h
generated
vendored
Normal file
File diff suppressed because it is too large
Load diff
487
Godeps/_workspace/src/github.com/mattn/go-sqlite3/code/sqlite3ext.h
generated
vendored
Normal file
487
Godeps/_workspace/src/github.com/mattn/go-sqlite3/code/sqlite3ext.h
generated
vendored
Normal file
|
@ -0,0 +1,487 @@
|
|||
/*
|
||||
** 2006 June 7
|
||||
**
|
||||
** The author disclaims copyright to this source code. In place of
|
||||
** a legal notice, here is a blessing:
|
||||
**
|
||||
** May you do good and not evil.
|
||||
** May you find forgiveness for yourself and forgive others.
|
||||
** May you share freely, never taking more than you give.
|
||||
**
|
||||
*************************************************************************
|
||||
** This header file defines the SQLite interface for use by
|
||||
** shared libraries that want to be imported as extensions into
|
||||
** an SQLite instance. Shared libraries that intend to be loaded
|
||||
** as extensions by SQLite should #include this file instead of
|
||||
** sqlite3.h.
|
||||
*/
|
||||
#ifndef _SQLITE3EXT_H_
|
||||
#define _SQLITE3EXT_H_
|
||||
#include "sqlite3-binding.h"
|
||||
|
||||
typedef struct sqlite3_api_routines sqlite3_api_routines;
|
||||
|
||||
/*
|
||||
** The following structure holds pointers to all of the SQLite API
|
||||
** routines.
|
||||
**
|
||||
** WARNING: In order to maintain backwards compatibility, add new
|
||||
** interfaces to the end of this structure only. If you insert new
|
||||
** interfaces in the middle of this structure, then older different
|
||||
** versions of SQLite will not be able to load each others' shared
|
||||
** libraries!
|
||||
*/
|
||||
struct sqlite3_api_routines {
|
||||
void * (*aggregate_context)(sqlite3_context*,int nBytes);
|
||||
int (*aggregate_count)(sqlite3_context*);
|
||||
int (*bind_blob)(sqlite3_stmt*,int,const void*,int n,void(*)(void*));
|
||||
int (*bind_double)(sqlite3_stmt*,int,double);
|
||||
int (*bind_int)(sqlite3_stmt*,int,int);
|
||||
int (*bind_int64)(sqlite3_stmt*,int,sqlite_int64);
|
||||
int (*bind_null)(sqlite3_stmt*,int);
|
||||
int (*bind_parameter_count)(sqlite3_stmt*);
|
||||
int (*bind_parameter_index)(sqlite3_stmt*,const char*zName);
|
||||
const char * (*bind_parameter_name)(sqlite3_stmt*,int);
|
||||
int (*bind_text)(sqlite3_stmt*,int,const char*,int n,void(*)(void*));
|
||||
int (*bind_text16)(sqlite3_stmt*,int,const void*,int,void(*)(void*));
|
||||
int (*bind_value)(sqlite3_stmt*,int,const sqlite3_value*);
|
||||
int (*busy_handler)(sqlite3*,int(*)(void*,int),void*);
|
||||
int (*busy_timeout)(sqlite3*,int ms);
|
||||
int (*changes)(sqlite3*);
|
||||
int (*close)(sqlite3*);
|
||||
int (*collation_needed)(sqlite3*,void*,void(*)(void*,sqlite3*,
|
||||
int eTextRep,const char*));
|
||||
int (*collation_needed16)(sqlite3*,void*,void(*)(void*,sqlite3*,
|
||||
int eTextRep,const void*));
|
||||
const void * (*column_blob)(sqlite3_stmt*,int iCol);
|
||||
int (*column_bytes)(sqlite3_stmt*,int iCol);
|
||||
int (*column_bytes16)(sqlite3_stmt*,int iCol);
|
||||
int (*column_count)(sqlite3_stmt*pStmt);
|
||||
const char * (*column_database_name)(sqlite3_stmt*,int);
|
||||
const void * (*column_database_name16)(sqlite3_stmt*,int);
|
||||
const char * (*column_decltype)(sqlite3_stmt*,int i);
|
||||
const void * (*column_decltype16)(sqlite3_stmt*,int);
|
||||
double (*column_double)(sqlite3_stmt*,int iCol);
|
||||
int (*column_int)(sqlite3_stmt*,int iCol);
|
||||
sqlite_int64 (*column_int64)(sqlite3_stmt*,int iCol);
|
||||
const char * (*column_name)(sqlite3_stmt*,int);
|
||||
const void * (*column_name16)(sqlite3_stmt*,int);
|
||||
const char * (*column_origin_name)(sqlite3_stmt*,int);
|
||||
const void * (*column_origin_name16)(sqlite3_stmt*,int);
|
||||
const char * (*column_table_name)(sqlite3_stmt*,int);
|
||||
const void * (*column_table_name16)(sqlite3_stmt*,int);
|
||||
const unsigned char * (*column_text)(sqlite3_stmt*,int iCol);
|
||||
const void * (*column_text16)(sqlite3_stmt*,int iCol);
|
||||
int (*column_type)(sqlite3_stmt*,int iCol);
|
||||
sqlite3_value* (*column_value)(sqlite3_stmt*,int iCol);
|
||||
void * (*commit_hook)(sqlite3*,int(*)(void*),void*);
|
||||
int (*complete)(const char*sql);
|
||||
int (*complete16)(const void*sql);
|
||||
int (*create_collation)(sqlite3*,const char*,int,void*,
|
||||
int(*)(void*,int,const void*,int,const void*));
|
||||
int (*create_collation16)(sqlite3*,const void*,int,void*,
|
||||
int(*)(void*,int,const void*,int,const void*));
|
||||
int (*create_function)(sqlite3*,const char*,int,int,void*,
|
||||
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
|
||||
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
|
||||
void (*xFinal)(sqlite3_context*));
|
||||
int (*create_function16)(sqlite3*,const void*,int,int,void*,
|
||||
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
|
||||
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
|
||||
void (*xFinal)(sqlite3_context*));
|
||||
int (*create_module)(sqlite3*,const char*,const sqlite3_module*,void*);
|
||||
int (*data_count)(sqlite3_stmt*pStmt);
|
||||
sqlite3 * (*db_handle)(sqlite3_stmt*);
|
||||
int (*declare_vtab)(sqlite3*,const char*);
|
||||
int (*enable_shared_cache)(int);
|
||||
int (*errcode)(sqlite3*db);
|
||||
const char * (*errmsg)(sqlite3*);
|
||||
const void * (*errmsg16)(sqlite3*);
|
||||
int (*exec)(sqlite3*,const char*,sqlite3_callback,void*,char**);
|
||||
int (*expired)(sqlite3_stmt*);
|
||||
int (*finalize)(sqlite3_stmt*pStmt);
|
||||
void (*free)(void*);
|
||||
void (*free_table)(char**result);
|
||||
int (*get_autocommit)(sqlite3*);
|
||||
void * (*get_auxdata)(sqlite3_context*,int);
|
||||
int (*get_table)(sqlite3*,const char*,char***,int*,int*,char**);
|
||||
int (*global_recover)(void);
|
||||
void (*interruptx)(sqlite3*);
|
||||
sqlite_int64 (*last_insert_rowid)(sqlite3*);
|
||||
const char * (*libversion)(void);
|
||||
int (*libversion_number)(void);
|
||||
void *(*malloc)(int);
|
||||
char * (*mprintf)(const char*,...);
|
||||
int (*open)(const char*,sqlite3**);
|
||||
int (*open16)(const void*,sqlite3**);
|
||||
int (*prepare)(sqlite3*,const char*,int,sqlite3_stmt**,const char**);
|
||||
int (*prepare16)(sqlite3*,const void*,int,sqlite3_stmt**,const void**);
|
||||
void * (*profile)(sqlite3*,void(*)(void*,const char*,sqlite_uint64),void*);
|
||||
void (*progress_handler)(sqlite3*,int,int(*)(void*),void*);
|
||||
void *(*realloc)(void*,int);
|
||||
int (*reset)(sqlite3_stmt*pStmt);
|
||||
void (*result_blob)(sqlite3_context*,const void*,int,void(*)(void*));
|
||||
void (*result_double)(sqlite3_context*,double);
|
||||
void (*result_error)(sqlite3_context*,const char*,int);
|
||||
void (*result_error16)(sqlite3_context*,const void*,int);
|
||||
void (*result_int)(sqlite3_context*,int);
|
||||
void (*result_int64)(sqlite3_context*,sqlite_int64);
|
||||
void (*result_null)(sqlite3_context*);
|
||||
void (*result_text)(sqlite3_context*,const char*,int,void(*)(void*));
|
||||
void (*result_text16)(sqlite3_context*,const void*,int,void(*)(void*));
|
||||
void (*result_text16be)(sqlite3_context*,const void*,int,void(*)(void*));
|
||||
void (*result_text16le)(sqlite3_context*,const void*,int,void(*)(void*));
|
||||
void (*result_value)(sqlite3_context*,sqlite3_value*);
|
||||
void * (*rollback_hook)(sqlite3*,void(*)(void*),void*);
|
||||
int (*set_authorizer)(sqlite3*,int(*)(void*,int,const char*,const char*,
|
||||
const char*,const char*),void*);
|
||||
void (*set_auxdata)(sqlite3_context*,int,void*,void (*)(void*));
|
||||
char * (*snprintf)(int,char*,const char*,...);
|
||||
int (*step)(sqlite3_stmt*);
|
||||
int (*table_column_metadata)(sqlite3*,const char*,const char*,const char*,
|
||||
char const**,char const**,int*,int*,int*);
|
||||
void (*thread_cleanup)(void);
|
||||
int (*total_changes)(sqlite3*);
|
||||
void * (*trace)(sqlite3*,void(*xTrace)(void*,const char*),void*);
|
||||
int (*transfer_bindings)(sqlite3_stmt*,sqlite3_stmt*);
|
||||
void * (*update_hook)(sqlite3*,void(*)(void*,int ,char const*,char const*,
|
||||
sqlite_int64),void*);
|
||||
void * (*user_data)(sqlite3_context*);
|
||||
const void * (*value_blob)(sqlite3_value*);
|
||||
int (*value_bytes)(sqlite3_value*);
|
||||
int (*value_bytes16)(sqlite3_value*);
|
||||
double (*value_double)(sqlite3_value*);
|
||||
int (*value_int)(sqlite3_value*);
|
||||
sqlite_int64 (*value_int64)(sqlite3_value*);
|
||||
int (*value_numeric_type)(sqlite3_value*);
|
||||
const unsigned char * (*value_text)(sqlite3_value*);
|
||||
const void * (*value_text16)(sqlite3_value*);
|
||||
const void * (*value_text16be)(sqlite3_value*);
|
||||
const void * (*value_text16le)(sqlite3_value*);
|
||||
int (*value_type)(sqlite3_value*);
|
||||
char *(*vmprintf)(const char*,va_list);
|
||||
/* Added ??? */
|
||||
int (*overload_function)(sqlite3*, const char *zFuncName, int nArg);
|
||||
/* Added by 3.3.13 */
|
||||
int (*prepare_v2)(sqlite3*,const char*,int,sqlite3_stmt**,const char**);
|
||||
int (*prepare16_v2)(sqlite3*,const void*,int,sqlite3_stmt**,const void**);
|
||||
int (*clear_bindings)(sqlite3_stmt*);
|
||||
/* Added by 3.4.1 */
|
||||
int (*create_module_v2)(sqlite3*,const char*,const sqlite3_module*,void*,
|
||||
void (*xDestroy)(void *));
|
||||
/* Added by 3.5.0 */
|
||||
int (*bind_zeroblob)(sqlite3_stmt*,int,int);
|
||||
int (*blob_bytes)(sqlite3_blob*);
|
||||
int (*blob_close)(sqlite3_blob*);
|
||||
int (*blob_open)(sqlite3*,const char*,const char*,const char*,sqlite3_int64,
|
||||
int,sqlite3_blob**);
|
||||
int (*blob_read)(sqlite3_blob*,void*,int,int);
|
||||
int (*blob_write)(sqlite3_blob*,const void*,int,int);
|
||||
int (*create_collation_v2)(sqlite3*,const char*,int,void*,
|
||||
int(*)(void*,int,const void*,int,const void*),
|
||||
void(*)(void*));
|
||||
int (*file_control)(sqlite3*,const char*,int,void*);
|
||||
sqlite3_int64 (*memory_highwater)(int);
|
||||
sqlite3_int64 (*memory_used)(void);
|
||||
sqlite3_mutex *(*mutex_alloc)(int);
|
||||
void (*mutex_enter)(sqlite3_mutex*);
|
||||
void (*mutex_free)(sqlite3_mutex*);
|
||||
void (*mutex_leave)(sqlite3_mutex*);
|
||||
int (*mutex_try)(sqlite3_mutex*);
|
||||
int (*open_v2)(const char*,sqlite3**,int,const char*);
|
||||
int (*release_memory)(int);
|
||||
void (*result_error_nomem)(sqlite3_context*);
|
||||
void (*result_error_toobig)(sqlite3_context*);
|
||||
int (*sleep)(int);
|
||||
void (*soft_heap_limit)(int);
|
||||
sqlite3_vfs *(*vfs_find)(const char*);
|
||||
int (*vfs_register)(sqlite3_vfs*,int);
|
||||
int (*vfs_unregister)(sqlite3_vfs*);
|
||||
int (*xthreadsafe)(void);
|
||||
void (*result_zeroblob)(sqlite3_context*,int);
|
||||
void (*result_error_code)(sqlite3_context*,int);
|
||||
int (*test_control)(int, ...);
|
||||
void (*randomness)(int,void*);
|
||||
sqlite3 *(*context_db_handle)(sqlite3_context*);
|
||||
int (*extended_result_codes)(sqlite3*,int);
|
||||
int (*limit)(sqlite3*,int,int);
|
||||
sqlite3_stmt *(*next_stmt)(sqlite3*,sqlite3_stmt*);
|
||||
const char *(*sql)(sqlite3_stmt*);
|
||||
int (*status)(int,int*,int*,int);
|
||||
int (*backup_finish)(sqlite3_backup*);
|
||||
sqlite3_backup *(*backup_init)(sqlite3*,const char*,sqlite3*,const char*);
|
||||
int (*backup_pagecount)(sqlite3_backup*);
|
||||
int (*backup_remaining)(sqlite3_backup*);
|
||||
int (*backup_step)(sqlite3_backup*,int);
|
||||
const char *(*compileoption_get)(int);
|
||||
int (*compileoption_used)(const char*);
|
||||
int (*create_function_v2)(sqlite3*,const char*,int,int,void*,
|
||||
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
|
||||
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
|
||||
void (*xFinal)(sqlite3_context*),
|
||||
void(*xDestroy)(void*));
|
||||
int (*db_config)(sqlite3*,int,...);
|
||||
sqlite3_mutex *(*db_mutex)(sqlite3*);
|
||||
int (*db_status)(sqlite3*,int,int*,int*,int);
|
||||
int (*extended_errcode)(sqlite3*);
|
||||
void (*log)(int,const char*,...);
|
||||
sqlite3_int64 (*soft_heap_limit64)(sqlite3_int64);
|
||||
const char *(*sourceid)(void);
|
||||
int (*stmt_status)(sqlite3_stmt*,int,int);
|
||||
int (*strnicmp)(const char*,const char*,int);
|
||||
int (*unlock_notify)(sqlite3*,void(*)(void**,int),void*);
|
||||
int (*wal_autocheckpoint)(sqlite3*,int);
|
||||
int (*wal_checkpoint)(sqlite3*,const char*);
|
||||
void *(*wal_hook)(sqlite3*,int(*)(void*,sqlite3*,const char*,int),void*);
|
||||
int (*blob_reopen)(sqlite3_blob*,sqlite3_int64);
|
||||
int (*vtab_config)(sqlite3*,int op,...);
|
||||
int (*vtab_on_conflict)(sqlite3*);
|
||||
/* Version 3.7.16 and later */
|
||||
int (*close_v2)(sqlite3*);
|
||||
const char *(*db_filename)(sqlite3*,const char*);
|
||||
int (*db_readonly)(sqlite3*,const char*);
|
||||
int (*db_release_memory)(sqlite3*);
|
||||
const char *(*errstr)(int);
|
||||
int (*stmt_busy)(sqlite3_stmt*);
|
||||
int (*stmt_readonly)(sqlite3_stmt*);
|
||||
int (*stricmp)(const char*,const char*);
|
||||
int (*uri_boolean)(const char*,const char*,int);
|
||||
sqlite3_int64 (*uri_int64)(const char*,const char*,sqlite3_int64);
|
||||
const char *(*uri_parameter)(const char*,const char*);
|
||||
char *(*vsnprintf)(int,char*,const char*,va_list);
|
||||
int (*wal_checkpoint_v2)(sqlite3*,const char*,int,int*,int*);
|
||||
};
|
||||
|
||||
/*
|
||||
** The following macros redefine the API routines so that they are
|
||||
** redirected throught the global sqlite3_api structure.
|
||||
**
|
||||
** This header file is also used by the loadext.c source file
|
||||
** (part of the main SQLite library - not an extension) so that
|
||||
** it can get access to the sqlite3_api_routines structure
|
||||
** definition. But the main library does not want to redefine
|
||||
** the API. So the redefinition macros are only valid if the
|
||||
** SQLITE_CORE macros is undefined.
|
||||
*/
|
||||
#ifndef SQLITE_CORE
|
||||
#define sqlite3_aggregate_context sqlite3_api->aggregate_context
|
||||
#ifndef SQLITE_OMIT_DEPRECATED
|
||||
#define sqlite3_aggregate_count sqlite3_api->aggregate_count
|
||||
#endif
|
||||
#define sqlite3_bind_blob sqlite3_api->bind_blob
|
||||
#define sqlite3_bind_double sqlite3_api->bind_double
|
||||
#define sqlite3_bind_int sqlite3_api->bind_int
|
||||
#define sqlite3_bind_int64 sqlite3_api->bind_int64
|
||||
#define sqlite3_bind_null sqlite3_api->bind_null
|
||||
#define sqlite3_bind_parameter_count sqlite3_api->bind_parameter_count
|
||||
#define sqlite3_bind_parameter_index sqlite3_api->bind_parameter_index
|
||||
#define sqlite3_bind_parameter_name sqlite3_api->bind_parameter_name
|
||||
#define sqlite3_bind_text sqlite3_api->bind_text
|
||||
#define sqlite3_bind_text16 sqlite3_api->bind_text16
|
||||
#define sqlite3_bind_value sqlite3_api->bind_value
|
||||
#define sqlite3_busy_handler sqlite3_api->busy_handler
|
||||
#define sqlite3_busy_timeout sqlite3_api->busy_timeout
|
||||
#define sqlite3_changes sqlite3_api->changes
|
||||
#define sqlite3_close sqlite3_api->close
|
||||
#define sqlite3_collation_needed sqlite3_api->collation_needed
|
||||
#define sqlite3_collation_needed16 sqlite3_api->collation_needed16
|
||||
#define sqlite3_column_blob sqlite3_api->column_blob
|
||||
#define sqlite3_column_bytes sqlite3_api->column_bytes
|
||||
#define sqlite3_column_bytes16 sqlite3_api->column_bytes16
|
||||
#define sqlite3_column_count sqlite3_api->column_count
|
||||
#define sqlite3_column_database_name sqlite3_api->column_database_name
|
||||
#define sqlite3_column_database_name16 sqlite3_api->column_database_name16
|
||||
#define sqlite3_column_decltype sqlite3_api->column_decltype
|
||||
#define sqlite3_column_decltype16 sqlite3_api->column_decltype16
|
||||
#define sqlite3_column_double sqlite3_api->column_double
|
||||
#define sqlite3_column_int sqlite3_api->column_int
|
||||
#define sqlite3_column_int64 sqlite3_api->column_int64
|
||||
#define sqlite3_column_name sqlite3_api->column_name
|
||||
#define sqlite3_column_name16 sqlite3_api->column_name16
|
||||
#define sqlite3_column_origin_name sqlite3_api->column_origin_name
|
||||
#define sqlite3_column_origin_name16 sqlite3_api->column_origin_name16
|
||||
#define sqlite3_column_table_name sqlite3_api->column_table_name
|
||||
#define sqlite3_column_table_name16 sqlite3_api->column_table_name16
|
||||
#define sqlite3_column_text sqlite3_api->column_text
|
||||
#define sqlite3_column_text16 sqlite3_api->column_text16
|
||||
#define sqlite3_column_type sqlite3_api->column_type
|
||||
#define sqlite3_column_value sqlite3_api->column_value
|
||||
#define sqlite3_commit_hook sqlite3_api->commit_hook
|
||||
#define sqlite3_complete sqlite3_api->complete
|
||||
#define sqlite3_complete16 sqlite3_api->complete16
|
||||
#define sqlite3_create_collation sqlite3_api->create_collation
|
||||
#define sqlite3_create_collation16 sqlite3_api->create_collation16
|
||||
#define sqlite3_create_function sqlite3_api->create_function
|
||||
#define sqlite3_create_function16 sqlite3_api->create_function16
|
||||
#define sqlite3_create_module sqlite3_api->create_module
|
||||
#define sqlite3_create_module_v2 sqlite3_api->create_module_v2
|
||||
#define sqlite3_data_count sqlite3_api->data_count
|
||||
#define sqlite3_db_handle sqlite3_api->db_handle
|
||||
#define sqlite3_declare_vtab sqlite3_api->declare_vtab
|
||||
#define sqlite3_enable_shared_cache sqlite3_api->enable_shared_cache
|
||||
#define sqlite3_errcode sqlite3_api->errcode
|
||||
#define sqlite3_errmsg sqlite3_api->errmsg
|
||||
#define sqlite3_errmsg16 sqlite3_api->errmsg16
|
||||
#define sqlite3_exec sqlite3_api->exec
|
||||
#ifndef SQLITE_OMIT_DEPRECATED
|
||||
#define sqlite3_expired sqlite3_api->expired
|
||||
#endif
|
||||
#define sqlite3_finalize sqlite3_api->finalize
|
||||
#define sqlite3_free sqlite3_api->free
|
||||
#define sqlite3_free_table sqlite3_api->free_table
|
||||
#define sqlite3_get_autocommit sqlite3_api->get_autocommit
|
||||
#define sqlite3_get_auxdata sqlite3_api->get_auxdata
|
||||
#define sqlite3_get_table sqlite3_api->get_table
|
||||
#ifndef SQLITE_OMIT_DEPRECATED
|
||||
#define sqlite3_global_recover sqlite3_api->global_recover
|
||||
#endif
|
||||
#define sqlite3_interrupt sqlite3_api->interruptx
|
||||
#define sqlite3_last_insert_rowid sqlite3_api->last_insert_rowid
|
||||
#define sqlite3_libversion sqlite3_api->libversion
|
||||
#define sqlite3_libversion_number sqlite3_api->libversion_number
|
||||
#define sqlite3_malloc sqlite3_api->malloc
|
||||
#define sqlite3_mprintf sqlite3_api->mprintf
|
||||
#define sqlite3_open sqlite3_api->open
|
||||
#define sqlite3_open16 sqlite3_api->open16
|
||||
#define sqlite3_prepare sqlite3_api->prepare
|
||||
#define sqlite3_prepare16 sqlite3_api->prepare16
|
||||
#define sqlite3_prepare_v2 sqlite3_api->prepare_v2
|
||||
#define sqlite3_prepare16_v2 sqlite3_api->prepare16_v2
|
||||
#define sqlite3_profile sqlite3_api->profile
|
||||
#define sqlite3_progress_handler sqlite3_api->progress_handler
|
||||
#define sqlite3_realloc sqlite3_api->realloc
|
||||
#define sqlite3_reset sqlite3_api->reset
|
||||
#define sqlite3_result_blob sqlite3_api->result_blob
|
||||
#define sqlite3_result_double sqlite3_api->result_double
|
||||
#define sqlite3_result_error sqlite3_api->result_error
|
||||
#define sqlite3_result_error16 sqlite3_api->result_error16
|
||||
#define sqlite3_result_int sqlite3_api->result_int
|
||||
#define sqlite3_result_int64 sqlite3_api->result_int64
|
||||
#define sqlite3_result_null sqlite3_api->result_null
|
||||
#define sqlite3_result_text sqlite3_api->result_text
|
||||
#define sqlite3_result_text16 sqlite3_api->result_text16
|
||||
#define sqlite3_result_text16be sqlite3_api->result_text16be
|
||||
#define sqlite3_result_text16le sqlite3_api->result_text16le
|
||||
#define sqlite3_result_value sqlite3_api->result_value
|
||||
#define sqlite3_rollback_hook sqlite3_api->rollback_hook
|
||||
#define sqlite3_set_authorizer sqlite3_api->set_authorizer
|
||||
#define sqlite3_set_auxdata sqlite3_api->set_auxdata
|
||||
#define sqlite3_snprintf sqlite3_api->snprintf
|
||||
#define sqlite3_step sqlite3_api->step
|
||||
#define sqlite3_table_column_metadata sqlite3_api->table_column_metadata
|
||||
#define sqlite3_thread_cleanup sqlite3_api->thread_cleanup
|
||||
#define sqlite3_total_changes sqlite3_api->total_changes
|
||||
#define sqlite3_trace sqlite3_api->trace
|
||||
#ifndef SQLITE_OMIT_DEPRECATED
|
||||
#define sqlite3_transfer_bindings sqlite3_api->transfer_bindings
|
||||
#endif
|
||||
#define sqlite3_update_hook sqlite3_api->update_hook
|
||||
#define sqlite3_user_data sqlite3_api->user_data
|
||||
#define sqlite3_value_blob sqlite3_api->value_blob
|
||||
#define sqlite3_value_bytes sqlite3_api->value_bytes
|
||||
#define sqlite3_value_bytes16 sqlite3_api->value_bytes16
|
||||
#define sqlite3_value_double sqlite3_api->value_double
|
||||
#define sqlite3_value_int sqlite3_api->value_int
|
||||
#define sqlite3_value_int64 sqlite3_api->value_int64
|
||||
#define sqlite3_value_numeric_type sqlite3_api->value_numeric_type
|
||||
#define sqlite3_value_text sqlite3_api->value_text
|
||||
#define sqlite3_value_text16 sqlite3_api->value_text16
|
||||
#define sqlite3_value_text16be sqlite3_api->value_text16be
|
||||
#define sqlite3_value_text16le sqlite3_api->value_text16le
|
||||
#define sqlite3_value_type sqlite3_api->value_type
|
||||
#define sqlite3_vmprintf sqlite3_api->vmprintf
|
||||
#define sqlite3_overload_function sqlite3_api->overload_function
|
||||
#define sqlite3_prepare_v2 sqlite3_api->prepare_v2
|
||||
#define sqlite3_prepare16_v2 sqlite3_api->prepare16_v2
|
||||
#define sqlite3_clear_bindings sqlite3_api->clear_bindings
|
||||
#define sqlite3_bind_zeroblob sqlite3_api->bind_zeroblob
|
||||
#define sqlite3_blob_bytes sqlite3_api->blob_bytes
|
||||
#define sqlite3_blob_close sqlite3_api->blob_close
|
||||
#define sqlite3_blob_open sqlite3_api->blob_open
|
||||
#define sqlite3_blob_read sqlite3_api->blob_read
|
||||
#define sqlite3_blob_write sqlite3_api->blob_write
|
||||
#define sqlite3_create_collation_v2 sqlite3_api->create_collation_v2
|
||||
#define sqlite3_file_control sqlite3_api->file_control
|
||||
#define sqlite3_memory_highwater sqlite3_api->memory_highwater
|
||||
#define sqlite3_memory_used sqlite3_api->memory_used
|
||||
#define sqlite3_mutex_alloc sqlite3_api->mutex_alloc
|
||||
#define sqlite3_mutex_enter sqlite3_api->mutex_enter
|
||||
#define sqlite3_mutex_free sqlite3_api->mutex_free
|
||||
#define sqlite3_mutex_leave sqlite3_api->mutex_leave
|
||||
#define sqlite3_mutex_try sqlite3_api->mutex_try
|
||||
#define sqlite3_open_v2 sqlite3_api->open_v2
|
||||
#define sqlite3_release_memory sqlite3_api->release_memory
|
||||
#define sqlite3_result_error_nomem sqlite3_api->result_error_nomem
|
||||
#define sqlite3_result_error_toobig sqlite3_api->result_error_toobig
|
||||
#define sqlite3_sleep sqlite3_api->sleep
|
||||
#define sqlite3_soft_heap_limit sqlite3_api->soft_heap_limit
|
||||
#define sqlite3_vfs_find sqlite3_api->vfs_find
|
||||
#define sqlite3_vfs_register sqlite3_api->vfs_register
|
||||
#define sqlite3_vfs_unregister sqlite3_api->vfs_unregister
|
||||
#define sqlite3_threadsafe sqlite3_api->xthreadsafe
|
||||
#define sqlite3_result_zeroblob sqlite3_api->result_zeroblob
|
||||
#define sqlite3_result_error_code sqlite3_api->result_error_code
|
||||
#define sqlite3_test_control sqlite3_api->test_control
|
||||
#define sqlite3_randomness sqlite3_api->randomness
|
||||
#define sqlite3_context_db_handle sqlite3_api->context_db_handle
|
||||
#define sqlite3_extended_result_codes sqlite3_api->extended_result_codes
|
||||
#define sqlite3_limit sqlite3_api->limit
|
||||
#define sqlite3_next_stmt sqlite3_api->next_stmt
|
||||
#define sqlite3_sql sqlite3_api->sql
|
||||
#define sqlite3_status sqlite3_api->status
|
||||
#define sqlite3_backup_finish sqlite3_api->backup_finish
|
||||
#define sqlite3_backup_init sqlite3_api->backup_init
|
||||
#define sqlite3_backup_pagecount sqlite3_api->backup_pagecount
|
||||
#define sqlite3_backup_remaining sqlite3_api->backup_remaining
|
||||
#define sqlite3_backup_step sqlite3_api->backup_step
|
||||
#define sqlite3_compileoption_get sqlite3_api->compileoption_get
|
||||
#define sqlite3_compileoption_used sqlite3_api->compileoption_used
|
||||
#define sqlite3_create_function_v2 sqlite3_api->create_function_v2
|
||||
#define sqlite3_db_config sqlite3_api->db_config
|
||||
#define sqlite3_db_mutex sqlite3_api->db_mutex
|
||||
#define sqlite3_db_status sqlite3_api->db_status
|
||||
#define sqlite3_extended_errcode sqlite3_api->extended_errcode
|
||||
#define sqlite3_log sqlite3_api->log
|
||||
#define sqlite3_soft_heap_limit64 sqlite3_api->soft_heap_limit64
|
||||
#define sqlite3_sourceid sqlite3_api->sourceid
|
||||
#define sqlite3_stmt_status sqlite3_api->stmt_status
|
||||
#define sqlite3_strnicmp sqlite3_api->strnicmp
|
||||
#define sqlite3_unlock_notify sqlite3_api->unlock_notify
|
||||
#define sqlite3_wal_autocheckpoint sqlite3_api->wal_autocheckpoint
|
||||
#define sqlite3_wal_checkpoint sqlite3_api->wal_checkpoint
|
||||
#define sqlite3_wal_hook sqlite3_api->wal_hook
|
||||
#define sqlite3_blob_reopen sqlite3_api->blob_reopen
|
||||
#define sqlite3_vtab_config sqlite3_api->vtab_config
|
||||
#define sqlite3_vtab_on_conflict sqlite3_api->vtab_on_conflict
|
||||
/* Version 3.7.16 and later */
|
||||
#define sqlite3_close_v2 sqlite3_api->close_v2
|
||||
#define sqlite3_db_filename sqlite3_api->db_filename
|
||||
#define sqlite3_db_readonly sqlite3_api->db_readonly
|
||||
#define sqlite3_db_release_memory sqlite3_api->db_release_memory
|
||||
#define sqlite3_errstr sqlite3_api->errstr
|
||||
#define sqlite3_stmt_busy sqlite3_api->stmt_busy
|
||||
#define sqlite3_stmt_readonly sqlite3_api->stmt_readonly
|
||||
#define sqlite3_stricmp sqlite3_api->stricmp
|
||||
#define sqlite3_uri_boolean sqlite3_api->uri_boolean
|
||||
#define sqlite3_uri_int64 sqlite3_api->uri_int64
|
||||
#define sqlite3_uri_parameter sqlite3_api->uri_parameter
|
||||
#define sqlite3_uri_vsnprintf sqlite3_api->vsnprintf
|
||||
#define sqlite3_wal_checkpoint_v2 sqlite3_api->wal_checkpoint_v2
|
||||
#endif /* SQLITE_CORE */
|
||||
|
||||
#ifndef SQLITE_CORE
|
||||
/* This case when the file really is being compiled as a loadable
|
||||
** extension */
|
||||
# define SQLITE_EXTENSION_INIT1 const sqlite3_api_routines *sqlite3_api=0;
|
||||
# define SQLITE_EXTENSION_INIT2(v) sqlite3_api=v;
|
||||
# define SQLITE_EXTENSION_INIT3 \
|
||||
extern const sqlite3_api_routines *sqlite3_api;
|
||||
#else
|
||||
/* This case when the file is being statically linked into the
|
||||
** application */
|
||||
# define SQLITE_EXTENSION_INIT1 /*no-op*/
|
||||
# define SQLITE_EXTENSION_INIT2(v) (void)v; /* unused parameter */
|
||||
# define SQLITE_EXTENSION_INIT3 /*no-op*/
|
||||
#endif
|
||||
|
||||
#endif /* _SQLITE3EXT_H_ */
|
112
Godeps/_workspace/src/github.com/mattn/go-sqlite3/doc.go
generated
vendored
Normal file
112
Godeps/_workspace/src/github.com/mattn/go-sqlite3/doc.go
generated
vendored
Normal file
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
Package sqlite3 provides interface to SQLite3 databases.
|
||||
|
||||
This works as driver for database/sql.
|
||||
|
||||
Installation
|
||||
|
||||
go get github.com/mattn/go-sqlite3
|
||||
|
||||
Supported Types
|
||||
|
||||
Currently, go-sqlite3 support following data types.
|
||||
|
||||
+------------------------------+
|
||||
|go | sqlite3 |
|
||||
|----------|-------------------|
|
||||
|nil | null |
|
||||
|int | integer |
|
||||
|int64 | integer |
|
||||
|float64 | float |
|
||||
|bool | integer |
|
||||
|[]byte | blob |
|
||||
|string | text |
|
||||
|time.Time | timestamp/datetime|
|
||||
+------------------------------+
|
||||
|
||||
SQLite3 Extension
|
||||
|
||||
You can write your own extension module for sqlite3. For example, below is a
|
||||
extension for Regexp matcher operation.
|
||||
|
||||
#include <pcre.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#include <sqlite3ext.h>
|
||||
|
||||
SQLITE_EXTENSION_INIT1
|
||||
static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) {
|
||||
if (argc >= 2) {
|
||||
const char *target = (const char *)sqlite3_value_text(argv[1]);
|
||||
const char *pattern = (const char *)sqlite3_value_text(argv[0]);
|
||||
const char* errstr = NULL;
|
||||
int erroff = 0;
|
||||
int vec[500];
|
||||
int n, rc;
|
||||
pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL);
|
||||
rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500);
|
||||
if (rc <= 0) {
|
||||
sqlite3_result_error(context, errstr, 0);
|
||||
return;
|
||||
}
|
||||
sqlite3_result_int(context, 1);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
int sqlite3_extension_init(sqlite3 *db, char **errmsg,
|
||||
const sqlite3_api_routines *api) {
|
||||
SQLITE_EXTENSION_INIT2(api);
|
||||
return sqlite3_create_function(db, "regexp", 2, SQLITE_UTF8,
|
||||
(void*)db, regexp_func, NULL, NULL);
|
||||
}
|
||||
|
||||
It need to build as so/dll shared library. And you need to register
|
||||
extension module like below.
|
||||
|
||||
sql.Register("sqlite3_with_extensions",
|
||||
&sqlite3.SQLiteDriver{
|
||||
Extensions: []string{
|
||||
"sqlite3_mod_regexp",
|
||||
},
|
||||
})
|
||||
|
||||
Then, you can use this extension.
|
||||
|
||||
rows, err := db.Query("select text from mytable where name regexp '^golang'")
|
||||
|
||||
Connection Hook
|
||||
|
||||
You can hook and inject your codes when connection established. database/sql
|
||||
doesn't provide the way to get native go-sqlite3 interfaces. So if you want,
|
||||
you need to hook ConnectHook and get the SQLiteConn.
|
||||
|
||||
sql.Register("sqlite3_with_hook_example",
|
||||
&sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
sqlite3conn = append(sqlite3conn, conn)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
Go SQlite3 Extensions
|
||||
|
||||
If you want to register Go functions as SQLite extension functions,
|
||||
call RegisterFunction from ConnectHook.
|
||||
|
||||
regex = func(re, s string) (bool, error) {
|
||||
return regexp.MatchString(re, s)
|
||||
}
|
||||
sql.Register("sqlite3_with_go_func",
|
||||
&sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
return conn.RegisterFunc("regex", regex, true)
|
||||
},
|
||||
})
|
||||
|
||||
See the documentation of RegisterFunc for more details.
|
||||
|
||||
*/
|
||||
package sqlite3
|
128
Godeps/_workspace/src/github.com/mattn/go-sqlite3/error.go
generated
vendored
Normal file
128
Godeps/_workspace/src/github.com/mattn/go-sqlite3/error.go
generated
vendored
Normal file
|
@ -0,0 +1,128 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import "C"
|
||||
|
||||
type ErrNo int
|
||||
|
||||
const ErrNoMask C.int = 0xff
|
||||
|
||||
type ErrNoExtended int
|
||||
|
||||
type Error struct {
|
||||
Code ErrNo /* The error code returned by SQLite */
|
||||
ExtendedCode ErrNoExtended /* The extended error code returned by SQLite */
|
||||
err string /* The error string returned by sqlite3_errmsg(),
|
||||
this usually contains more specific details. */
|
||||
}
|
||||
|
||||
// result codes from http://www.sqlite.org/c3ref/c_abort.html
|
||||
var (
|
||||
ErrError = ErrNo(1) /* SQL error or missing database */
|
||||
ErrInternal = ErrNo(2) /* Internal logic error in SQLite */
|
||||
ErrPerm = ErrNo(3) /* Access permission denied */
|
||||
ErrAbort = ErrNo(4) /* Callback routine requested an abort */
|
||||
ErrBusy = ErrNo(5) /* The database file is locked */
|
||||
ErrLocked = ErrNo(6) /* A table in the database is locked */
|
||||
ErrNomem = ErrNo(7) /* A malloc() failed */
|
||||
ErrReadonly = ErrNo(8) /* Attempt to write a readonly database */
|
||||
ErrInterrupt = ErrNo(9) /* Operation terminated by sqlite3_interrupt() */
|
||||
ErrIoErr = ErrNo(10) /* Some kind of disk I/O error occurred */
|
||||
ErrCorrupt = ErrNo(11) /* The database disk image is malformed */
|
||||
ErrNotFound = ErrNo(12) /* Unknown opcode in sqlite3_file_control() */
|
||||
ErrFull = ErrNo(13) /* Insertion failed because database is full */
|
||||
ErrCantOpen = ErrNo(14) /* Unable to open the database file */
|
||||
ErrProtocol = ErrNo(15) /* Database lock protocol error */
|
||||
ErrEmpty = ErrNo(16) /* Database is empty */
|
||||
ErrSchema = ErrNo(17) /* The database schema changed */
|
||||
ErrTooBig = ErrNo(18) /* String or BLOB exceeds size limit */
|
||||
ErrConstraint = ErrNo(19) /* Abort due to constraint violation */
|
||||
ErrMismatch = ErrNo(20) /* Data type mismatch */
|
||||
ErrMisuse = ErrNo(21) /* Library used incorrectly */
|
||||
ErrNoLFS = ErrNo(22) /* Uses OS features not supported on host */
|
||||
ErrAuth = ErrNo(23) /* Authorization denied */
|
||||
ErrFormat = ErrNo(24) /* Auxiliary database format error */
|
||||
ErrRange = ErrNo(25) /* 2nd parameter to sqlite3_bind out of range */
|
||||
ErrNotADB = ErrNo(26) /* File opened that is not a database file */
|
||||
ErrNotice = ErrNo(27) /* Notifications from sqlite3_log() */
|
||||
ErrWarning = ErrNo(28) /* Warnings from sqlite3_log() */
|
||||
)
|
||||
|
||||
func (err ErrNo) Error() string {
|
||||
return Error{Code: err}.Error()
|
||||
}
|
||||
|
||||
func (err ErrNo) Extend(by int) ErrNoExtended {
|
||||
return ErrNoExtended(int(err) | (by << 8))
|
||||
}
|
||||
|
||||
func (err ErrNoExtended) Error() string {
|
||||
return Error{Code: ErrNo(C.int(err) & ErrNoMask), ExtendedCode: err}.Error()
|
||||
}
|
||||
|
||||
func (err Error) Error() string {
|
||||
if err.err != "" {
|
||||
return err.err
|
||||
}
|
||||
return errorString(err)
|
||||
}
|
||||
|
||||
// result codes from http://www.sqlite.org/c3ref/c_abort_rollback.html
|
||||
var (
|
||||
ErrIoErrRead = ErrIoErr.Extend(1)
|
||||
ErrIoErrShortRead = ErrIoErr.Extend(2)
|
||||
ErrIoErrWrite = ErrIoErr.Extend(3)
|
||||
ErrIoErrFsync = ErrIoErr.Extend(4)
|
||||
ErrIoErrDirFsync = ErrIoErr.Extend(5)
|
||||
ErrIoErrTruncate = ErrIoErr.Extend(6)
|
||||
ErrIoErrFstat = ErrIoErr.Extend(7)
|
||||
ErrIoErrUnlock = ErrIoErr.Extend(8)
|
||||
ErrIoErrRDlock = ErrIoErr.Extend(9)
|
||||
ErrIoErrDelete = ErrIoErr.Extend(10)
|
||||
ErrIoErrBlocked = ErrIoErr.Extend(11)
|
||||
ErrIoErrNoMem = ErrIoErr.Extend(12)
|
||||
ErrIoErrAccess = ErrIoErr.Extend(13)
|
||||
ErrIoErrCheckReservedLock = ErrIoErr.Extend(14)
|
||||
ErrIoErrLock = ErrIoErr.Extend(15)
|
||||
ErrIoErrClose = ErrIoErr.Extend(16)
|
||||
ErrIoErrDirClose = ErrIoErr.Extend(17)
|
||||
ErrIoErrSHMOpen = ErrIoErr.Extend(18)
|
||||
ErrIoErrSHMSize = ErrIoErr.Extend(19)
|
||||
ErrIoErrSHMLock = ErrIoErr.Extend(20)
|
||||
ErrIoErrSHMMap = ErrIoErr.Extend(21)
|
||||
ErrIoErrSeek = ErrIoErr.Extend(22)
|
||||
ErrIoErrDeleteNoent = ErrIoErr.Extend(23)
|
||||
ErrIoErrMMap = ErrIoErr.Extend(24)
|
||||
ErrIoErrGetTempPath = ErrIoErr.Extend(25)
|
||||
ErrIoErrConvPath = ErrIoErr.Extend(26)
|
||||
ErrLockedSharedCache = ErrLocked.Extend(1)
|
||||
ErrBusyRecovery = ErrBusy.Extend(1)
|
||||
ErrBusySnapshot = ErrBusy.Extend(2)
|
||||
ErrCantOpenNoTempDir = ErrCantOpen.Extend(1)
|
||||
ErrCantOpenIsDir = ErrCantOpen.Extend(2)
|
||||
ErrCantOpenFullPath = ErrCantOpen.Extend(3)
|
||||
ErrCantOpenConvPath = ErrCantOpen.Extend(4)
|
||||
ErrCorruptVTab = ErrCorrupt.Extend(1)
|
||||
ErrReadonlyRecovery = ErrReadonly.Extend(1)
|
||||
ErrReadonlyCantLock = ErrReadonly.Extend(2)
|
||||
ErrReadonlyRollback = ErrReadonly.Extend(3)
|
||||
ErrReadonlyDbMoved = ErrReadonly.Extend(4)
|
||||
ErrAbortRollback = ErrAbort.Extend(2)
|
||||
ErrConstraintCheck = ErrConstraint.Extend(1)
|
||||
ErrConstraintCommitHook = ErrConstraint.Extend(2)
|
||||
ErrConstraintForeignKey = ErrConstraint.Extend(3)
|
||||
ErrConstraintFunction = ErrConstraint.Extend(4)
|
||||
ErrConstraintNotNull = ErrConstraint.Extend(5)
|
||||
ErrConstraintPrimaryKey = ErrConstraint.Extend(6)
|
||||
ErrConstraintTrigger = ErrConstraint.Extend(7)
|
||||
ErrConstraintUnique = ErrConstraint.Extend(8)
|
||||
ErrConstraintVTab = ErrConstraint.Extend(9)
|
||||
ErrConstraintRowId = ErrConstraint.Extend(10)
|
||||
ErrNoticeRecoverWAL = ErrNotice.Extend(1)
|
||||
ErrNoticeRecoverRollback = ErrNotice.Extend(2)
|
||||
ErrWarningAutoIndex = ErrWarning.Extend(1)
|
||||
)
|
4
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3-binding.c
generated
vendored
Normal file
4
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3-binding.c
generated
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
#ifndef USE_LIBSQLITE3
|
||||
# include "code/sqlite3-binding.c"
|
||||
#endif
|
||||
|
5
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3-binding.h
generated
vendored
Normal file
5
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3-binding.h
generated
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
#ifndef USE_LIBSQLITE3
|
||||
#include "code/sqlite3-binding.h"
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
977
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3.go
generated
vendored
Normal file
977
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3.go
generated
vendored
Normal file
|
@ -0,0 +1,977 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -std=gnu99
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
|
||||
#include <sqlite3-binding.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef __CYGWIN__
|
||||
# include <errno.h>
|
||||
#endif
|
||||
|
||||
#ifndef SQLITE_OPEN_READWRITE
|
||||
# define SQLITE_OPEN_READWRITE 0
|
||||
#endif
|
||||
|
||||
#ifndef SQLITE_OPEN_FULLMUTEX
|
||||
# define SQLITE_OPEN_FULLMUTEX 0
|
||||
#endif
|
||||
|
||||
static int
|
||||
_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) {
|
||||
#ifdef SQLITE_OPEN_URI
|
||||
return sqlite3_open_v2(filename, ppDb, flags | SQLITE_OPEN_URI, zVfs);
|
||||
#else
|
||||
return sqlite3_open_v2(filename, ppDb, flags, zVfs);
|
||||
#endif
|
||||
}
|
||||
|
||||
static int
|
||||
_sqlite3_bind_text(sqlite3_stmt *stmt, int n, char *p, int np) {
|
||||
return sqlite3_bind_text(stmt, n, p, np, SQLITE_TRANSIENT);
|
||||
}
|
||||
|
||||
static int
|
||||
_sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
|
||||
return sqlite3_bind_blob(stmt, n, p, np, SQLITE_TRANSIENT);
|
||||
}
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
|
||||
static int
|
||||
_sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* changes)
|
||||
{
|
||||
int rv = sqlite3_exec(db, pcmd, 0, 0, 0);
|
||||
*rowid = (long long) sqlite3_last_insert_rowid(db);
|
||||
*changes = (long long) sqlite3_changes(db);
|
||||
return rv;
|
||||
}
|
||||
|
||||
static int
|
||||
_sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes)
|
||||
{
|
||||
int rv = sqlite3_step(stmt);
|
||||
sqlite3* db = sqlite3_db_handle(stmt);
|
||||
*rowid = (long long) sqlite3_last_insert_rowid(db);
|
||||
*changes = (long long) sqlite3_changes(db);
|
||||
return rv;
|
||||
}
|
||||
|
||||
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
|
||||
sqlite3_result_text(ctx, s, -1, &free);
|
||||
}
|
||||
|
||||
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) {
|
||||
sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT);
|
||||
}
|
||||
|
||||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
void doneTrampoline(sqlite3_context*);
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Timestamp formats understood by both this module and SQLite.
|
||||
// The first format in the slice will be used when saving time values
|
||||
// into the database. When parsing a string from a timestamp or
|
||||
// datetime column, the formats are tried in order.
|
||||
var SQLiteTimestampFormats = []string{
|
||||
// By default, store timestamps with whatever timezone they come with.
|
||||
// When parsed, they will be returned with the same timezone.
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02T15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04",
|
||||
"2006-01-02T15:04",
|
||||
"2006-01-02",
|
||||
}
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3", &SQLiteDriver{})
|
||||
}
|
||||
|
||||
// Return SQLite library Version information.
|
||||
func Version() (libVersion string, libVersionNumber int, sourceId string) {
|
||||
libVersion = C.GoString(C.sqlite3_libversion())
|
||||
libVersionNumber = int(C.sqlite3_libversion_number())
|
||||
sourceId = C.GoString(C.sqlite3_sourceid())
|
||||
return libVersion, libVersionNumber, sourceId
|
||||
}
|
||||
|
||||
// Driver struct.
|
||||
type SQLiteDriver struct {
|
||||
Extensions []string
|
||||
ConnectHook func(*SQLiteConn) error
|
||||
}
|
||||
|
||||
// Conn struct.
|
||||
type SQLiteConn struct {
|
||||
db *C.sqlite3
|
||||
loc *time.Location
|
||||
txlock string
|
||||
funcs []*functionInfo
|
||||
aggregators []*aggInfo
|
||||
}
|
||||
|
||||
// Tx struct.
|
||||
type SQLiteTx struct {
|
||||
c *SQLiteConn
|
||||
}
|
||||
|
||||
// Stmt struct.
|
||||
type SQLiteStmt struct {
|
||||
c *SQLiteConn
|
||||
s *C.sqlite3_stmt
|
||||
nv int
|
||||
nn []string
|
||||
t string
|
||||
closed bool
|
||||
cls bool
|
||||
}
|
||||
|
||||
// Result struct.
|
||||
type SQLiteResult struct {
|
||||
id int64
|
||||
changes int64
|
||||
}
|
||||
|
||||
// Rows struct.
|
||||
type SQLiteRows struct {
|
||||
s *SQLiteStmt
|
||||
nc int
|
||||
cols []string
|
||||
decltype []string
|
||||
cls bool
|
||||
}
|
||||
|
||||
type functionInfo struct {
|
||||
f reflect.Value
|
||||
argConverters []callbackArgConverter
|
||||
variadicConverter callbackArgConverter
|
||||
retConverter callbackRetConverter
|
||||
}
|
||||
|
||||
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
|
||||
args, err := callbackConvertArgs(argv, fi.argConverters, fi.variadicConverter)
|
||||
if err != nil {
|
||||
callbackError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
ret := fi.f.Call(args)
|
||||
|
||||
if len(ret) == 2 && ret[1].Interface() != nil {
|
||||
callbackError(ctx, ret[1].Interface().(error))
|
||||
return
|
||||
}
|
||||
|
||||
err = fi.retConverter(ctx, ret[0])
|
||||
if err != nil {
|
||||
callbackError(ctx, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type aggInfo struct {
|
||||
constructor reflect.Value
|
||||
|
||||
// Active aggregator objects for aggregations in flight. The
|
||||
// aggregators are indexed by a counter stored in the aggregation
|
||||
// user data space provided by sqlite.
|
||||
active map[int64]reflect.Value
|
||||
next int64
|
||||
|
||||
stepArgConverters []callbackArgConverter
|
||||
stepVariadicConverter callbackArgConverter
|
||||
|
||||
doneRetConverter callbackRetConverter
|
||||
}
|
||||
|
||||
func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
|
||||
aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8)))
|
||||
if *aggIdx == 0 {
|
||||
*aggIdx = ai.next
|
||||
ret := ai.constructor.Call(nil)
|
||||
if len(ret) == 2 && ret[1].Interface() != nil {
|
||||
return 0, reflect.Value{}, ret[1].Interface().(error)
|
||||
}
|
||||
if ret[0].IsNil() {
|
||||
return 0, reflect.Value{}, errors.New("aggregator constructor returned nil state")
|
||||
}
|
||||
ai.next++
|
||||
ai.active[*aggIdx] = ret[0]
|
||||
}
|
||||
return *aggIdx, ai.active[*aggIdx], nil
|
||||
}
|
||||
|
||||
func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
|
||||
_, agg, err := ai.agg(ctx)
|
||||
if err != nil {
|
||||
callbackError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
args, err := callbackConvertArgs(argv, ai.stepArgConverters, ai.stepVariadicConverter)
|
||||
if err != nil {
|
||||
callbackError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
ret := agg.MethodByName("Step").Call(args)
|
||||
if len(ret) == 1 && ret[0].Interface() != nil {
|
||||
callbackError(ctx, ret[0].Interface().(error))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
|
||||
idx, agg, err := ai.agg(ctx)
|
||||
if err != nil {
|
||||
callbackError(ctx, err)
|
||||
return
|
||||
}
|
||||
defer func() { delete(ai.active, idx) }()
|
||||
|
||||
ret := agg.MethodByName("Done").Call(nil)
|
||||
if len(ret) == 2 && ret[1].Interface() != nil {
|
||||
callbackError(ctx, ret[1].Interface().(error))
|
||||
return
|
||||
}
|
||||
|
||||
err = ai.doneRetConverter(ctx, ret[0])
|
||||
if err != nil {
|
||||
callbackError(ctx, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Commit transaction.
|
||||
func (tx *SQLiteTx) Commit() error {
|
||||
_, err := tx.c.exec("COMMIT")
|
||||
return err
|
||||
}
|
||||
|
||||
// Rollback transaction.
|
||||
func (tx *SQLiteTx) Rollback() error {
|
||||
_, err := tx.c.exec("ROLLBACK")
|
||||
return err
|
||||
}
|
||||
|
||||
// RegisterFunc makes a Go function available as a SQLite function.
|
||||
//
|
||||
// The Go function can have arguments of the following types: any
|
||||
// numeric type except complex, bool, []byte, string and
|
||||
// interface{}. interface{} arguments are given the direct translation
|
||||
// of the SQLite data type: int64 for INTEGER, float64 for FLOAT,
|
||||
// []byte for BLOB, string for TEXT.
|
||||
//
|
||||
// The function can additionally be variadic, as long as the type of
|
||||
// the variadic argument is one of the above.
|
||||
//
|
||||
// If pure is true. SQLite will assume that the function's return
|
||||
// value depends only on its inputs, and make more aggressive
|
||||
// optimizations in its queries.
|
||||
//
|
||||
// See _example/go_custom_funcs for a detailed example.
|
||||
func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error {
|
||||
var fi functionInfo
|
||||
fi.f = reflect.ValueOf(impl)
|
||||
t := fi.f.Type()
|
||||
if t.Kind() != reflect.Func {
|
||||
return errors.New("Non-function passed to RegisterFunc")
|
||||
}
|
||||
if t.NumOut() != 1 && t.NumOut() != 2 {
|
||||
return errors.New("SQLite functions must return 1 or 2 values")
|
||||
}
|
||||
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("Second return value of SQLite function must be error")
|
||||
}
|
||||
|
||||
numArgs := t.NumIn()
|
||||
if t.IsVariadic() {
|
||||
numArgs--
|
||||
}
|
||||
|
||||
for i := 0; i < numArgs; i++ {
|
||||
conv, err := callbackArg(t.In(i))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fi.argConverters = append(fi.argConverters, conv)
|
||||
}
|
||||
|
||||
if t.IsVariadic() {
|
||||
conv, err := callbackArg(t.In(numArgs).Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fi.variadicConverter = conv
|
||||
// Pass -1 to sqlite so that it allows any number of
|
||||
// arguments. The call helper verifies that the minimum number
|
||||
// of arguments is present for variadic functions.
|
||||
numArgs = -1
|
||||
}
|
||||
|
||||
conv, err := callbackRet(t.Out(0))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fi.retConverter = conv
|
||||
|
||||
// fi must outlast the database connection, or we'll have dangling pointers.
|
||||
c.funcs = append(c.funcs, &fi)
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
opts := C.SQLITE_UTF8
|
||||
if pure {
|
||||
opts |= C.SQLITE_DETERMINISTIC
|
||||
}
|
||||
rv := C.sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
|
||||
//
|
||||
// Because aggregation is incremental, it's implemented in Go with a
|
||||
// type that has 2 methods: func Step(values) accumulates one row of
|
||||
// data into the accumulator, and func Done() ret finalizes and
|
||||
// returns the aggregate value. "values" and "ret" may be any type
|
||||
// supported by RegisterFunc.
|
||||
//
|
||||
// RegisterAggregator takes as implementation a constructor function
|
||||
// that constructs an instance of the aggregator type each time an
|
||||
// aggregation begins. The constructor must return a pointer to a
|
||||
// type, or an interface that implements Step() and Done().
|
||||
//
|
||||
// The constructor function and the Step/Done methods may optionally
|
||||
// return an error in addition to their other return values.
|
||||
//
|
||||
// See _example/go_custom_funcs for a detailed example.
|
||||
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
|
||||
var ai aggInfo
|
||||
ai.constructor = reflect.ValueOf(impl)
|
||||
t := ai.constructor.Type()
|
||||
if t.Kind() != reflect.Func {
|
||||
return errors.New("non-function passed to RegisterAggregator")
|
||||
}
|
||||
if t.NumOut() != 1 && t.NumOut() != 2 {
|
||||
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
|
||||
}
|
||||
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("Second return value of SQLite function must be error")
|
||||
}
|
||||
if t.NumIn() != 0 {
|
||||
return errors.New("SQLite aggregator constructors must not have arguments")
|
||||
}
|
||||
|
||||
agg := t.Out(0)
|
||||
switch agg.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
default:
|
||||
return errors.New("SQlite aggregator constructor must return a pointer object")
|
||||
}
|
||||
stepFn, found := agg.MethodByName("Step")
|
||||
if !found {
|
||||
return errors.New("SQlite aggregator doesn't have a Step() function")
|
||||
}
|
||||
step := stepFn.Type
|
||||
if step.NumOut() != 0 && step.NumOut() != 1 {
|
||||
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
|
||||
}
|
||||
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("type of SQlite aggregator Step() return value must be error")
|
||||
}
|
||||
|
||||
stepNArgs := step.NumIn()
|
||||
start := 0
|
||||
if agg.Kind() == reflect.Ptr {
|
||||
// Skip over the method receiver
|
||||
stepNArgs--
|
||||
start++
|
||||
}
|
||||
if step.IsVariadic() {
|
||||
stepNArgs--
|
||||
}
|
||||
for i := start; i < start+stepNArgs; i++ {
|
||||
conv, err := callbackArg(step.In(i))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai.stepArgConverters = append(ai.stepArgConverters, conv)
|
||||
}
|
||||
if step.IsVariadic() {
|
||||
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai.stepVariadicConverter = conv
|
||||
// Pass -1 to sqlite so that it allows any number of
|
||||
// arguments. The call helper verifies that the minimum number
|
||||
// of arguments is present for variadic functions.
|
||||
stepNArgs = -1
|
||||
}
|
||||
|
||||
doneFn, found := agg.MethodByName("Done")
|
||||
if !found {
|
||||
return errors.New("SQlite aggregator doesn't have a Done() function")
|
||||
}
|
||||
done := doneFn.Type
|
||||
doneNArgs := done.NumIn()
|
||||
if agg.Kind() == reflect.Ptr {
|
||||
// Skip over the method receiver
|
||||
doneNArgs--
|
||||
}
|
||||
if doneNArgs != 0 {
|
||||
return errors.New("SQlite aggregator Done() function must have no arguments")
|
||||
}
|
||||
if done.NumOut() != 1 && done.NumOut() != 2 {
|
||||
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
|
||||
}
|
||||
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("second return value of SQLite aggregator Done() function must be error")
|
||||
}
|
||||
|
||||
conv, err := callbackRet(done.Out(0))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai.doneRetConverter = conv
|
||||
ai.active = make(map[int64]reflect.Value)
|
||||
ai.next = 1
|
||||
|
||||
// ai must outlast the database connection, or we'll have dangling pointers.
|
||||
c.aggregators = append(c.aggregators, &ai)
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
opts := C.SQLITE_UTF8
|
||||
if pure {
|
||||
opts |= C.SQLITE_DETERMINISTIC
|
||||
}
|
||||
rv := C.sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), unsafe.Pointer(&ai), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AutoCommit return which currently auto commit or not.
|
||||
func (c *SQLiteConn) AutoCommit() bool {
|
||||
return int(C.sqlite3_get_autocommit(c.db)) != 0
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) lastError() Error {
|
||||
return Error{
|
||||
Code: ErrNo(C.sqlite3_errcode(c.db)),
|
||||
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)),
|
||||
err: C.GoString(C.sqlite3_errmsg(c.db)),
|
||||
}
|
||||
}
|
||||
|
||||
// Implements Execer
|
||||
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
if len(args) == 0 {
|
||||
return c.exec(query)
|
||||
}
|
||||
|
||||
for {
|
||||
s, err := c.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var res driver.Result
|
||||
if s.(*SQLiteStmt).s != nil {
|
||||
na := s.NumInput()
|
||||
if len(args) < na {
|
||||
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
|
||||
}
|
||||
res, err = s.Exec(args[:na])
|
||||
if err != nil && err != driver.ErrSkip {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
args = args[na:]
|
||||
}
|
||||
tail := s.(*SQLiteStmt).t
|
||||
s.Close()
|
||||
if tail == "" {
|
||||
return res, nil
|
||||
}
|
||||
query = tail
|
||||
}
|
||||
}
|
||||
|
||||
// Implements Queryer
|
||||
func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
for {
|
||||
s, err := c.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.(*SQLiteStmt).cls = true
|
||||
na := s.NumInput()
|
||||
if len(args) < na {
|
||||
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
|
||||
}
|
||||
rows, err := s.Query(args[:na])
|
||||
if err != nil && err != driver.ErrSkip {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
args = args[na:]
|
||||
tail := s.(*SQLiteStmt).t
|
||||
if tail == "" {
|
||||
return rows, nil
|
||||
}
|
||||
rows.Close()
|
||||
s.Close()
|
||||
query = tail
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
|
||||
pcmd := C.CString(cmd)
|
||||
defer C.free(unsafe.Pointer(pcmd))
|
||||
|
||||
var rowid, changes C.longlong
|
||||
rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
|
||||
if rv != C.SQLITE_OK {
|
||||
return nil, c.lastError()
|
||||
}
|
||||
return &SQLiteResult{int64(rowid), int64(changes)}, nil
|
||||
}
|
||||
|
||||
// Begin transaction.
|
||||
func (c *SQLiteConn) Begin() (driver.Tx, error) {
|
||||
if _, err := c.exec(c.txlock); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SQLiteTx{c}, nil
|
||||
}
|
||||
|
||||
func errorString(err Error) string {
|
||||
return C.GoString(C.sqlite3_errstr(C.int(err.Code)))
|
||||
}
|
||||
|
||||
// Open database and return a new connection.
|
||||
// You can specify DSN string with URI filename.
|
||||
// test.db
|
||||
// file:test.db?cache=shared&mode=memory
|
||||
// :memory:
|
||||
// file::memory:
|
||||
// go-sqlite3 adds the following query parameters to those used by SQLite:
|
||||
// _loc=XXX
|
||||
// Specify location of time format. It's possible to specify "auto".
|
||||
// _busy_timeout=XXX
|
||||
// Specify value for sqlite3_busy_timeout.
|
||||
// _txlock=XXX
|
||||
// Specify locking behavior for transactions. XXX can be "immediate",
|
||||
// "deferred", "exclusive".
|
||||
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||
if C.sqlite3_threadsafe() == 0 {
|
||||
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
||||
}
|
||||
|
||||
var loc *time.Location
|
||||
txlock := "BEGIN"
|
||||
busy_timeout := 5000
|
||||
pos := strings.IndexRune(dsn, '?')
|
||||
if pos >= 1 {
|
||||
params, err := url.ParseQuery(dsn[pos+1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// _loc
|
||||
if val := params.Get("_loc"); val != "" {
|
||||
if val == "auto" {
|
||||
loc = time.Local
|
||||
} else {
|
||||
loc, err = time.LoadLocation(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Invalid _loc: %v: %v", val, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// _busy_timeout
|
||||
if val := params.Get("_busy_timeout"); val != "" {
|
||||
iv, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err)
|
||||
}
|
||||
busy_timeout = int(iv)
|
||||
}
|
||||
|
||||
// _txlock
|
||||
if val := params.Get("_txlock"); val != "" {
|
||||
switch val {
|
||||
case "immediate":
|
||||
txlock = "BEGIN IMMEDIATE"
|
||||
case "exclusive":
|
||||
txlock = "BEGIN EXCLUSIVE"
|
||||
case "deferred":
|
||||
txlock = "BEGIN"
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid _txlock: %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(dsn, "file:") {
|
||||
dsn = dsn[:pos]
|
||||
}
|
||||
}
|
||||
|
||||
var db *C.sqlite3
|
||||
name := C.CString(dsn)
|
||||
defer C.free(unsafe.Pointer(name))
|
||||
rv := C._sqlite3_open_v2(name, &db,
|
||||
C.SQLITE_OPEN_FULLMUTEX|
|
||||
C.SQLITE_OPEN_READWRITE|
|
||||
C.SQLITE_OPEN_CREATE,
|
||||
nil)
|
||||
if rv != 0 {
|
||||
return nil, Error{Code: ErrNo(rv)}
|
||||
}
|
||||
if db == nil {
|
||||
return nil, errors.New("sqlite succeeded without returning a database")
|
||||
}
|
||||
|
||||
rv = C.sqlite3_busy_timeout(db, C.int(busy_timeout))
|
||||
if rv != C.SQLITE_OK {
|
||||
return nil, Error{Code: ErrNo(rv)}
|
||||
}
|
||||
|
||||
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
|
||||
|
||||
if len(d.Extensions) > 0 {
|
||||
if err := conn.loadExtensions(d.Extensions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if d.ConnectHook != nil {
|
||||
if err := d.ConnectHook(conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
runtime.SetFinalizer(conn, (*SQLiteConn).Close)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close the connection.
|
||||
func (c *SQLiteConn) Close() error {
|
||||
rv := C.sqlite3_close_v2(c.db)
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
c.db = nil
|
||||
runtime.SetFinalizer(c, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prepare query string. Return a new statement.
|
||||
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
|
||||
pquery := C.CString(query)
|
||||
defer C.free(unsafe.Pointer(pquery))
|
||||
var s *C.sqlite3_stmt
|
||||
var tail *C.char
|
||||
rv := C.sqlite3_prepare_v2(c.db, pquery, -1, &s, &tail)
|
||||
if rv != C.SQLITE_OK {
|
||||
return nil, c.lastError()
|
||||
}
|
||||
var t string
|
||||
if tail != nil && *tail != '\000' {
|
||||
t = strings.TrimSpace(C.GoString(tail))
|
||||
}
|
||||
nv := int(C.sqlite3_bind_parameter_count(s))
|
||||
var nn []string
|
||||
for i := 0; i < nv; i++ {
|
||||
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
|
||||
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
|
||||
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
|
||||
}
|
||||
}
|
||||
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
|
||||
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// Close the statement.
|
||||
func (s *SQLiteStmt) Close() error {
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
if s.c == nil || s.c.db == nil {
|
||||
return errors.New("sqlite statement with already closed database connection")
|
||||
}
|
||||
rv := C.sqlite3_finalize(s.s)
|
||||
if rv != C.SQLITE_OK {
|
||||
return s.c.lastError()
|
||||
}
|
||||
runtime.SetFinalizer(s, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return a number of parameters.
|
||||
func (s *SQLiteStmt) NumInput() int {
|
||||
return s.nv
|
||||
}
|
||||
|
||||
type bindArg struct {
|
||||
n int
|
||||
v driver.Value
|
||||
}
|
||||
|
||||
func (s *SQLiteStmt) bind(args []driver.Value) error {
|
||||
rv := C.sqlite3_reset(s.s)
|
||||
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
||||
return s.c.lastError()
|
||||
}
|
||||
|
||||
var vargs []bindArg
|
||||
narg := len(args)
|
||||
vargs = make([]bindArg, narg)
|
||||
if len(s.nn) > 0 {
|
||||
for i, v := range s.nn {
|
||||
if pi, err := strconv.Atoi(v[1:]); err == nil {
|
||||
vargs[i] = bindArg{pi, args[i]}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i, v := range args {
|
||||
vargs[i] = bindArg{i + 1, v}
|
||||
}
|
||||
}
|
||||
|
||||
for _, varg := range vargs {
|
||||
n := C.int(varg.n)
|
||||
v := varg.v
|
||||
switch v := v.(type) {
|
||||
case nil:
|
||||
rv = C.sqlite3_bind_null(s.s, n)
|
||||
case string:
|
||||
if len(v) == 0 {
|
||||
b := []byte{0}
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
|
||||
} else {
|
||||
b := []byte(v)
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||
}
|
||||
case int64:
|
||||
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
|
||||
case bool:
|
||||
if bool(v) {
|
||||
rv = C.sqlite3_bind_int(s.s, n, 1)
|
||||
} else {
|
||||
rv = C.sqlite3_bind_int(s.s, n, 0)
|
||||
}
|
||||
case float64:
|
||||
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
|
||||
case []byte:
|
||||
var p *byte
|
||||
if len(v) > 0 {
|
||||
p = &v[0]
|
||||
}
|
||||
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
|
||||
case time.Time:
|
||||
b := []byte(v.Format(SQLiteTimestampFormats[0]))
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||
}
|
||||
if rv != C.SQLITE_OK {
|
||||
return s.c.lastError()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query the statement with arguments. Return records.
|
||||
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
if err := s.bind(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil, s.cls}, nil
|
||||
}
|
||||
|
||||
// Return last inserted ID.
|
||||
func (r *SQLiteResult) LastInsertId() (int64, error) {
|
||||
return r.id, nil
|
||||
}
|
||||
|
||||
// Return how many rows affected.
|
||||
func (r *SQLiteResult) RowsAffected() (int64, error) {
|
||||
return r.changes, nil
|
||||
}
|
||||
|
||||
// Execute the statement with arguments. Return result object.
|
||||
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
if err := s.bind(args); err != nil {
|
||||
C.sqlite3_reset(s.s)
|
||||
C.sqlite3_clear_bindings(s.s)
|
||||
return nil, err
|
||||
}
|
||||
var rowid, changes C.longlong
|
||||
rv := C._sqlite3_step(s.s, &rowid, &changes)
|
||||
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
||||
err := s.c.lastError()
|
||||
C.sqlite3_reset(s.s)
|
||||
C.sqlite3_clear_bindings(s.s)
|
||||
return nil, err
|
||||
}
|
||||
return &SQLiteResult{int64(rowid), int64(changes)}, nil
|
||||
}
|
||||
|
||||
// Close the rows.
|
||||
func (rc *SQLiteRows) Close() error {
|
||||
if rc.s.closed {
|
||||
return nil
|
||||
}
|
||||
if rc.cls {
|
||||
return rc.s.Close()
|
||||
}
|
||||
rv := C.sqlite3_reset(rc.s.s)
|
||||
if rv != C.SQLITE_OK {
|
||||
return rc.s.c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return column names.
|
||||
func (rc *SQLiteRows) Columns() []string {
|
||||
if rc.nc != len(rc.cols) {
|
||||
rc.cols = make([]string, rc.nc)
|
||||
for i := 0; i < rc.nc; i++ {
|
||||
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
|
||||
}
|
||||
}
|
||||
return rc.cols
|
||||
}
|
||||
|
||||
// Move cursor to next.
|
||||
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||
rv := C.sqlite3_step(rc.s.s)
|
||||
if rv == C.SQLITE_DONE {
|
||||
return io.EOF
|
||||
}
|
||||
if rv != C.SQLITE_ROW {
|
||||
rv = C.sqlite3_reset(rc.s.s)
|
||||
if rv != C.SQLITE_OK {
|
||||
return rc.s.c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if rc.decltype == nil {
|
||||
rc.decltype = make([]string, rc.nc)
|
||||
for i := 0; i < rc.nc; i++ {
|
||||
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
|
||||
}
|
||||
}
|
||||
|
||||
for i := range dest {
|
||||
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
|
||||
case C.SQLITE_INTEGER:
|
||||
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
|
||||
switch rc.decltype[i] {
|
||||
case "timestamp", "datetime", "date":
|
||||
var t time.Time
|
||||
// Assume a millisecond unix timestamp if it's 13 digits -- too
|
||||
// large to be a reasonable timestamp in seconds.
|
||||
if val > 1e12 || val < -1e12 {
|
||||
val *= int64(time.Millisecond) // convert ms to nsec
|
||||
} else {
|
||||
val *= int64(time.Second) // convert sec to nsec
|
||||
}
|
||||
t = time.Unix(0, val).UTC()
|
||||
if rc.s.c.loc != nil {
|
||||
t = t.In(rc.s.c.loc)
|
||||
}
|
||||
dest[i] = t
|
||||
case "boolean":
|
||||
dest[i] = val > 0
|
||||
default:
|
||||
dest[i] = val
|
||||
}
|
||||
case C.SQLITE_FLOAT:
|
||||
dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i)))
|
||||
case C.SQLITE_BLOB:
|
||||
p := C.sqlite3_column_blob(rc.s.s, C.int(i))
|
||||
if p == nil {
|
||||
dest[i] = nil
|
||||
continue
|
||||
}
|
||||
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
|
||||
switch dest[i].(type) {
|
||||
case sql.RawBytes:
|
||||
dest[i] = (*[1 << 30]byte)(unsafe.Pointer(p))[0:n]
|
||||
default:
|
||||
slice := make([]byte, n)
|
||||
copy(slice[:], (*[1 << 30]byte)(unsafe.Pointer(p))[0:n])
|
||||
dest[i] = slice
|
||||
}
|
||||
case C.SQLITE_NULL:
|
||||
dest[i] = nil
|
||||
case C.SQLITE_TEXT:
|
||||
var err error
|
||||
var timeVal time.Time
|
||||
|
||||
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
|
||||
s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n))
|
||||
|
||||
switch rc.decltype[i] {
|
||||
case "timestamp", "datetime", "date":
|
||||
var t time.Time
|
||||
s = strings.TrimSuffix(s, "Z")
|
||||
for _, format := range SQLiteTimestampFormats {
|
||||
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
|
||||
t = timeVal
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// The column is a time value, so return the zero time on parse failure.
|
||||
t = time.Time{}
|
||||
}
|
||||
if rc.s.c.loc != nil {
|
||||
t = t.In(rc.s.c.loc)
|
||||
}
|
||||
dest[i] = t
|
||||
default:
|
||||
dest[i] = []byte(s)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
13
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_icu.go
generated
vendored
Normal file
13
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_icu.go
generated
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build icu
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -licuuc -licui18n
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_ICU
|
||||
*/
|
||||
import "C"
|
13
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_libsqlite3.go
generated
vendored
Normal file
13
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_libsqlite3.go
generated
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build libsqlite3
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DUSE_LIBSQLITE3
|
||||
#cgo LDFLAGS: -lsqlite3
|
||||
*/
|
||||
import "C"
|
39
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_load_extension.go
generated
vendored
Normal file
39
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_load_extension.go
generated
vendored
Normal file
|
@ -0,0 +1,39 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build !sqlite_omit_load_extension
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#include <sqlite3-binding.h>
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (c *SQLiteConn) loadExtensions(extensions []string) error {
|
||||
rv := C.sqlite3_enable_load_extension(c.db, 1)
|
||||
if rv != C.SQLITE_OK {
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
|
||||
for _, extension := range extensions {
|
||||
cext := C.CString(extension)
|
||||
defer C.free(unsafe.Pointer(cext))
|
||||
rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
|
||||
if rv != C.SQLITE_OK {
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
}
|
||||
|
||||
rv = C.sqlite3_enable_load_extension(c.db, 0)
|
||||
if rv != C.SQLITE_OK {
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
return nil
|
||||
}
|
19
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_omit_load_extension.go
generated
vendored
Normal file
19
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_omit_load_extension.go
generated
vendored
Normal file
|
@ -0,0 +1,19 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build sqlite_omit_load_extension
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_OMIT_LOAD_EXTENSION
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func (c *SQLiteConn) loadExtensions(extensions []string) error {
|
||||
return errors.New("Extensions have been disabled for static builds")
|
||||
}
|
13
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_other.go
generated
vendored
Normal file
13
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_other.go
generated
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build !windows
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I.
|
||||
#cgo linux LDFLAGS: -ldl
|
||||
*/
|
||||
import "C"
|
409
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_test/sqltest.go
generated
vendored
Normal file
409
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_test/sqltest.go
generated
vendored
Normal file
|
@ -0,0 +1,409 @@
|
|||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Dialect int
|
||||
|
||||
const (
|
||||
SQLITE Dialect = iota
|
||||
POSTGRESQL
|
||||
MYSQL
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
*testing.T
|
||||
*sql.DB
|
||||
dialect Dialect
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
var db *DB
|
||||
|
||||
// the following tables will be created and dropped during the test
|
||||
var testTables = []string{"foo", "bar", "t", "bench"}
|
||||
|
||||
var tests = []testing.InternalTest{
|
||||
{"TestBlobs", TestBlobs},
|
||||
{"TestManyQueryRow", TestManyQueryRow},
|
||||
{"TestTxQuery", TestTxQuery},
|
||||
{"TestPreparedStmt", TestPreparedStmt},
|
||||
}
|
||||
|
||||
var benchmarks = []testing.InternalBenchmark{
|
||||
{"BenchmarkExec", BenchmarkExec},
|
||||
{"BenchmarkQuery", BenchmarkQuery},
|
||||
{"BenchmarkParams", BenchmarkParams},
|
||||
{"BenchmarkStmt", BenchmarkStmt},
|
||||
{"BenchmarkRows", BenchmarkRows},
|
||||
{"BenchmarkStmtRows", BenchmarkStmtRows},
|
||||
}
|
||||
|
||||
// RunTests runs the SQL test suite
|
||||
func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
|
||||
db = &DB{t, d, dialect, sync.Once{}}
|
||||
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
|
||||
|
||||
if !testing.Short() {
|
||||
for _, b := range benchmarks {
|
||||
fmt.Printf("%-20s", b.Name)
|
||||
r := testing.Benchmark(b.F)
|
||||
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
|
||||
}
|
||||
}
|
||||
db.tearDown()
|
||||
}
|
||||
|
||||
func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
res, err := db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
db.Fatalf("Error running %q: %v", sql, err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (db *DB) tearDown() {
|
||||
for _, tbl := range testTables {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
case MYSQL, POSTGRESQL:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
default:
|
||||
db.Fatal("unkown dialect")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// q replaces ? parameters if needed
|
||||
func (db *DB) q(sql string) string {
|
||||
switch db.dialect {
|
||||
case POSTGRESQL: // repace with $1, $2, ..
|
||||
qrx := regexp.MustCompile(`\?`)
|
||||
n := 0
|
||||
return qrx.ReplaceAllStringFunc(sql, func(string) string {
|
||||
n++
|
||||
return "$" + strconv.Itoa(n)
|
||||
})
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (db *DB) blobType(size int) string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return fmt.Sprintf("blob[%d]", size)
|
||||
case POSTGRESQL:
|
||||
return "bytea"
|
||||
case MYSQL:
|
||||
return fmt.Sprintf("VARBINARY(%d)", size)
|
||||
}
|
||||
panic("unkown dialect")
|
||||
}
|
||||
|
||||
func (db *DB) serialPK() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "integer primary key autoincrement"
|
||||
case POSTGRESQL:
|
||||
return "serial primary key"
|
||||
case MYSQL:
|
||||
return "integer primary key auto_increment"
|
||||
}
|
||||
panic("unkown dialect")
|
||||
}
|
||||
|
||||
func (db *DB) now() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "datetime('now')"
|
||||
case POSTGRESQL:
|
||||
return "now()"
|
||||
case MYSQL:
|
||||
return "now()"
|
||||
}
|
||||
panic("unkown dialect")
|
||||
}
|
||||
|
||||
func makeBench() {
|
||||
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
for i := 0; i < 100; i++ {
|
||||
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResult(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
|
||||
|
||||
for i := 1; i < 3; i++ {
|
||||
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
|
||||
n, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("got %v, want %v", n, 1)
|
||||
}
|
||||
n, err = r.LastInsertId()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != int64(i) {
|
||||
t.Errorf("got %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec("error!"); err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlobs(t *testing.T) {
|
||||
db.tearDown()
|
||||
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
|
||||
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
|
||||
|
||||
want := fmt.Sprintf("%x", blob)
|
||||
|
||||
b := make([]byte, 16)
|
||||
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
|
||||
got := fmt.Sprintf("%x", b)
|
||||
if err != nil {
|
||||
t.Errorf("[]byte scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for []byte, got %q; want %q", got, want)
|
||||
}
|
||||
|
||||
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
|
||||
want = string(blob)
|
||||
if err != nil {
|
||||
t.Errorf("string scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for string, got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManyQueryRow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Log("skipping in short mode")
|
||||
return
|
||||
}
|
||||
db.tearDown()
|
||||
db.mustExec("create table foo (id integer primary key, name varchar(50))")
|
||||
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
var name string
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
|
||||
if err != nil || name != "bob" {
|
||||
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxQuery(t *testing.T) {
|
||||
db.tearDown()
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("expected one rows")
|
||||
}
|
||||
|
||||
var name string
|
||||
err = r.Scan(&name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmt(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("CREATE TABLE t (count INT)")
|
||||
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 1: %v", err)
|
||||
}
|
||||
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 2: %v", err)
|
||||
}
|
||||
|
||||
for n := 1; n <= 3; n++ {
|
||||
if _, err := ins.Exec(n); err != nil {
|
||||
t.Fatalf("insert(%d) = %v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
const nRuns = 10
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nRuns; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
count := 0
|
||||
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||
t.Errorf("Query: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := ins.Exec(rand.Intn(100)); err != nil {
|
||||
t.Errorf("Insert: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Benchmarks need to use panic() since b.Error errors are lost when
|
||||
// running via testing.Benchmark() I would like to run these via go
|
||||
// test -bench but calling Benchmark() from a benchmark test
|
||||
// currently hangs go.
|
||||
|
||||
func BenchmarkExec(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := db.Exec("select 1"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQuery(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParams(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStmt(b *testing.B) {
|
||||
st, err := db.Prepare("select ?, ?, ?, ?")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := db.Query("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStmtRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
st, err := db.Prepare("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := st.Query()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
14
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_windows.go
generated
vendored
Normal file
14
Godeps/_workspace/src/github.com/mattn/go-sqlite3/sqlite3_windows.go
generated
vendored
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build windows
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I. -fno-stack-check -fno-stack-protector -mno-stack-arg-probe
|
||||
#cgo windows,386 CFLAGS: -D_localtime32=localtime
|
||||
#cgo LDFLAGS: -lmingwex -lmingw32
|
||||
*/
|
||||
import "C"
|
|
@ -4,7 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/schema/adminschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
|
@ -22,7 +22,9 @@ type testFixtures struct {
|
|||
func makeTestFixtures() *testFixtures {
|
||||
f := &testFixtures{}
|
||||
|
||||
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
||||
dbMap := db.NewMemDB()
|
||||
f.ur = func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
|
@ -38,16 +40,35 @@ func makeTestFixtures() *testFixtures {
|
|||
},
|
||||
},
|
||||
})
|
||||
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
f.pwr = func() user.PasswordInfoRepo {
|
||||
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{
|
||||
{
|
||||
UserID: "ID-1",
|
||||
Password: []byte("hi."),
|
||||
},
|
||||
})
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
ccr := func() connector.ConnectorConfigRepo {
|
||||
c := []connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}}
|
||||
repo := db.NewConnectorConfigRepo(dbMap)
|
||||
if err := repo.Set(c); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||
f.adAPI = NewAdminAPI(f.mgr, f.ur, f.pwr, "local")
|
||||
|
||||
return f
|
||||
|
|
18
build
18
build
|
@ -1,18 +1,10 @@
|
|||
#!/bin/bash -e
|
||||
|
||||
export GOPATH=${PWD}/Godeps/_workspace
|
||||
export GOBIN=${PWD}/bin
|
||||
source ./env
|
||||
|
||||
rm -rf $GOPATH/src/github.com/coreos/dex
|
||||
mkdir -p $GOPATH/src/github.com/coreos/
|
||||
|
||||
# Only attempt to link dex into godeps if it isn't already there
|
||||
[ -d $GOPATH/src/github.com/coreos/dex ] || ln -s ${PWD} $GOPATH/src/github.com/coreos/dex
|
||||
|
||||
LD_FLAGS="-X main.version=$(./git-version)"
|
||||
go build -o bin/dex-worker -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dex-worker
|
||||
go build -o bin/dexctl -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dexctl
|
||||
go build -o bin/dex-overlord -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dex-overlord
|
||||
go install -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dex-worker
|
||||
go install -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dexctl
|
||||
go install -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dex-overlord
|
||||
go build -o bin/example-app github.com/coreos/dex/examples/app
|
||||
go build -o bin/example-cli github.com/coreos/dex/examples/cli
|
||||
go build -o bin/gendoc github.com/coreos/dex/cmd/gendoc
|
||||
go install github.com/coreos/dex/cmd/gendoc
|
||||
|
|
146
client/client.go
146
client/client.go
|
@ -1,16 +1,10 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
pcrypto "github.com/coreos/dex/pkg/crypto"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -46,146 +40,6 @@ type ClientIdentityRepo interface {
|
|||
IsDexAdmin(clientID string) (bool, error)
|
||||
}
|
||||
|
||||
func NewClientIdentityRepo(cs []oidc.ClientIdentity) ClientIdentityRepo {
|
||||
cr := memClientIdentityRepo{
|
||||
idents: make(map[string]oidc.ClientIdentity, len(cs)),
|
||||
admins: make(map[string]bool),
|
||||
}
|
||||
|
||||
for _, c := range cs {
|
||||
c := c
|
||||
cr.idents[c.Credentials.ID] = c
|
||||
}
|
||||
|
||||
return &cr
|
||||
}
|
||||
|
||||
type memClientIdentityRepo struct {
|
||||
idents map[string]oidc.ClientIdentity
|
||||
admins map[string]bool
|
||||
}
|
||||
|
||||
func (cr *memClientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) {
|
||||
if _, ok := cr.idents[id]; ok {
|
||||
return nil, errors.New("client ID already exists")
|
||||
}
|
||||
|
||||
secret, err := pcrypto.RandBytes(32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc := oidc.ClientCredentials{
|
||||
ID: id,
|
||||
Secret: base64.URLEncoding.EncodeToString(secret),
|
||||
}
|
||||
|
||||
cr.idents[id] = oidc.ClientIdentity{
|
||||
Metadata: meta,
|
||||
Credentials: cc,
|
||||
}
|
||||
|
||||
return &cc, nil
|
||||
}
|
||||
|
||||
func (cr *memClientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||
ci, ok := cr.idents[clientID]
|
||||
if !ok {
|
||||
return nil, ErrorNotFound
|
||||
}
|
||||
return &ci.Metadata, nil
|
||||
}
|
||||
|
||||
func (cr *memClientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
||||
ci, ok := cr.idents[creds.ID]
|
||||
ok = ok && ci.Credentials.Secret == creds.Secret
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (cr *memClientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
|
||||
cs := make(sortableClientIdentities, 0, len(cr.idents))
|
||||
for _, ci := range cr.idents {
|
||||
ci := ci
|
||||
cs = append(cs, ci)
|
||||
}
|
||||
sort.Sort(cs)
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func (cr *memClientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
||||
cr.admins[clientID] = isAdmin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cr *memClientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
|
||||
return cr.admins[clientID], nil
|
||||
}
|
||||
|
||||
type sortableClientIdentities []oidc.ClientIdentity
|
||||
|
||||
func (s sortableClientIdentities) Len() int {
|
||||
return len([]oidc.ClientIdentity(s))
|
||||
}
|
||||
|
||||
func (s sortableClientIdentities) Less(i, j int) bool {
|
||||
return s[i].Credentials.ID < s[j].Credentials.ID
|
||||
}
|
||||
|
||||
func (s sortableClientIdentities) Swap(i, j int) {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
|
||||
func NewClientIdentityRepoFromReader(r io.Reader) (ClientIdentityRepo, error) {
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cs []clientIdentity
|
||||
if err = json.Unmarshal(b, &cs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ocs := make([]oidc.ClientIdentity, len(cs))
|
||||
for i, c := range cs {
|
||||
ocs[i] = oidc.ClientIdentity(c)
|
||||
}
|
||||
|
||||
return NewClientIdentityRepo(ocs), nil
|
||||
}
|
||||
|
||||
type clientIdentity oidc.ClientIdentity
|
||||
|
||||
func (ci *clientIdentity) UnmarshalJSON(data []byte) error {
|
||||
c := struct {
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
RedirectURLs []string `json:"redirectURLs"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(data, &c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ci.Credentials = oidc.ClientCredentials{
|
||||
ID: c.ID,
|
||||
Secret: c.Secret,
|
||||
}
|
||||
ci.Metadata = oidc.ClientMetadata{
|
||||
RedirectURIs: make([]url.URL, len(c.RedirectURLs)),
|
||||
}
|
||||
|
||||
for i, us := range c.RedirectURLs {
|
||||
up, err := url.Parse(us)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ci.Metadata.RedirectURIs[i] = *up
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.
|
||||
// If nil is passed in as the rURL and there is only one URL in redirectURLs,
|
||||
// that URL will be returned. If nil is passed but theres >1 URL in the slice,
|
||||
|
|
|
@ -1,190 +0,0 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
func TestMemClientIdentityRepoNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
id string
|
||||
meta oidc.ClientMetadata
|
||||
}{
|
||||
{
|
||||
id: "foo",
|
||||
meta: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{
|
||||
Scheme: "https",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
id: "bar",
|
||||
meta: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "https", Host: "example.com/foo"},
|
||||
url.URL{Scheme: "https", Host: "example.com/bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
cr := NewClientIdentityRepo(nil)
|
||||
creds, err := cr.New(tt.id, tt.meta)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
if creds.ID != tt.id {
|
||||
t.Errorf("case %d: expected non-empty Client ID", i)
|
||||
}
|
||||
|
||||
if creds.Secret == "" {
|
||||
t.Errorf("case %d: expected non-empty Client Secret", i)
|
||||
}
|
||||
|
||||
all, err := cr.All()
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if len(all) != 1 {
|
||||
t.Errorf("case %d: expected repo to contain newly created Client", i)
|
||||
}
|
||||
|
||||
wantURLs := tt.meta.RedirectURIs
|
||||
gotURLs := all[0].Metadata.RedirectURIs
|
||||
if !reflect.DeepEqual(wantURLs, gotURLs) {
|
||||
t.Errorf("case %d: redirect url mismatch, want=%v, got=%v", i, wantURLs, gotURLs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemClientIdentityRepoNewDuplicate(t *testing.T) {
|
||||
cr := NewClientIdentityRepo(nil)
|
||||
|
||||
meta1 := oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "https", Host: "foo.example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := cr.New("foo", meta1); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
meta2 := oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "https", Host: "bar.example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := cr.New("foo", meta2); err == nil {
|
||||
t.Errorf("expected non-nil error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemClientIdentityRepoAll(t *testing.T) {
|
||||
tests := []struct {
|
||||
ids []string
|
||||
}{
|
||||
{
|
||||
ids: nil,
|
||||
},
|
||||
{
|
||||
ids: []string{"foo"},
|
||||
},
|
||||
{
|
||||
ids: []string{"foo", "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
cs := make([]oidc.ClientIdentity, len(tt.ids))
|
||||
for i, s := range tt.ids {
|
||||
cs[i] = oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: s,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
cr := NewClientIdentityRepo(cs)
|
||||
|
||||
all, err := cr.All()
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
want := sortableClientIdentities(cs)
|
||||
sort.Sort(want)
|
||||
got := sortableClientIdentities(all)
|
||||
sort.Sort(got)
|
||||
|
||||
if len(got) != len(want) {
|
||||
t.Errorf("case %d: wrong length: %d", i, len(got))
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("case %d: want=%#v, got=%#v", i, want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIdentityUnmarshalJSON(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
json string
|
||||
expectedID string
|
||||
expectedSecret string
|
||||
expectedURLs []string
|
||||
}{
|
||||
{
|
||||
json: `{"id":"12345","secret":"rosebud","redirectURLs":["https://redirectone.com", "https://redirecttwo.com"]}`,
|
||||
expectedID: "12345",
|
||||
expectedSecret: "rosebud",
|
||||
expectedURLs: []string{
|
||||
"https://redirectone.com",
|
||||
"https://redirecttwo.com",
|
||||
},
|
||||
},
|
||||
} {
|
||||
var actual clientIdentity
|
||||
err := json.Unmarshal([]byte(test.json), &actual)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: error unmarshalling: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if actual.Credentials.ID != test.expectedID {
|
||||
t.Errorf("case %d: actual.Credentials.ID == %v, want %v", i, actual.Credentials.ID, test.expectedID)
|
||||
}
|
||||
|
||||
if actual.Credentials.Secret != test.expectedSecret {
|
||||
t.Errorf("case %d: actual.Credentials.Secret == %v, want %v", i, actual.Credentials.Secret, test.expectedSecret)
|
||||
}
|
||||
expectedURLs := test.expectedURLs
|
||||
sort.Strings(expectedURLs)
|
||||
|
||||
actualURLs := make([]string, 0)
|
||||
for _, u := range actual.Metadata.RedirectURIs {
|
||||
actualURLs = append(actualURLs, u.String())
|
||||
}
|
||||
sort.Strings(actualURLs)
|
||||
if len(actualURLs) != len(expectedURLs) {
|
||||
t.Errorf("case %d: len(actualURLs) == %v, want %v", i, len(actualURLs), len(expectedURLs))
|
||||
}
|
||||
for ui, actualURL := range actualURLs {
|
||||
if actualURL != expectedURLs[ui] {
|
||||
t.Errorf("case %d: actualURLs[%d] == %q, want %q", i, ui, actualURL, expectedURLs[ui])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/admin"
|
||||
"github.com/coreos/dex/db"
|
||||
|
@ -94,6 +95,9 @@ func main() {
|
|||
if err != nil {
|
||||
log.Fatalf(err.Error())
|
||||
}
|
||||
if _, ok := dbc.Dialect.(gorp.PostgresDialect); !ok {
|
||||
log.Fatal("only postgres backend supported for multi server configurations")
|
||||
}
|
||||
|
||||
if *dbMigrate {
|
||||
var sleep time.Duration
|
||||
|
|
|
@ -3,8 +3,6 @@ package connector
|
|||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
func ReadConfigs(r io.Reader) ([]ConnectorConfig, error) {
|
||||
|
@ -22,24 +20,3 @@ func ReadConfigs(r io.Reader) ([]ConnectorConfig, error) {
|
|||
}
|
||||
return cfgs, nil
|
||||
}
|
||||
|
||||
type memConnectorConfigRepo struct {
|
||||
configs []ConnectorConfig
|
||||
}
|
||||
|
||||
func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo {
|
||||
return &memConnectorConfigRepo{configs: cfgs}
|
||||
}
|
||||
|
||||
func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
|
||||
return r.configs, nil
|
||||
}
|
||||
|
||||
func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) {
|
||||
for _, cfg := range r.configs {
|
||||
if cfg.ConnectorID() == id {
|
||||
return cfg, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrorNotFound
|
||||
}
|
||||
|
|
63
db/client.go
63
db/client.go
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
|
@ -85,35 +86,47 @@ func (m *clientIdentityModel) ClientIdentity() (*oidc.ClientIdentity, error) {
|
|||
}
|
||||
|
||||
func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo {
|
||||
return &clientIdentityRepo{dbMap: dbm}
|
||||
return newClientIdentityRepo(dbm)
|
||||
}
|
||||
|
||||
func newClientIdentityRepo(dbm *gorp.DbMap) *clientIdentityRepo {
|
||||
return &clientIdentityRepo{db: &db{dbm}}
|
||||
}
|
||||
|
||||
func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) {
|
||||
repo := NewClientIdentityRepo(dbm).(*clientIdentityRepo)
|
||||
repo := newClientIdentityRepo(dbm)
|
||||
tx, err := repo.begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := repo.executor(tx)
|
||||
for _, c := range clients {
|
||||
dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = repo.dbMap.Insert(cm)
|
||||
err = exec.Insert(cm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
type clientIdentityRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
}
|
||||
|
||||
func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
|
||||
m, err := r.executor(nil).Get(clientIdentityModel{}, clientID)
|
||||
if err == sql.ErrNoRows || m == nil {
|
||||
return nil, client.ErrorNotFound
|
||||
}
|
||||
|
@ -136,7 +149,7 @@ func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, er
|
|||
}
|
||||
|
||||
func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
|
||||
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
|
||||
m, err := r.executor(nil).Get(clientIdentityModel{}, clientID)
|
||||
if m == nil || err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -151,42 +164,35 @@ func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
|
|||
}
|
||||
|
||||
func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
||||
tx, err := r.dbMap.Begin()
|
||||
tx, err := r.begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := r.executor(tx)
|
||||
|
||||
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
|
||||
m, err := exec.Get(clientIdentityModel{}, clientID)
|
||||
if m == nil || err != nil {
|
||||
rollback(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
cim, ok := m.(*clientIdentityModel)
|
||||
if !ok {
|
||||
rollback(tx)
|
||||
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
|
||||
return errors.New("unrecognized model")
|
||||
}
|
||||
|
||||
cim.DexAdmin = isAdmin
|
||||
_, err = r.dbMap.Update(cim)
|
||||
_, err = exec.Update(cim)
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
||||
m, err := r.dbMap.Get(clientIdentityModel{}, creds.ID)
|
||||
m, err := r.executor(nil).Get(clientIdentityModel{}, creds.ID)
|
||||
if m == nil || err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -222,10 +228,17 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := r.dbMap.Insert(cim); err != nil {
|
||||
if perr, ok := err.(*pq.Error); ok && perr.Code == pgErrorCodeUniqueViolation {
|
||||
if err := r.executor(nil).Insert(cim); err != nil {
|
||||
switch sqlErr := err.(type) {
|
||||
case *pq.Error:
|
||||
if sqlErr.Code == pgErrorCodeUniqueViolation {
|
||||
err = errors.New("client ID already exists")
|
||||
}
|
||||
case *sqlite3.Error:
|
||||
if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique {
|
||||
err = errors.New("client ID already exists")
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
@ -239,9 +252,9 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
|
|||
}
|
||||
|
||||
func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
|
||||
qt := pq.QuoteIdentifier(clientIdentityTableName)
|
||||
qt := r.quote(clientIdentityTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||
objs, err := r.dbMap.Select(&clientIdentityModel{}, q)
|
||||
objs, err := r.executor(nil).Select(&clientIdentityModel{}, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
51
db/conn.go
51
db/conn.go
|
@ -4,13 +4,15 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
|
||||
// Import database drivers
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type table struct {
|
||||
|
@ -43,22 +45,35 @@ type Config struct {
|
|||
}
|
||||
|
||||
func NewConnection(cfg Config) (*gorp.DbMap, error) {
|
||||
if !strings.HasPrefix(cfg.DSN, "postgres://") {
|
||||
return nil, errors.New("unrecognized database driver")
|
||||
u, err := url.Parse(cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse DSN: %v", err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", cfg.DSN)
|
||||
var (
|
||||
db *sql.DB
|
||||
dialect gorp.Dialect
|
||||
)
|
||||
switch u.Scheme {
|
||||
case "postgres":
|
||||
db, err = sql.Open("postgres", cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConnections)
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConnections)
|
||||
|
||||
dbm := gorp.DbMap{
|
||||
Db: db,
|
||||
Dialect: gorp.PostgresDialect{},
|
||||
dialect = gorp.PostgresDialect{}
|
||||
case "sqlite3":
|
||||
db, err = sql.Open("sqlite3", u.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// NOTE(ericchiang): sqlite does NOT work with SetMaxIdleConns.
|
||||
dialect = gorp.SqliteDialect{}
|
||||
default:
|
||||
return nil, errors.New("unrecognized database driver")
|
||||
}
|
||||
|
||||
dbm := gorp.DbMap{Db: db, Dialect: dialect}
|
||||
|
||||
for _, t := range tables {
|
||||
tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...)
|
||||
|
@ -70,7 +85,6 @@ func NewConnection(cfg Config) (*gorp.DbMap, error) {
|
|||
cm.SetUnique(true)
|
||||
}
|
||||
}
|
||||
|
||||
return &dbm, nil
|
||||
}
|
||||
|
||||
|
@ -80,9 +94,14 @@ func TransactionFactory(conn *gorp.DbMap) repo.TransactionFactory {
|
|||
}
|
||||
}
|
||||
|
||||
func rollback(tx *gorp.Transaction) {
|
||||
err := tx.Rollback()
|
||||
// NewMemDB creates a new in memory sqlite3 database.
|
||||
func NewMemDB() *gorp.DbMap {
|
||||
dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"})
|
||||
if err != nil {
|
||||
log.Errorf("unable to rollback: %v", err)
|
||||
panic("Failed to create in memory database: " + err.Error())
|
||||
}
|
||||
if _, err := MigrateToLatest(dbMap); err != nil {
|
||||
panic("In memory database migration failed: " + err.Error())
|
||||
}
|
||||
return dbMap
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
|
@ -61,17 +60,17 @@ func (m *connectorConfigModel) ConnectorConfig() (connector.ConnectorConfig, err
|
|||
}
|
||||
|
||||
func NewConnectorConfigRepo(dbm *gorp.DbMap) *ConnectorConfigRepo {
|
||||
return &ConnectorConfigRepo{dbMap: dbm}
|
||||
return &ConnectorConfigRepo{&db{dbm}}
|
||||
}
|
||||
|
||||
type ConnectorConfigRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
||||
qt := r.quote(connectorConfigTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||
objs, err := r.dbMap.Select(&connectorConfigModel{}, q)
|
||||
objs, err := r.executor(nil).Select(&connectorConfigModel{}, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -94,7 +93,7 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
|||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
|
||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
||||
qt := r.quote(connectorConfigTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
|
||||
var c connectorConfigModel
|
||||
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
|
||||
|
@ -117,32 +116,22 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
|||
insert[i] = m
|
||||
}
|
||||
|
||||
tx, err := r.dbMap.Begin()
|
||||
tx, err := r.begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := r.executor(tx)
|
||||
|
||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
||||
qt := r.quote(connectorConfigTableName)
|
||||
q := fmt.Sprintf("DELETE FROM %s", qt)
|
||||
if _, err = r.dbMap.Exec(q); err != nil {
|
||||
if _, err = exec.Exec(q); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = r.dbMap.Insert(insert...); err != nil {
|
||||
if err = exec.Insert(insert...); err != nil {
|
||||
return fmt.Errorf("DB insert failed %#v: %v", insert, err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
return gorpTx
|
||||
}
|
||||
|
|
60
db/db.go
Normal file
60
db/db.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
// Package db provides SQL implementations of dex's storage interfaces.
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/db/translate"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
// db is the connection type passed to repos.
|
||||
//
|
||||
// TODO(ericchiang): Eventually just return this instead of gorp.DbMap during Conn.
|
||||
// All actions should go through this type instead of dbMap.
|
||||
type db struct {
|
||||
dbMap *gorp.DbMap
|
||||
}
|
||||
|
||||
// executor returns a driver agnostic SQL executor.
|
||||
//
|
||||
// The expected flavor of all queries is the flavor used by github.com/lib/pq. All bind
|
||||
// parameters must be unique and in sequential order (e.g. $1, $2, ...).
|
||||
//
|
||||
// See github.com/coreos/dex/db/translate for details on the translation.
|
||||
//
|
||||
// If tx is nil, a non-transaction context is provided.
|
||||
func (conn *db) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
var exec gorp.SqlExecutor
|
||||
if tx == nil {
|
||||
exec = conn.dbMap
|
||||
} else {
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
|
||||
// Check if the underlying value of the pointer is nil.
|
||||
// This is not caught by the initial comparison (tx == nil).
|
||||
if gorpTx == nil {
|
||||
exec = conn.dbMap
|
||||
} else {
|
||||
exec = gorpTx
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := conn.dbMap.Dialect.(gorp.SqliteDialect); ok {
|
||||
exec = translate.NewTranslatingExecutor(exec, translate.PostgresToSQLite)
|
||||
}
|
||||
return exec
|
||||
}
|
||||
|
||||
// quote escapes a table name for a driver specific SQL query. quote uses the
|
||||
// gorp's package underlying quote logic and should NOT be used on untrusted input.
|
||||
func (conn *db) quote(tableName string) string {
|
||||
return conn.dbMap.Dialect.QuotedTableForQuery("", tableName)
|
||||
}
|
||||
|
||||
func (conn *db) begin() (repo.Transaction, error) {
|
||||
return conn.dbMap.Begin()
|
||||
}
|
23
db/key.go
23
db/key.go
|
@ -8,7 +8,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
|
||||
pcrypto "github.com/coreos/dex/pkg/crypto"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
|
@ -99,7 +98,7 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, useOldFormat bool, secrets ...[]byte)
|
|||
}
|
||||
|
||||
r := &PrivateKeySetRepo{
|
||||
dbMap: dbm,
|
||||
db: &db{dbm},
|
||||
useOldFormat: useOldFormat,
|
||||
secrets: secrets,
|
||||
}
|
||||
|
@ -108,17 +107,22 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, useOldFormat bool, secrets ...[]byte)
|
|||
}
|
||||
|
||||
type PrivateKeySetRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
useOldFormat bool
|
||||
secrets [][]byte
|
||||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||||
qt := pq.QuoteIdentifier(keyTableName)
|
||||
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt))
|
||||
qt := r.quote(keyTableName)
|
||||
tx, err := r.begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := r.executor(tx)
|
||||
if _, err := exec.Exec(fmt.Sprintf("DELETE FROM %s", qt)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pks, ok := ks.(*key.PrivateKeySet)
|
||||
if !ok {
|
||||
|
@ -148,12 +152,15 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
|||
}
|
||||
|
||||
b := &privateKeySetBlob{Value: v}
|
||||
return r.dbMap.Insert(b)
|
||||
if err := exec.Insert(b); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
|
||||
qt := pq.QuoteIdentifier(keyTableName)
|
||||
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
|
||||
qt := r.quote(keyTableName)
|
||||
objs, err := r.executor(nil).Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
migrate "github.com/rubenv/sql-migrate"
|
||||
"github.com/rubenv/sql-migrate"
|
||||
|
||||
"github.com/coreos/dex/db/migrations"
|
||||
)
|
||||
|
||||
const (
|
||||
migrationDialect = "postgres"
|
||||
migrationTable = "dex_migrations"
|
||||
migrationDir = "db/migrations"
|
||||
)
|
||||
|
@ -21,32 +20,57 @@ func init() {
|
|||
}
|
||||
|
||||
func MigrateToLatest(dbMap *gorp.DbMap) (int, error) {
|
||||
source := getSource()
|
||||
|
||||
return migrate.Exec(dbMap.Db, migrationDialect, source, migrate.Up)
|
||||
source, dialect, err := migrationSource(dbMap)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return migrate.Exec(dbMap.Db, dialect, source, migrate.Up)
|
||||
}
|
||||
|
||||
func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) {
|
||||
source := getSource()
|
||||
|
||||
return migrate.ExecMax(dbMap.Db, migrationDialect, source, migrate.Up, max)
|
||||
source, dialect, err := migrationSource(dbMap)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return migrate.ExecMax(dbMap.Db, dialect, source, migrate.Up, max)
|
||||
}
|
||||
|
||||
func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) {
|
||||
migrations, _, err := migrate.PlanMigration(dbMap.Db, migrationDialect, getSource(), migrate.Up, 0)
|
||||
source, dialect, err := migrationSource(dbMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
migrations, _, err := migrate.PlanMigration(dbMap.Db, dialect, source, migrate.Up, 0)
|
||||
return migrations, err
|
||||
}
|
||||
|
||||
func DropMigrationsTable(dbMap *gorp.DbMap) error {
|
||||
qt := pq.QuoteIdentifier(migrationTable)
|
||||
_, err := dbMap.Exec(fmt.Sprintf("drop table if exists %s ;", qt))
|
||||
qt := fmt.Sprintf("DROP TABLE IF EXISTS %s;", dbMap.Dialect.QuotedTableForQuery("", migrationTable))
|
||||
_, err := dbMap.Exec(qt)
|
||||
return err
|
||||
}
|
||||
|
||||
func getSource() migrate.MigrationSource {
|
||||
return &migrate.AssetMigrationSource{
|
||||
func migrationSource(dbMap *gorp.DbMap) (src migrate.MigrationSource, dialect string, err error) {
|
||||
switch dbMap.Dialect.(type) {
|
||||
case gorp.PostgresDialect:
|
||||
src = &migrate.AssetMigrationSource{
|
||||
Dir: migrationDir,
|
||||
Asset: migrations.Asset,
|
||||
AssetDir: migrations.AssetDir,
|
||||
}
|
||||
return src, "postgres", nil
|
||||
case gorp.SqliteDialect:
|
||||
src = &migrate.MemoryMigrationSource{
|
||||
Migrations: []*migrate.Migration{
|
||||
{
|
||||
Id: "dex.sql",
|
||||
Up: []string{sqlite3Migration},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return src, "sqlite3", nil
|
||||
default:
|
||||
return nil, "", errors.New("unsupported migration driver")
|
||||
}
|
||||
}
|
||||
|
|
73
db/migrate_sqlite3.go
Normal file
73
db/migrate_sqlite3.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package db
|
||||
|
||||
// SQLite3 is a test only database. There is only one migration because we do not support migrations.
|
||||
const sqlite3Migration = `
|
||||
CREATE TABLE authd_user (
|
||||
id text NOT NULL UNIQUE,
|
||||
email text,
|
||||
email_verified integer,
|
||||
display_name text,
|
||||
admin integer,
|
||||
created_at bigint,
|
||||
disabled integer
|
||||
);
|
||||
|
||||
CREATE TABLE client_identity (
|
||||
id text NOT NULL UNIQUE,
|
||||
secret blob,
|
||||
metadata text,
|
||||
dex_admin integer
|
||||
);
|
||||
|
||||
CREATE TABLE connector_config (
|
||||
id text NOT NULL UNIQUE,
|
||||
type text,
|
||||
config text
|
||||
);
|
||||
|
||||
CREATE TABLE key (
|
||||
value blob
|
||||
);
|
||||
|
||||
CREATE TABLE password_info (
|
||||
user_id text NOT NULL UNIQUE,
|
||||
password text,
|
||||
password_expires bigint
|
||||
);
|
||||
|
||||
CREATE TABLE refresh_token (
|
||||
id integer PRIMARY KEY,
|
||||
payload_hash blob,
|
||||
user_id text,
|
||||
client_id text
|
||||
);
|
||||
|
||||
CREATE TABLE remote_identity_mapping (
|
||||
connector_id text NOT NULL,
|
||||
user_id text,
|
||||
remote_id text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE session (
|
||||
id text NOT NULL UNIQUE,
|
||||
state text,
|
||||
created_at bigint,
|
||||
expires_at bigint,
|
||||
client_id text,
|
||||
client_state text,
|
||||
redirect_url text,
|
||||
identity text,
|
||||
connector_id text,
|
||||
user_id text,
|
||||
register integer,
|
||||
nonce text,
|
||||
scope text
|
||||
);
|
||||
|
||||
CREATE TABLE session_key (
|
||||
key text NOT NULL,
|
||||
session_id text,
|
||||
expires_at bigint,
|
||||
stale integer
|
||||
);
|
||||
`
|
|
@ -105,7 +105,7 @@ func TestMigrateClientMetadata(t *testing.T) {
|
|||
id := strconv.Itoa(i)
|
||||
m, err := dbMap.Get(clientIdentityModel{}, id)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to get model: %err", i, err)
|
||||
t.Errorf("case %d: failed to get model: %v", i, err)
|
||||
continue
|
||||
}
|
||||
cim, ok := m.(*clientIdentityModel)
|
||||
|
|
|
@ -5,10 +5,11 @@ import (
|
|||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/go-gorp/gorp"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -33,12 +34,22 @@ type passwordInfoModel struct {
|
|||
|
||||
func NewPasswordInfoRepo(dbm *gorp.DbMap) user.PasswordInfoRepo {
|
||||
return &passwordInfoRepo{
|
||||
dbMap: dbm,
|
||||
db: &db{dbm},
|
||||
}
|
||||
}
|
||||
|
||||
func NewPasswordInfoRepoFromPasswordInfos(dbm *gorp.DbMap, infos []user.PasswordInfo) (user.PasswordInfoRepo, error) {
|
||||
repo := NewPasswordInfoRepo(dbm)
|
||||
for _, info := range infos {
|
||||
if err := repo.Create(nil, info); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
type passwordInfoRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
}
|
||||
|
||||
func (r *passwordInfoRepo) Get(tx repo.Transaction, userID string) (user.PasswordInfo, error) {
|
||||
|
@ -89,18 +100,6 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *passwordInfoRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
return gorpTx
|
||||
}
|
||||
|
||||
func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) {
|
||||
ex := r.executor(tx)
|
||||
|
||||
|
|
|
@ -8,10 +8,12 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/go-gorp/gorp"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -28,7 +30,7 @@ func init() {
|
|||
}
|
||||
|
||||
type refreshTokenRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
tokenGenerator refresh.RefreshTokenGenerator
|
||||
}
|
||||
|
||||
|
@ -76,9 +78,13 @@ func checkTokenPayload(payloadHash, payload []byte) error {
|
|||
}
|
||||
|
||||
func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo {
|
||||
return NewRefreshTokenRepoWithGenerator(dbm, refresh.DefaultRefreshTokenGenerator)
|
||||
}
|
||||
|
||||
func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo {
|
||||
return &refreshTokenRepo{
|
||||
dbMap: dbm,
|
||||
tokenGenerator: refresh.DefaultRefreshTokenGenerator,
|
||||
db: &db{dbm},
|
||||
tokenGenerator: gen,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,7 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
|||
ClientID: clientID,
|
||||
}
|
||||
|
||||
if err := r.dbMap.Insert(record); err != nil {
|
||||
if err := r.executor(nil).Insert(record); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
@ -143,7 +149,13 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
record, err := r.get(nil, tokenID)
|
||||
tx, err := r.begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := r.executor(tx)
|
||||
record, err := r.get(tx, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -156,7 +168,7 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
deleted, err := r.dbMap.Delete(record)
|
||||
deleted, err := exec.Delete(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -164,17 +176,10 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
|||
return refresh.ErrorInvalidToken
|
||||
}
|
||||
|
||||
return nil
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *refreshTokenRepo) executor(tx *gorp.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) {
|
||||
func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) {
|
||||
ex := r.executor(tx)
|
||||
result, err := ex.Get(refreshTokenModel{}, tokenID)
|
||||
if err != nil {
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/session"
|
||||
|
@ -124,16 +123,16 @@ func NewSessionRepo(dbm *gorp.DbMap) *SessionRepo {
|
|||
}
|
||||
|
||||
func NewSessionRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionRepo {
|
||||
return &SessionRepo{dbMap: dbm, clock: clock}
|
||||
return &SessionRepo{db: &db{dbm}, clock: clock}
|
||||
}
|
||||
|
||||
type SessionRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func (r *SessionRepo) Get(sessionID string) (*session.Session, error) {
|
||||
m, err := r.dbMap.Get(sessionModel{}, sessionID)
|
||||
m, err := r.executor(nil).Get(sessionModel{}, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -164,7 +163,7 @@ func (r *SessionRepo) Create(s session.Session) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.dbMap.Insert(sm)
|
||||
return r.executor(nil).Insert(sm)
|
||||
}
|
||||
|
||||
func (r *SessionRepo) Update(s session.Session) error {
|
||||
|
@ -172,7 +171,7 @@ func (r *SessionRepo) Update(s session.Session) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, err := r.dbMap.Update(sm)
|
||||
n, err := r.executor(nil).Update(sm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -183,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error {
|
|||
}
|
||||
|
||||
func (r *SessionRepo) purge() error {
|
||||
qt := pq.QuoteIdentifier(sessionTableName)
|
||||
qt := r.quote(sessionTableName)
|
||||
q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt)
|
||||
res, err := r.dbMap.Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
|
||||
res, err := r.executor(nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/session"
|
||||
|
@ -39,11 +38,11 @@ func NewSessionKeyRepo(dbm *gorp.DbMap) *SessionKeyRepo {
|
|||
}
|
||||
|
||||
func NewSessionKeyRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionKeyRepo {
|
||||
return &SessionKeyRepo{dbMap: dbm, clock: clock}
|
||||
return &SessionKeyRepo{db: &db{dbm}, clock: clock}
|
||||
}
|
||||
|
||||
type SessionKeyRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
|
@ -54,11 +53,11 @@ func (r *SessionKeyRepo) Push(sk session.SessionKey, exp time.Duration) error {
|
|||
ExpiresAt: r.clock.Now().Unix() + int64(exp.Seconds()),
|
||||
Stale: false,
|
||||
}
|
||||
return r.dbMap.Insert(skm)
|
||||
return r.executor(nil).Insert(skm)
|
||||
}
|
||||
|
||||
func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
||||
m, err := r.dbMap.Get(sessionKeyModel{}, key)
|
||||
m, err := r.executor(nil).Get(sessionKeyModel{}, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -77,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
|||
return "", errors.New("invalid session key")
|
||||
}
|
||||
|
||||
qt := pq.QuoteIdentifier(sessionKeyTableName)
|
||||
qt := r.quote(sessionKeyTableName)
|
||||
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
|
||||
res, err := r.dbMap.Exec(q, true, key, false)
|
||||
res, err := r.executor(nil).Exec(q, true, key, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -95,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
|||
}
|
||||
|
||||
func (r *SessionKeyRepo) purge() error {
|
||||
qt := pq.QuoteIdentifier(sessionKeyTableName)
|
||||
qt := r.quote(sessionKeyTableName)
|
||||
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
|
||||
res, err := r.dbMap.Exec(q, true, r.clock.Now().Unix())
|
||||
res, err := r.executor(nil).Exec(q, true, r.clock.Now().Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
72
db/translate/translate.go
Normal file
72
db/translate/translate.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
Package translate implements translation of driver specific SQL queries.
|
||||
*/
|
||||
package translate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"regexp"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
)
|
||||
|
||||
var (
|
||||
bindRegexp = regexp.MustCompile(`\$\d+`)
|
||||
trueRegexp = regexp.MustCompile(`\btrue\b`)
|
||||
)
|
||||
|
||||
// PostgresToSQLite translates github.com/lib/pq flavored SQL quries to github.com/mattn/go-sqlite3's flavor.
|
||||
//
|
||||
// It assumes that possitional bind arguements ($1, $2, etc.) are unqiue and in sequential order.
|
||||
func PostgresToSQLite(query string) string {
|
||||
query = bindRegexp.ReplaceAllString(query, "?")
|
||||
query = trueRegexp.ReplaceAllString(query, "1")
|
||||
return query
|
||||
}
|
||||
|
||||
// NewTranslatingExecutor returns an executor wrapping the existing executor. All query strings passed to
|
||||
// the executor will be run through the translate function before begin passed to the driver.
|
||||
func NewTranslatingExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor {
|
||||
return &executor{exec, translate}
|
||||
}
|
||||
|
||||
type executor struct {
|
||||
gorp.SqlExecutor
|
||||
Translate func(string) string
|
||||
}
|
||||
|
||||
func (e *executor) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return e.SqlExecutor.Exec(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
||||
return e.SqlExecutor.Select(i, e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectInt(query string, args ...interface{}) (int64, error) {
|
||||
return e.SqlExecutor.SelectInt(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
|
||||
return e.SqlExecutor.SelectNullInt(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectFloat(query string, args ...interface{}) (float64, error) {
|
||||
return e.SqlExecutor.SelectFloat(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
|
||||
return e.SqlExecutor.SelectNullFloat(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectStr(query string, args ...interface{}) (string, error) {
|
||||
return e.SqlExecutor.SelectStr(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
|
||||
return e.SqlExecutor.SelectNullStr(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectOne(holder interface{}, query string, args ...interface{}) error {
|
||||
return e.SqlExecutor.SelectOne(holder, e.Translate(query), args...)
|
||||
}
|
28
db/translate/translate_test.go
Normal file
28
db/translate/translate_test.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package translate
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPostgresToSQLite(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
query string
|
||||
want string
|
||||
}{
|
||||
{"SELECT * FROM foo", "SELECT * FROM foo"},
|
||||
{"SELECT * FROM %s", "SELECT * FROM %s"},
|
||||
{"SELECT * FROM foo WHERE is_admin=true", "SELECT * FROM foo WHERE is_admin=1"},
|
||||
{"SELECT * FROM foo WHERE is_admin=true;", "SELECT * FROM foo WHERE is_admin=1;"},
|
||||
{"SELECT * FROM foo WHERE is_admin=$10", "SELECT * FROM foo WHERE is_admin=?"},
|
||||
{"SELECT * FROM foo WHERE is_admin=$10;", "SELECT * FROM foo WHERE is_admin=?;"},
|
||||
{"SELECT * FROM foo WHERE name=$1 AND is_admin=$2;", "SELECT * FROM foo WHERE name=? AND is_admin=?;"},
|
||||
{"$1", "?"},
|
||||
{"$", "$"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := PostgresToSQLite(tt.query)
|
||||
if got != tt.want {
|
||||
t.Errorf("PostgresToSQLite(%q): want=%q, got=%q", tt.query, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
39
db/user.go
39
db/user.go
|
@ -8,7 +8,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
|
@ -42,7 +41,7 @@ func init() {
|
|||
|
||||
func NewUserRepo(dbm *gorp.DbMap) user.UserRepo {
|
||||
return &userRepo{
|
||||
dbMap: dbm,
|
||||
db: &db{dbm},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,7 +52,7 @@ func NewUserRepoFromUsers(dbm *gorp.DbMap, us []user.UserWithRemoteIdentities) (
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = repo.dbMap.Insert(um)
|
||||
err = repo.executor(nil).Insert(um)
|
||||
for _, ri := range u.RemoteIdentities {
|
||||
err = repo.AddRemoteIdentity(nil, u.User.ID, ri)
|
||||
if err != nil {
|
||||
|
@ -65,7 +64,7 @@ func NewUserRepoFromUsers(dbm *gorp.DbMap, us []user.UserWithRemoteIdentities) (
|
|||
}
|
||||
|
||||
type userRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
}
|
||||
|
||||
func (r *userRepo) Get(tx repo.Transaction, userID string) (user.User, error) {
|
||||
|
@ -107,9 +106,9 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err
|
|||
return user.ErrorInvalidID
|
||||
}
|
||||
|
||||
qt := pq.QuoteIdentifier(userTableName)
|
||||
qt := r.quote(userTableName)
|
||||
ex := r.executor(tx)
|
||||
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable)
|
||||
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -241,9 +240,8 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
|
|||
return nil, user.ErrorInvalidID
|
||||
}
|
||||
|
||||
qt := pq.QuoteIdentifier(remoteIdentityMappingTableName)
|
||||
rims, err := ex.Select(&remoteIdentityMappingModel{},
|
||||
fmt.Sprintf("select * from %s where user_id = $1", qt), userID)
|
||||
qt := r.quote(remoteIdentityMappingTableName)
|
||||
rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID)
|
||||
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
|
@ -273,9 +271,9 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
|
|||
}
|
||||
|
||||
func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) {
|
||||
qt := pq.QuoteIdentifier(userTableName)
|
||||
qt := r.quote(userTableName)
|
||||
ex := r.executor(tx)
|
||||
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s where admin=true", qt))
|
||||
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt))
|
||||
return int(i), err
|
||||
}
|
||||
|
||||
|
@ -290,12 +288,11 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
|
|||
}
|
||||
ex := r.executor(tx)
|
||||
|
||||
qt := pq.QuoteIdentifier(userTableName)
|
||||
qt := r.quote(userTableName)
|
||||
|
||||
// Ask for one more than needed so we know if there's more results, and
|
||||
// hence, whether a nextPageToken is necessary.
|
||||
ums, err := ex.Select(&userModel{},
|
||||
fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2 ", qt), maxResults+1, offset)
|
||||
ums, err := ex.Select(&userModel{}, fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2", qt), maxResults+1, offset)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -338,18 +335,6 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
|
|||
|
||||
}
|
||||
|
||||
func (r *userRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
return gorpTx
|
||||
}
|
||||
|
||||
func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
|
||||
ex := r.executor(tx)
|
||||
um, err := newUserModel(&usr)
|
||||
|
@ -412,7 +397,7 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot
|
|||
}
|
||||
|
||||
func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) {
|
||||
qt := pq.QuoteIdentifier(userTableName)
|
||||
qt := r.quote(userTableName)
|
||||
ex := r.executor(tx)
|
||||
var um userModel
|
||||
err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email)
|
||||
|
|
10
env
Normal file
10
env
Normal file
|
@ -0,0 +1,10 @@
|
|||
export GOPATH=${PWD}/Godeps/_workspace
|
||||
export GOBIN=${PWD}/bin
|
||||
|
||||
rm -rf $GOPATH/src/github.com/coreos/dex
|
||||
mkdir -p $GOPATH/src/github.com/coreos/
|
||||
|
||||
# Only attempt to link dex into godeps if it isn't already there
|
||||
[ -d $GOPATH/src/github.com/coreos/dex ] || ln -s ${PWD} $GOPATH/src/github.com/coreos/dex
|
||||
|
||||
LD_FLAGS="-X main.version=$(./git-version)"
|
|
@ -36,17 +36,16 @@ func connect(t *testing.T) *gorp.DbMap {
|
|||
if err != nil {
|
||||
t.Fatalf("Unable to connect to database: %v", err)
|
||||
}
|
||||
|
||||
if err = c.DropTablesIfExists(); err != nil {
|
||||
t.Fatalf("Unable to drop database tables: %v", err)
|
||||
}
|
||||
|
||||
if err = db.DropMigrationsTable(c); err != nil {
|
||||
panic(fmt.Sprintf("Unable to drop migration table: %v", err))
|
||||
t.Fatalf("Unable to drop migration table: %v", err)
|
||||
}
|
||||
|
||||
if _, err = db.MigrateToLatest(c); err != nil {
|
||||
panic(fmt.Sprintf("Unable to migrate: %v", err))
|
||||
t.Fatalf("Unable to migrate: %v", err)
|
||||
}
|
||||
|
||||
return c
|
||||
|
@ -157,12 +156,13 @@ func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
setRepo, err := db.NewPrivateKeySetRepo(connect(t), false, tt.setSecrets...)
|
||||
dbMap := connect(t)
|
||||
setRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.setSecrets...)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
getRepo, err := db.NewPrivateKeySetRepo(connect(t), false, tt.getSecrets...)
|
||||
getRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.getSecrets...)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
@ -377,9 +377,24 @@ func TestDBRefreshRepoCreate(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
_, err := r.Create(tt.userID, tt.clientID)
|
||||
if err != tt.err {
|
||||
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
|
||||
token, err := r.Create(tt.userID, tt.clientID)
|
||||
if err != nil {
|
||||
if tt.err == nil {
|
||||
t.Errorf("case %d: create failed: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tt.err != nil {
|
||||
t.Errorf("case %d: expected error, didn't get one", i)
|
||||
continue
|
||||
}
|
||||
userID, err := r.Verify(tt.clientID, token)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to verify good token: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if userID != tt.userID {
|
||||
t.Errorf("case %d: want userID=%s, got userID=%s", i, tt.userID, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,9 @@ var connConfigExample = []byte(`[
|
|||
]`)
|
||||
|
||||
func TestDexctlCommands(t *testing.T) {
|
||||
if strings.HasPrefix(dsn, "sqlite3://") {
|
||||
t.Skip("only test dexctl conmand with postgres")
|
||||
}
|
||||
tempFile, err := ioutil.TempFile("", "dexctl_functional_tests_")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"gopkg.in/ldap.v2"
|
||||
|
@ -93,13 +94,17 @@ func TestConnectorLDAPConnectFail(t *testing.T) {
|
|||
|
||||
templates := template.New(connector.LDAPLoginPageTemplateName)
|
||||
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||
ccr := db.NewConnectorConfigRepo(db.NewMemDB())
|
||||
err := ccr.Set(
|
||||
[]connector.ConnectorConfig{&connector.LDAPConnectorConfig{
|
||||
ID: "ldap",
|
||||
ServerHost: ldapHost,
|
||||
ServerPort: ldapPort + 1,
|
||||
}},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cc, err := ccr.GetConnectorByID(tx, "ldap")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -121,13 +126,17 @@ func TestConnectorLDAPConnectSuccess(t *testing.T) {
|
|||
|
||||
templates := template.New(connector.LDAPLoginPageTemplateName)
|
||||
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||
ccr := db.NewConnectorConfigRepo(db.NewMemDB())
|
||||
err := ccr.Set(
|
||||
[]connector.ConnectorConfig{&connector.LDAPConnectorConfig{
|
||||
ID: "ldap",
|
||||
ServerHost: ldapHost,
|
||||
ServerPort: ldapPort,
|
||||
}},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cc, err := ccr.GetConnectorByID(tx, "ldap")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -149,7 +158,8 @@ func TestConnectorLDAPcaFilecertFileConnectTLS(t *testing.T) {
|
|||
|
||||
templates := template.New(connector.LDAPLoginPageTemplateName)
|
||||
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||
ccr := db.NewConnectorConfigRepo(db.NewMemDB())
|
||||
err := ccr.Set(
|
||||
[]connector.ConnectorConfig{&connector.LDAPConnectorConfig{
|
||||
ID: "ldap",
|
||||
ServerHost: ldapHost,
|
||||
|
@ -160,6 +170,9 @@ func TestConnectorLDAPcaFilecertFileConnectTLS(t *testing.T) {
|
|||
CaFile: "/tmp/openldap-ca.pem",
|
||||
}},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cc, err := ccr.GetConnectorByID(tx, "ldap")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -181,7 +194,8 @@ func TestConnectorLDAPcaFilecertFileConnectSSL(t *testing.T) {
|
|||
|
||||
templates := template.New(connector.LDAPLoginPageTemplateName)
|
||||
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||
ccr := db.NewConnectorConfigRepo(db.NewMemDB())
|
||||
err := ccr.Set(
|
||||
[]connector.ConnectorConfig{&connector.LDAPConnectorConfig{
|
||||
ID: "ldap",
|
||||
ServerHost: ldapHost,
|
||||
|
@ -192,6 +206,9 @@ func TestConnectorLDAPcaFilecertFileConnectSSL(t *testing.T) {
|
|||
CaFile: "/tmp/openldap-ca.pem",
|
||||
}},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cc, err := ccr.GetConnectorByID(tx, "ldap")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -1,25 +1,24 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"encoding/base64"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
)
|
||||
|
||||
var makeTestClientIdentityRepoFromClients func(clients []oidc.ClientIdentity) client.ClientIdentityRepo
|
||||
|
||||
var (
|
||||
testClients = []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "client1",
|
||||
Secret: "secret-1",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret-1")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -33,7 +32,7 @@ var (
|
|||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "client2",
|
||||
Secret: "secret-2",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -47,34 +46,19 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
func newClientIdentityRepo(t *testing.T) client.ClientIdentityRepo {
|
||||
dsn := os.Getenv("DEX_TEST_DSN")
|
||||
var dbMap *gorp.DbMap
|
||||
if dsn == "" {
|
||||
makeTestClientIdentityRepoFromClients = makeTestClientIdentityRepoMem
|
||||
dbMap = db.NewMemDB()
|
||||
} else {
|
||||
makeTestClientIdentityRepoFromClients = makeTestClientIdentityRepoDB(dsn)
|
||||
dbMap = connect(t)
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestClientIdentityRepoMem(clients []oidc.ClientIdentity) client.ClientIdentityRepo {
|
||||
return client.NewClientIdentityRepo(clients)
|
||||
}
|
||||
|
||||
func makeTestClientIdentityRepoDB(dsn string) func([]oidc.ClientIdentity) client.ClientIdentityRepo {
|
||||
return func(clients []oidc.ClientIdentity) client.ClientIdentityRepo {
|
||||
c := initDB(dsn)
|
||||
|
||||
repo, err := db.NewClientIdentityRepoFromClients(c, clients)
|
||||
repo, err := db.NewClientIdentityRepoFromClients(dbMap, testClients)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unable to add clients: %v", err))
|
||||
t.Fatalf("failed to create client repo from clients: %v", err)
|
||||
}
|
||||
return repo
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func makeTestClientIdentityRepo() client.ClientIdentityRepo {
|
||||
return makeTestClientIdentityRepoFromClients(testClients)
|
||||
}
|
||||
|
||||
func TestGetSetAdminClient(t *testing.T) {
|
||||
|
@ -113,12 +97,14 @@ func TestGetSetAdminClient(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
Tests:
|
||||
for i, tt := range tests {
|
||||
repo := makeTestClientIdentityRepo()
|
||||
repo := newClientIdentityRepo(t)
|
||||
for _, cid := range startAdmins {
|
||||
err := repo.SetDexAdmin(cid, true)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
t.Errorf("case %d: failed to set dex admin: %v", i, err)
|
||||
continue Tests
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -130,7 +116,7 @@ func TestGetSetAdminClient(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if gotAdmin != tt.wantAdmin {
|
||||
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
|
||||
|
@ -138,12 +124,12 @@ func TestGetSetAdminClient(t *testing.T) {
|
|||
|
||||
err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
gotAdmin, err = repo.IsDexAdmin(tt.cid)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if gotAdmin != tt.setAdmin {
|
||||
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)
|
||||
|
|
|
@ -1,36 +1,27 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
)
|
||||
|
||||
type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo
|
||||
|
||||
var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory
|
||||
|
||||
func init() {
|
||||
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
|
||||
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
|
||||
func newConnectorConfigRepo(t *testing.T, configs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
|
||||
var dbMap *gorp.DbMap
|
||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||
dbMap = db.NewMemDB()
|
||||
} else {
|
||||
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn)
|
||||
dbMap = connect(t)
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory {
|
||||
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
|
||||
dbMap := initDB(dsn)
|
||||
|
||||
repo := db.NewConnectorConfigRepo(dbMap)
|
||||
if err := repo.Set(cfgs); err != nil {
|
||||
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
|
||||
if err := repo.Set(configs); err != nil {
|
||||
t.Fatalf("Unable to set connector configs: %v", err)
|
||||
}
|
||||
return repo
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectorConfigRepoGetByID(t *testing.T) {
|
||||
|
@ -63,7 +54,7 @@ func TestConnectorConfigRepoGetByID(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs)
|
||||
repo := newConnectorConfigRepo(t, tt.cfgs)
|
||||
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
|
||||
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
|
||||
}
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/user"
|
||||
)
|
||||
|
||||
var makeTestPasswordInfoRepo func() user.PasswordInfoRepo
|
||||
|
||||
var (
|
||||
testPWs = []user.PasswordInfo{
|
||||
{
|
||||
|
@ -23,30 +21,18 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
dsn := os.Getenv("DEX_TEST_DSN")
|
||||
if dsn == "" {
|
||||
makeTestPasswordInfoRepo = makeTestPasswordInfoRepoMem
|
||||
func newPasswordInfoRepo(t *testing.T) user.PasswordInfoRepo {
|
||||
var dbMap *gorp.DbMap
|
||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||
dbMap = db.NewMemDB()
|
||||
} else {
|
||||
makeTestPasswordInfoRepo = makeTestPasswordInfoRepoDB(dsn)
|
||||
dbMap = connect(t)
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestPasswordInfoRepoMem() user.PasswordInfoRepo {
|
||||
return user.NewPasswordInfoRepoFromPasswordInfos(testPWs)
|
||||
}
|
||||
|
||||
func makeTestPasswordInfoRepoDB(dsn string) func() user.PasswordInfoRepo {
|
||||
return func() user.PasswordInfoRepo {
|
||||
c := initDB(dsn)
|
||||
|
||||
repo := db.NewPasswordInfoRepo(c)
|
||||
err := user.LoadPasswordInfos(repo, testPWs)
|
||||
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, testPWs)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unable to add passwordInfos: %v", err))
|
||||
t.Fatalf("Unable to add password infos: %v", err)
|
||||
}
|
||||
return repo
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreatePasswordInfo(t *testing.T) {
|
||||
|
@ -87,7 +73,7 @@ func TestCreatePasswordInfo(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestPasswordInfoRepo()
|
||||
repo := newPasswordInfoRepo(t)
|
||||
err := repo.Create(nil, tt.pw)
|
||||
if tt.err != nil {
|
||||
if err != tt.err {
|
||||
|
@ -142,7 +128,7 @@ func TestUpdatePasswordInfo(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestPasswordInfoRepo()
|
||||
repo := newPasswordInfoRepo(t)
|
||||
err := repo.Update(nil, tt.pw)
|
||||
if tt.err != nil {
|
||||
if err != tt.err {
|
||||
|
|
|
@ -12,48 +12,26 @@ import (
|
|||
"github.com/coreos/dex/session"
|
||||
)
|
||||
|
||||
var makeTestSessionRepo func() (session.SessionRepo, clockwork.FakeClock)
|
||||
var makeTestSessionKeyRepo func() (session.SessionKeyRepo, clockwork.FakeClock)
|
||||
|
||||
func init() {
|
||||
dsn := os.Getenv("DEX_TEST_DSN")
|
||||
if dsn == "" {
|
||||
makeTestSessionRepo = makeTestSessionRepoMem
|
||||
makeTestSessionKeyRepo = makeTestSessionKeyRepoMem
|
||||
} else {
|
||||
makeTestSessionRepo = makeTestSessionRepoDB(dsn)
|
||||
makeTestSessionKeyRepo = makeTestSessionKeyRepoDB(dsn)
|
||||
func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
|
||||
clock := clockwork.NewFakeClock()
|
||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||
return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock
|
||||
}
|
||||
dbMap := connect(t)
|
||||
return db.NewSessionRepoWithClock(dbMap, clock), clock
|
||||
}
|
||||
|
||||
func makeTestSessionRepoMem() (session.SessionRepo, clockwork.FakeClock) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
return session.NewSessionRepoWithClock(fc), fc
|
||||
}
|
||||
|
||||
func makeTestSessionRepoDB(dsn string) func() (session.SessionRepo, clockwork.FakeClock) {
|
||||
return func() (session.SessionRepo, clockwork.FakeClock) {
|
||||
c := initDB(dsn)
|
||||
fc := clockwork.NewFakeClock()
|
||||
return db.NewSessionRepoWithClock(c, fc), fc
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestSessionKeyRepoMem() (session.SessionKeyRepo, clockwork.FakeClock) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
return session.NewSessionKeyRepoWithClock(fc), fc
|
||||
}
|
||||
|
||||
func makeTestSessionKeyRepoDB(dsn string) func() (session.SessionKeyRepo, clockwork.FakeClock) {
|
||||
return func() (session.SessionKeyRepo, clockwork.FakeClock) {
|
||||
c := initDB(dsn)
|
||||
fc := clockwork.NewFakeClock()
|
||||
return db.NewSessionKeyRepoWithClock(c, fc), fc
|
||||
func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
|
||||
clock := clockwork.NewFakeClock()
|
||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||
return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock
|
||||
}
|
||||
dbMap := connect(t)
|
||||
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
|
||||
}
|
||||
|
||||
func TestSessionKeyRepoPopNoExist(t *testing.T) {
|
||||
r, _ := makeTestSessionKeyRepo()
|
||||
r, _ := newSessionKeyRepo(t)
|
||||
|
||||
_, err := r.Pop("123")
|
||||
if err == nil {
|
||||
|
@ -62,7 +40,7 @@ func TestSessionKeyRepoPopNoExist(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionKeyRepoPushPop(t *testing.T) {
|
||||
r, _ := makeTestSessionKeyRepo()
|
||||
r, _ := newSessionKeyRepo(t)
|
||||
|
||||
key := "123"
|
||||
sessionID := "456"
|
||||
|
@ -80,7 +58,7 @@ func TestSessionKeyRepoPushPop(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionKeyRepoExpired(t *testing.T) {
|
||||
r, fc := makeTestSessionKeyRepo()
|
||||
r, fc := newSessionKeyRepo(t)
|
||||
|
||||
key := "123"
|
||||
sessionID := "456"
|
||||
|
@ -96,7 +74,7 @@ func TestSessionKeyRepoExpired(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionRepoGetNoExist(t *testing.T) {
|
||||
r, _ := makeTestSessionRepo()
|
||||
r, _ := newSessionRepo(t)
|
||||
|
||||
ses, err := r.Get("123")
|
||||
if ses != nil {
|
||||
|
@ -129,7 +107,7 @@ func TestSessionRepoCreateGet(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := makeTestSessionRepo()
|
||||
r, _ := newSessionRepo(t)
|
||||
|
||||
r.Create(tt)
|
||||
|
||||
|
@ -166,7 +144,7 @@ func TestSessionRepoCreateUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := makeTestSessionRepo()
|
||||
r, _ := newSessionRepo(t)
|
||||
r.Create(tt.initial)
|
||||
|
||||
ses, _ := r.Get(tt.initial.ID)
|
||||
|
@ -186,7 +164,7 @@ func TestSessionRepoCreateUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionRepoUpdateNoExist(t *testing.T) {
|
||||
r, _ := makeTestSessionRepo()
|
||||
r, _ := newSessionRepo(t)
|
||||
|
||||
err := r.Update(session.Session{ID: "123", ClientState: "boom"})
|
||||
if err == nil {
|
||||
|
|
|
@ -1,28 +1,38 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/go-gorp/gorp"
|
||||
)
|
||||
|
||||
func initDB(dsn string) *gorp.DbMap {
|
||||
func connect(t *testing.T) *gorp.DbMap {
|
||||
dsn := os.Getenv("DEX_TEST_DSN")
|
||||
if dsn == "" {
|
||||
t.Fatal("DEX_TEST_DSN environment variable not set")
|
||||
}
|
||||
c, err := db.NewConnection(db.Config{DSN: dsn})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unable to connect to database: %v", err))
|
||||
t.Fatalf("Unable to connect to database: %v", err)
|
||||
}
|
||||
|
||||
if err = c.DropTablesIfExists(); err != nil {
|
||||
panic(fmt.Sprintf("Unable to drop database tables: %v", err))
|
||||
t.Fatalf("Unable to drop database tables: %v", err)
|
||||
}
|
||||
|
||||
if err = db.DropMigrationsTable(c); err != nil {
|
||||
panic(fmt.Sprintf("Unable to drop migration table: %v", err))
|
||||
t.Fatalf("Unable to drop migration table: %v", err)
|
||||
}
|
||||
|
||||
if _, err = db.MigrateToLatest(c); err != nil {
|
||||
panic(fmt.Sprintf("Unable to migrate: %v", err))
|
||||
n, err := db.MigrateToLatest(c)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to migrate: %v", err)
|
||||
}
|
||||
if n == 0 {
|
||||
t.Fatalf("No migrations performed")
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
|
|
@ -7,14 +7,13 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/user"
|
||||
)
|
||||
|
||||
var makeTestUserRepoFromUsers func(users []user.UserWithRemoteIdentities) user.UserRepo
|
||||
|
||||
var (
|
||||
testUsers = []user.UserWithRemoteIdentities{
|
||||
{
|
||||
|
@ -47,34 +46,21 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
dsn := os.Getenv("DEX_TEST_DSN")
|
||||
if dsn == "" {
|
||||
makeTestUserRepoFromUsers = makeTestUserRepoMem
|
||||
} else {
|
||||
makeTestUserRepoFromUsers = makeTestUserRepoDB(dsn)
|
||||
func newUserRepo(t *testing.T, users []user.UserWithRemoteIdentities) user.UserRepo {
|
||||
if users == nil {
|
||||
users = []user.UserWithRemoteIdentities{}
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestUserRepoMem(users []user.UserWithRemoteIdentities) user.UserRepo {
|
||||
return user.NewUserRepoFromUsers(users)
|
||||
}
|
||||
|
||||
func makeTestUserRepoDB(dsn string) func([]user.UserWithRemoteIdentities) user.UserRepo {
|
||||
return func(users []user.UserWithRemoteIdentities) user.UserRepo {
|
||||
c := initDB(dsn)
|
||||
|
||||
repo, err := db.NewUserRepoFromUsers(c, users)
|
||||
var dbMap *gorp.DbMap
|
||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||
dbMap = db.NewMemDB()
|
||||
} else {
|
||||
dbMap = connect(t)
|
||||
}
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unable to add users: %v", err))
|
||||
t.Fatalf("Unable to add users: %v", err)
|
||||
}
|
||||
return repo
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func makeTestUserRepo() user.UserRepo {
|
||||
return makeTestUserRepoFromUsers(testUsers)
|
||||
}
|
||||
|
||||
func TestNewUser(t *testing.T) {
|
||||
|
@ -137,7 +123,7 @@ func TestNewUser(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestUserRepo()
|
||||
repo := newUserRepo(t, testUsers)
|
||||
err := repo.Create(nil, tt.user)
|
||||
if tt.err != nil {
|
||||
if err != tt.err {
|
||||
|
@ -209,7 +195,7 @@ func TestUpdateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestUserRepo()
|
||||
repo := newUserRepo(t, testUsers)
|
||||
err := repo.Update(nil, tt.user)
|
||||
if tt.err != nil {
|
||||
if err != tt.err {
|
||||
|
@ -269,7 +255,7 @@ func TestDisableUser(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestUserRepo()
|
||||
repo := newUserRepo(t, testUsers)
|
||||
err := repo.Disable(nil, tt.id, tt.disable)
|
||||
switch {
|
||||
case err != tt.err:
|
||||
|
@ -320,7 +306,7 @@ func TestAttachRemoteIdentity(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestUserRepo()
|
||||
repo := newUserRepo(t, testUsers)
|
||||
err := repo.AddRemoteIdentity(nil, tt.id, tt.rid)
|
||||
if tt.err != nil {
|
||||
if err != tt.err {
|
||||
|
@ -390,7 +376,7 @@ func TestRemoveRemoteIdentity(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestUserRepo()
|
||||
repo := newUserRepo(t, testUsers)
|
||||
err := repo.RemoveRemoteIdentity(nil, tt.id, tt.rid)
|
||||
if tt.err != nil {
|
||||
if err != tt.err {
|
||||
|
@ -433,59 +419,6 @@ func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int
|
|||
return -1
|
||||
}
|
||||
|
||||
func TestNewUserRepoFromUsers(t *testing.T) {
|
||||
tests := []struct {
|
||||
users []user.UserWithRemoteIdentities
|
||||
}{
|
||||
{
|
||||
users: []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "123",
|
||||
Email: "email123@example.com",
|
||||
},
|
||||
RemoteIdentities: []user.RemoteIdentity{},
|
||||
},
|
||||
{
|
||||
User: user.User{
|
||||
ID: "456",
|
||||
Email: "email456@example.com",
|
||||
},
|
||||
RemoteIdentities: []user.RemoteIdentity{
|
||||
{
|
||||
ID: "remoteID",
|
||||
ConnectorID: "connID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := user.NewUserRepoFromUsers(tt.users)
|
||||
for _, want := range tt.users {
|
||||
gotUser, err := repo.Get(nil, want.User.ID)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want nil err: %v", i, err)
|
||||
}
|
||||
|
||||
gotRIDs, err := repo.GetRemoteIdentities(nil, want.User.ID)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want nil err: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want.User, gotUser) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, want.User, gotUser)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want.RemoteIdentities, gotRIDs) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, want.RemoteIdentities, gotRIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetByEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
email string
|
||||
|
@ -502,7 +435,7 @@ func TestGetByEmail(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestUserRepo()
|
||||
repo := newUserRepo(t, testUsers)
|
||||
gotUser, gotErr := repo.GetByEmail(nil, tt.email)
|
||||
if tt.wantErr != nil {
|
||||
if tt.wantErr != gotErr {
|
||||
|
@ -566,7 +499,7 @@ func TestGetAdminCount(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestUserRepo()
|
||||
repo := newUserRepo(t, testUsers)
|
||||
for _, addUser := range tt.addUsers {
|
||||
err := repo.Create(nil, addUser)
|
||||
if err != nil {
|
||||
|
@ -621,7 +554,7 @@ func TestList(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestUserRepoFromUsers(repoUsers)
|
||||
repo := newUserRepo(t, repoUsers)
|
||||
var tok string
|
||||
gotIDs := [][]string{}
|
||||
done := false
|
||||
|
@ -651,7 +584,7 @@ func TestList(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestListErrorNotFound(t *testing.T) {
|
||||
repo := makeTestUserRepoFromUsers(nil)
|
||||
repo := newUserRepo(t, nil)
|
||||
_, _, err := repo.List(nil, user.UserFilter{}, 10, "")
|
||||
if err != user.ErrorNotFound {
|
||||
t.Errorf("want=%q, got=%q", user.ErrorNotFound, err)
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package integration
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
@ -13,7 +15,12 @@ func TestClientCreate(t *testing.T) {
|
|||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "72de74a9",
|
||||
Secret: "XXX",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https://", Host: "authn.example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
}
|
||||
cis := []oidc.ClientIdentity{ci}
|
||||
|
@ -54,7 +61,7 @@ func TestClientCreate(t *testing.T) {
|
|||
call := svc.Clients.Create(newClientInput)
|
||||
newClient, err := call.Do()
|
||||
if err != nil {
|
||||
t.Errorf("Call to create client API failed: %v", err)
|
||||
t.Fatalf("Call to create client API failed: %v", err)
|
||||
}
|
||||
|
||||
if newClient.Id == "" {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package integration
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -11,7 +12,7 @@ import (
|
|||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
@ -21,7 +22,7 @@ var (
|
|||
|
||||
testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"}
|
||||
testClientID = "XXX"
|
||||
testClientSecret = "yyy"
|
||||
testClientSecret = base64.URLEncoding.EncodeToString([]byte("yyy"))
|
||||
testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
|
||||
testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
|
||||
testPrivKey, _ = key.GeneratePrivateKey()
|
||||
|
@ -45,13 +46,32 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro
|
|||
}
|
||||
|
||||
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) {
|
||||
ur := user.NewUserRepoFromUsers(users)
|
||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
|
||||
dbMap := db.NewMemDB()
|
||||
ur := func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
pwr := func() user.PasswordInfoRepo {
|
||||
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, passwords)
|
||||
if err != nil {
|
||||
panic("Failed to create password info repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
|
||||
)
|
||||
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
ccr := func() connector.ConnectorConfigRepo {
|
||||
repo := db.NewConnectorConfigRepo(dbMap)
|
||||
c := []connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}}
|
||||
if err := repo.Set(c); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
um := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||
um.Clock = clock
|
||||
return ur, pwr, um
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package integration
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
@ -8,12 +9,12 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
phttp "github.com/coreos/dex/pkg/http"
|
||||
"github.com/coreos/dex/refresh/refreshtest"
|
||||
"github.com/coreos/dex/server"
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
|
@ -22,6 +23,7 @@ import (
|
|||
)
|
||||
|
||||
func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
|
||||
dbMap := db.NewMemDB()
|
||||
k, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to generate private key: %v", err)
|
||||
|
@ -32,12 +34,16 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientIdentityRepo, err := db.NewClientIdentityRepoFromClients(dbMap, cis)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
|
||||
srv := &server.Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
ClientIdentityRepo: client.NewClientIdentityRepo(cis),
|
||||
ClientIdentityRepo: clientIdentityRepo,
|
||||
SessionManager: sm,
|
||||
}
|
||||
|
||||
|
@ -113,14 +119,18 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "72de74a9",
|
||||
Secret: "XXX",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
|
||||
},
|
||||
}
|
||||
|
||||
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
||||
dbMap := db.NewMemDB()
|
||||
cir, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: " + err.Error())
|
||||
}
|
||||
|
||||
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
|
||||
|
||||
k, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
@ -138,16 +148,13 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
Email: "testemail@example.com",
|
||||
DisplayName: "displayname",
|
||||
}
|
||||
userRepo := user.NewUserRepo()
|
||||
userRepo := db.NewUserRepo(db.NewMemDB())
|
||||
if err := userRepo.Create(nil, usr); err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
passwordInfoRepo := user.NewPasswordInfoRepo()
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
passwordInfoRepo := db.NewPasswordInfoRepo(db.NewMemDB())
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &server.Server{
|
||||
IssuerURL: issuerURL,
|
||||
|
@ -255,7 +262,7 @@ func TestHTTPClientCredsToken(t *testing.T) {
|
|||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "72de74a9",
|
||||
Secret: "XXX",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
|
||||
},
|
||||
}
|
||||
cis := []oidc.ClientIdentity{ci}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package integration
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -15,6 +16,7 @@ import (
|
|||
"google.golang.org/api/googleapi"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/server"
|
||||
"github.com/coreos/dex/user"
|
||||
|
@ -97,7 +99,8 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
|||
|
||||
_, _, um := makeUserObjects(userUsers, userPasswords)
|
||||
|
||||
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
cir := func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: testClientID,
|
||||
|
@ -112,7 +115,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
|||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: userBadClientID,
|
||||
Secret: "secret",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -121,6 +124,11 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
|||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to create client identity repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
cir.SetDexAdmin(testClientID, true)
|
||||
|
||||
|
|
|
@ -3,16 +3,17 @@ package refreshtest
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/refresh"
|
||||
)
|
||||
|
||||
// NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase.
|
||||
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
|
||||
func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) {
|
||||
func NewTestRefreshTokenRepo() refresh.RefreshTokenRepo {
|
||||
var tokenIdx int
|
||||
tokenGenerator := func() ([]byte, error) {
|
||||
tokenIdx++
|
||||
return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
|
||||
}
|
||||
return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil
|
||||
return db.NewRefreshTokenRepoWithGenerator(db.NewMemDB(), tokenGenerator)
|
||||
}
|
||||
|
|
123
refresh/repo.go
123
refresh/repo.go
|
@ -1,13 +1,8 @@
|
|||
package refresh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -53,121 +48,3 @@ type RefreshTokenRepo interface {
|
|||
// Revoke deletes the refresh token if the token belongs to the given userID.
|
||||
Revoke(userID, token string) error
|
||||
}
|
||||
|
||||
type refreshToken struct {
|
||||
payload []byte
|
||||
userID string
|
||||
clientID string
|
||||
}
|
||||
|
||||
type memRefreshTokenRepo struct {
|
||||
store map[int]refreshToken
|
||||
tokenGenerator RefreshTokenGenerator
|
||||
}
|
||||
|
||||
// buildToken combines the token ID and token payload to create a new token.
|
||||
func buildToken(tokenID int, tokenPayload []byte) string {
|
||||
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
|
||||
}
|
||||
|
||||
// parseToken parses a token and returns the token ID and token payload.
|
||||
func parseToken(token string) (int, []byte, error) {
|
||||
parts := strings.SplitN(token, TokenDelimer, 2)
|
||||
if len(parts) != 2 {
|
||||
return -1, nil, ErrorInvalidToken
|
||||
}
|
||||
id, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return -1, nil, ErrorInvalidToken
|
||||
}
|
||||
tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return -1, nil, ErrorInvalidToken
|
||||
}
|
||||
return id, tokenPayload, nil
|
||||
}
|
||||
|
||||
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.
|
||||
func NewRefreshTokenRepo() RefreshTokenRepo {
|
||||
return NewRefreshTokenRepoWithTokenGenerator(DefaultRefreshTokenGenerator)
|
||||
}
|
||||
|
||||
func NewRefreshTokenRepoWithTokenGenerator(tokenGenerator RefreshTokenGenerator) RefreshTokenRepo {
|
||||
repo := &memRefreshTokenRepo{}
|
||||
repo.store = make(map[int]refreshToken)
|
||||
repo.tokenGenerator = tokenGenerator
|
||||
return repo
|
||||
}
|
||||
|
||||
func (r *memRefreshTokenRepo) Create(userID, clientID string) (string, error) {
|
||||
// Validate userID.
|
||||
if userID == "" {
|
||||
return "", ErrorInvalidUserID
|
||||
}
|
||||
|
||||
// Validate clientID.
|
||||
if clientID == "" {
|
||||
return "", ErrorInvalidClientID
|
||||
}
|
||||
|
||||
// Generate and store token.
|
||||
tokenPayload, err := r.tokenGenerator.Generate()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
tokenID := len(r.store) // Should only be used in single threaded tests.
|
||||
|
||||
// No limits on the number of tokens per user/client for this in-memory repo.
|
||||
r.store[tokenID] = refreshToken{
|
||||
payload: tokenPayload,
|
||||
userID: userID,
|
||||
clientID: clientID,
|
||||
}
|
||||
return buildToken(tokenID, tokenPayload), nil
|
||||
}
|
||||
|
||||
func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) {
|
||||
tokenID, tokenPayload, err := parseToken(token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
record, ok := r.store[tokenID]
|
||||
if !ok {
|
||||
return "", ErrorInvalidToken
|
||||
}
|
||||
|
||||
if !bytes.Equal(record.payload, tokenPayload) {
|
||||
return "", ErrorInvalidToken
|
||||
}
|
||||
|
||||
if record.clientID != clientID {
|
||||
return "", ErrorInvalidClientID
|
||||
}
|
||||
|
||||
return record.userID, nil
|
||||
}
|
||||
|
||||
func (r *memRefreshTokenRepo) Revoke(userID, token string) error {
|
||||
tokenID, tokenPayload, err := parseToken(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record, ok := r.store[tokenID]
|
||||
if !ok {
|
||||
return ErrorInvalidToken
|
||||
}
|
||||
|
||||
if !bytes.Equal(record.payload, tokenPayload) {
|
||||
return ErrorInvalidToken
|
||||
}
|
||||
|
||||
if record.userID != userID {
|
||||
return ErrorInvalidUserID
|
||||
}
|
||||
|
||||
delete(r.store, tokenID)
|
||||
return nil
|
||||
}
|
||||
|
|
28
repo/repo.go
28
repo/repo.go
|
@ -1,7 +1,5 @@
|
|||
package repo
|
||||
|
||||
import "errors"
|
||||
|
||||
// Transaction is an abstraction of transactions typically found in database systems.
|
||||
// One of Commit() or Rollback() must be called on each transaction.
|
||||
type Transaction interface {
|
||||
|
@ -13,29 +11,3 @@ type Transaction interface {
|
|||
}
|
||||
|
||||
type TransactionFactory func() (Transaction, error)
|
||||
|
||||
// InMemTransaction satisifies the Transaction interface for in-memory systems.
|
||||
// However, the only thing it really does is ensure that the same transaction is
|
||||
// can't be committed/rolled back more than once. As such, this can lead to data
|
||||
// corruption and should not be used in production systems.
|
||||
type InMemTransaction bool
|
||||
|
||||
func InMemTransactionFactory() (Transaction, error) {
|
||||
return new(InMemTransaction), nil
|
||||
}
|
||||
|
||||
func (i *InMemTransaction) Commit() error {
|
||||
return i.commitOrRollback()
|
||||
}
|
||||
|
||||
func (i *InMemTransaction) Rollback() error {
|
||||
return i.commitOrRollback()
|
||||
}
|
||||
|
||||
func (i *InMemTransaction) commitOrRollback() error {
|
||||
if *i {
|
||||
return errors.New("Already committed/rolled-back.")
|
||||
}
|
||||
*i = true
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
|
@ -26,9 +29,18 @@ func TestClientToken(t *testing.T) {
|
|||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: validClientID,
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
}
|
||||
repo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
|
||||
privKey, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
@ -102,7 +114,7 @@ func TestClientToken(t *testing.T) {
|
|||
// empty repo
|
||||
{
|
||||
keys: []key.PublicKey{pubKey},
|
||||
repo: client.NewClientIdentityRepo(nil),
|
||||
repo: db.NewClientIdentityRepo(db.NewMemDB()),
|
||||
header: fmt.Sprintf("BEARER %s", validJWT),
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -9,12 +10,14 @@ import (
|
|||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
)
|
||||
|
||||
func makeBody(s string) io.ReadCloser {
|
||||
|
@ -24,7 +27,7 @@ func makeBody(s string) io.ReadCloser {
|
|||
func TestCreateInvalidRequest(t *testing.T) {
|
||||
u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
|
||||
h := http.Header{"Content-Type": []string{"application/json"}}
|
||||
repo := client.NewClientIdentityRepo(nil)
|
||||
repo := db.NewClientIdentityRepo(db.NewMemDB())
|
||||
res := &clientResource{repo: repo}
|
||||
tests := []struct {
|
||||
req *http.Request
|
||||
|
@ -115,7 +118,7 @@ func TestCreateInvalidRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
repo := client.NewClientIdentityRepo(nil)
|
||||
repo := db.NewClientIdentityRepo(db.NewMemDB())
|
||||
res := &clientResource{repo: repo}
|
||||
tests := [][]string{
|
||||
[]string{"http://example.com"},
|
||||
|
@ -168,6 +171,11 @@ func TestCreate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
|
||||
b64Encode := func(s string) string {
|
||||
return base64.URLEncoding.EncodeToString([]byte(s))
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
cs []oidc.ClientIdentity
|
||||
want []*schema.Client
|
||||
|
@ -181,7 +189,7 @@ func TestList(t *testing.T) {
|
|||
{
|
||||
cs: []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{ID: "foo", Secret: "bar"},
|
||||
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "example.com"},
|
||||
|
@ -200,7 +208,7 @@ func TestList(t *testing.T) {
|
|||
{
|
||||
cs: []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{ID: "foo", Secret: "bar"},
|
||||
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "example.com"},
|
||||
|
@ -208,7 +216,7 @@ func TestList(t *testing.T) {
|
|||
},
|
||||
},
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{ID: "biz", Secret: "bang"},
|
||||
Credentials: oidc.ClientCredentials{ID: "biz", Secret: b64Encode("bang")},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"},
|
||||
|
@ -230,7 +238,11 @@ func TestList(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := client.NewClientIdentityRepo(tt.cs)
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), tt.cs)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
|
||||
continue
|
||||
}
|
||||
res := &clientResource{repo: repo}
|
||||
|
||||
r, err := http.NewRequest("GET", "http://example.com/clients", nil)
|
||||
|
@ -248,9 +260,17 @@ func TestList(t *testing.T) {
|
|||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Errorf("case %d: unexpected error=%v", i, err)
|
||||
}
|
||||
sort.Sort(byClientId(tt.want))
|
||||
sort.Sort(byClientId(resp.Clients))
|
||||
|
||||
if !reflect.DeepEqual(tt.want, resp.Clients) {
|
||||
t.Errorf("case %d: invalid response body, want=%#v, got=%#v", i, tt.want, resp.Clients)
|
||||
if diff := pretty.Compare(tt.want, resp.Clients); diff != "" {
|
||||
t.Errorf("case %d: invalid response body: %s", i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type byClientId []*schema.Client
|
||||
|
||||
func (b byClientId) Len() int { return len(b) }
|
||||
func (b byClientId) Less(i, j int) bool { return b[i].Id < b[j].Id }
|
||||
func (b byClientId) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
|
@ -11,18 +12,17 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/coreos/pkg/health"
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
type ServerConfig struct {
|
||||
|
@ -101,20 +101,21 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
return err
|
||||
}
|
||||
|
||||
dbMap := db.NewMemDB()
|
||||
|
||||
ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour))
|
||||
kRepo := key.NewPrivateKeySetRepo()
|
||||
if err = kRepo.Set(ks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cf, err := os.Open(cfg.ClientsFile)
|
||||
clients, err := loadClients(cfg.ClientsFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
|
||||
}
|
||||
defer cf.Close()
|
||||
ciRepo, err := client.NewClientIdentityRepoFromReader(cf)
|
||||
ciRepo, err := db.NewClientIdentityRepoFromClients(dbMap, clients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read client identities from file %s: %v", cfg.ClientsFile, err)
|
||||
return fmt.Errorf("failed to create client identity repo: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Open(cfg.ConnectorsFile)
|
||||
|
@ -126,23 +127,30 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("decoding connector configs: %v", err)
|
||||
}
|
||||
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
|
||||
cfgRepo := db.NewConnectorConfigRepo(dbMap)
|
||||
if err := cfgRepo.Set(cfgs); err != nil {
|
||||
return fmt.Errorf("failed to set connectors: %v", err)
|
||||
}
|
||||
|
||||
sRepo := session.NewSessionRepo()
|
||||
skRepo := session.NewSessionKeyRepo()
|
||||
sm := session.NewSessionManager(sRepo, skRepo)
|
||||
sRepo := db.NewSessionRepo(dbMap)
|
||||
skRepo := db.NewSessionKeyRepo(dbMap)
|
||||
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||
|
||||
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
|
||||
users, err := loadUsers(cfg.UsersFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read users from file: %v", err)
|
||||
}
|
||||
userRepo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pwiRepo := user.NewPasswordInfoRepo()
|
||||
pwiRepo := db.NewPasswordInfoRepo(dbMap)
|
||||
|
||||
refTokRepo := refresh.NewRefreshTokenRepo()
|
||||
refTokRepo := db.NewRefreshTokenRepo(dbMap)
|
||||
|
||||
txnFactory := repo.InMemTransactionFactory
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
|
||||
txnFactory := db.TransactionFactory(dbMap)
|
||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
||||
srv.ClientIdentityRepo = ciRepo
|
||||
srv.KeySetRepo = kRepo
|
||||
srv.ConnectorConfigRepo = cfgRepo
|
||||
|
@ -152,7 +160,54 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
srv.SessionManager = sm
|
||||
srv.RefreshTokenRepo = refTokRepo
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadUsers(filepath string) (users []user.UserWithRemoteIdentities, err error) {
|
||||
f, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
err = json.NewDecoder(f).Decode(&users)
|
||||
return
|
||||
}
|
||||
|
||||
func loadClients(filepath string) ([]oidc.ClientIdentity, error) {
|
||||
f, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
var c []struct {
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
RedirectURLs []string `json:"redirectURLs"`
|
||||
}
|
||||
if err := json.NewDecoder(f).Decode(&c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clients := make([]oidc.ClientIdentity, len(c))
|
||||
for i, client := range c {
|
||||
redirectURIs := make([]url.URL, len(client.RedirectURLs))
|
||||
for j, u := range client.RedirectURLs {
|
||||
uri, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
redirectURIs[j] = *uri
|
||||
}
|
||||
|
||||
clients[i] = oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: client.ID,
|
||||
Secret: client.Secret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: redirectURIs,
|
||||
},
|
||||
}
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
|
||||
func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
||||
|
@ -168,6 +223,9 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("unable to initialize database connection: %v", err)
|
||||
}
|
||||
if _, ok := dbc.Dialect.(gorp.PostgresDialect); !ok {
|
||||
return errors.New("only postgres backend supported for multi server configurations")
|
||||
}
|
||||
|
||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.UseOldFormat, cfg.KeySecrets...)
|
||||
if err != nil {
|
||||
|
@ -180,10 +238,10 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
|||
cfgRepo := db.NewConnectorConfigRepo(dbc)
|
||||
userRepo := db.NewUserRepo(dbc)
|
||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
|
||||
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
||||
|
||||
sm := session.NewSessionManager(sRepo, skRepo)
|
||||
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||
|
||||
srv.ClientIdentityRepo = ciRepo
|
||||
srv.KeySetRepo = kRepo
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -17,7 +18,8 @@ import (
|
|||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
|
@ -75,12 +77,13 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
}
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
|
||||
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||
ClientIdentityRepo: func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secrete",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -88,7 +91,12 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}(),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -198,12 +206,13 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
}
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
|
||||
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||
ClientIdentityRepo: func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secrete",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -212,7 +221,12 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}(),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
|
|
@ -9,10 +9,10 @@ import (
|
|||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
type sendResetPasswordEmailData struct {
|
||||
|
@ -28,7 +28,7 @@ type sendResetPasswordEmailData struct {
|
|||
type SendResetPasswordEmailHandler struct {
|
||||
tpl *template.Template
|
||||
emailer *useremail.UserEmailer
|
||||
sm *session.SessionManager
|
||||
sm *sessionmanager.SessionManager
|
||||
cr client.ClientIdentityRepo
|
||||
}
|
||||
|
||||
|
@ -182,7 +182,7 @@ type resetPasswordTemplateData struct {
|
|||
type ResetPasswordHandler struct {
|
||||
tpl *template.Template
|
||||
issuerURL url.URL
|
||||
um *manager.UserManager
|
||||
um *usermanager.UserManager
|
||||
keysFunc func() ([]key.PublicKey, error)
|
||||
}
|
||||
|
||||
|
@ -238,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
|
|||
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case manager.ErrorPasswordAlreadyChanged:
|
||||
case usermanager.ErrorPasswordAlreadyChanged:
|
||||
r.data.Error = "Link Expired"
|
||||
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
|
||||
r.data.DontShowForm = true
|
||||
|
|
|
@ -10,8 +10,9 @@ import (
|
|||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -274,7 +275,7 @@ func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url.
|
|||
return &ru
|
||||
}
|
||||
|
||||
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
|
||||
func registerFromLocalConnector(userManager *usermanager.UserManager, sessionManager *sessionmanager.SessionManager, ses *session.Session, email, password string) (string, error) {
|
||||
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -289,7 +290,7 @@ func registerFromLocalConnector(userManager *manager.UserManager, sessionManager
|
|||
return userID, nil
|
||||
}
|
||||
|
||||
func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
|
||||
func registerFromRemoteConnector(userManager *usermanager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
|
||||
if ses.Identity.ID == "" {
|
||||
return "", errors.New("No Identity found in session.")
|
||||
}
|
||||
|
|
|
@ -22,10 +22,11 @@ import (
|
|||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
usersapi "github.com/coreos/dex/user/api"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -57,7 +58,7 @@ type Server struct {
|
|||
IssuerURL url.URL
|
||||
KeyManager key.PrivateKeyManager
|
||||
KeySetRepo key.PrivateKeySetRepo
|
||||
SessionManager *session.SessionManager
|
||||
SessionManager *sessionmanager.SessionManager
|
||||
ClientIdentityRepo client.ClientIdentityRepo
|
||||
ConnectorConfigRepo connector.ConnectorConfigRepo
|
||||
Templates *template.Template
|
||||
|
@ -69,7 +70,7 @@ type Server struct {
|
|||
HealthChecks []health.Checkable
|
||||
Connectors []connector.Connector
|
||||
UserRepo user.UserRepo
|
||||
UserManager *manager.UserManager
|
||||
UserManager *usermanager.UserManager
|
||||
PasswordInfoRepo user.PasswordInfoRepo
|
||||
RefreshTokenRepo refresh.RefreshTokenRepo
|
||||
UserEmailer *useremail.UserEmailer
|
||||
|
|
|
@ -10,8 +10,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/refresh/refreshtest"
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
|
@ -20,6 +21,8 @@ import (
|
|||
"github.com/kylelemons/godebug/pretty"
|
||||
)
|
||||
|
||||
var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secrete"))
|
||||
|
||||
type StaticKeyManager struct {
|
||||
key.PrivateKeyManager
|
||||
expiresAt time.Time
|
||||
|
@ -68,14 +71,14 @@ func (ss *StaticSigner) JWK() jose.JWK {
|
|||
return jose.JWK{}
|
||||
}
|
||||
|
||||
func staticGenerateCodeFunc(code string) session.GenerateCodeFunc {
|
||||
func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
|
||||
return func() (string, error) {
|
||||
return code, nil
|
||||
}
|
||||
}
|
||||
|
||||
func makeNewUserRepo() (user.UserRepo, error) {
|
||||
userRepo := user.NewUserRepo()
|
||||
userRepo := db.NewUserRepo(db.NewMemDB())
|
||||
|
||||
id := "testid-1"
|
||||
err := userRepo.Create(nil, user.User{
|
||||
|
@ -120,7 +123,7 @@ func TestServerProviderConfig(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServerNewSession(t *testing.T) {
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
srv := &Server{
|
||||
SessionManager: sm,
|
||||
}
|
||||
|
@ -179,7 +182,7 @@ func TestServerLogin(t *testing.T) {
|
|||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secrete",
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -191,13 +194,19 @@ func TestServerLogin(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
||||
ciRepo := func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
|
||||
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
|
@ -235,17 +244,24 @@ func TestServerLogin(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
||||
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
ciRepo := func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX", Secret: "secrete",
|
||||
ID: "XXX", Secret: clientTestSecret,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
||||
}
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
|
@ -268,7 +284,7 @@ func TestServerLoginDisabledUser(t *testing.T) {
|
|||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secrete",
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -280,13 +296,19 @@ func TestServerLoginDisabledUser(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
||||
ciRepo := func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
|
||||
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
|
@ -336,24 +358,27 @@ func TestServerCodeToken(t *testing.T) {
|
|||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secrete",
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
}
|
||||
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
||||
ciRepo := func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
|
||||
userRepo, err := makeNewUserRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
|
@ -375,8 +400,10 @@ func TestServerCodeToken(t *testing.T) {
|
|||
},
|
||||
// Have 'offline_access' in scope, should get non-empty refresh token.
|
||||
{
|
||||
// NOTE(ericchiang): This test assumes that the database ID of the first
|
||||
// refresh token will be "1".
|
||||
scope: []string{"openid", "offline_access"},
|
||||
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -417,14 +444,20 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
|
|||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secrete",
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
}
|
||||
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
||||
ciRepo := func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
|
@ -460,7 +493,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
keyFixture := "goodkey"
|
||||
ccFixture := oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secrete",
|
||||
Secret: clientTestSecret,
|
||||
}
|
||||
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
||||
|
||||
|
@ -474,11 +507,13 @@ func TestServerTokenFail(t *testing.T) {
|
|||
}{
|
||||
// control test case to make sure fixtures check out
|
||||
{
|
||||
// NOTE(ericchiang): This test assumes that the database ID of the first
|
||||
// refresh token will be "1".
|
||||
signer: signerFixture,
|
||||
argCC: ccFixture,
|
||||
argKey: keyFixture,
|
||||
scope: []string{"openid", "offline_access"},
|
||||
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
},
|
||||
|
||||
// no 'offline_access' in 'scope', should get empty refresh token
|
||||
|
@ -518,7 +553,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sm.GenerateCode = func() (string, error) { return keyFixture, nil }
|
||||
|
||||
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
|
||||
|
@ -534,9 +569,13 @@ func TestServerTokenFail(t *testing.T) {
|
|||
km := &StaticKeyManager{
|
||||
signer: tt.signer,
|
||||
}
|
||||
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{Credentials: ccFixture},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = sm.AttachUser(sessionID, "testid-1")
|
||||
if err != nil {
|
||||
|
@ -548,10 +587,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: issuerURL,
|
||||
|
@ -590,15 +626,17 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
|
||||
credXXX := oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secret",
|
||||
Secret: clientTestSecret,
|
||||
}
|
||||
credYYY := oidc.ClientCredentials{
|
||||
ID: "YYY",
|
||||
Secret: "secret",
|
||||
Secret: clientTestSecret,
|
||||
}
|
||||
|
||||
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
||||
|
||||
// NOTE(ericchiang): These tests assume that the database ID of the first
|
||||
// refresh token will be "1".
|
||||
tests := []struct {
|
||||
token string
|
||||
clientID string // The client that associates with the token.
|
||||
|
@ -608,7 +646,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
}{
|
||||
// Everything is good.
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
|
@ -624,7 +662,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid refresh token(invalid payload content).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
|
@ -632,7 +670,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid refresh token(invalid ID content).
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
signerFixture,
|
||||
|
@ -640,7 +678,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(client is not associated with the token).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credYYY,
|
||||
signerFixture,
|
||||
|
@ -648,7 +686,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(no client ID).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
||||
signerFixture,
|
||||
|
@ -656,7 +694,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(no such client).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
||||
signerFixture,
|
||||
|
@ -664,7 +702,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(no secrets).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "XXX"},
|
||||
signerFixture,
|
||||
|
@ -672,7 +710,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Invalid client(invalid secret).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
|
||||
signerFixture,
|
||||
|
@ -680,7 +718,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
},
|
||||
// Signing operation fails.
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
&StaticSigner{sig: nil, err: errors.New("fail")},
|
||||
|
@ -693,20 +731,21 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
signer: tt.signer,
|
||||
}
|
||||
|
||||
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{Credentials: credXXX},
|
||||
oidc.ClientIdentity{Credentials: credYYY},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
userRepo, err := makeNewUserRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: issuerURL,
|
||||
|
@ -745,10 +784,13 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
signer: signerFixture,
|
||||
}
|
||||
|
||||
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{Credentials: credXXX},
|
||||
oidc.ClientIdentity{Credentials: credYYY},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create client identity repo: %v", err)
|
||||
}
|
||||
|
||||
userRepo, err := makeNewUserRepo()
|
||||
if err != nil {
|
||||
|
@ -763,10 +805,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: issuerURL,
|
||||
|
@ -787,7 +826,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
}
|
||||
srv.UserRepo = userRepo
|
||||
|
||||
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
||||
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
||||
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
|
||||
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
@ -10,12 +11,12 @@ import (
|
|||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -26,7 +27,6 @@ const (
|
|||
var (
|
||||
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
|
||||
testClientID = "XXX"
|
||||
testClientSecret = "secrete"
|
||||
|
||||
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
|
||||
|
||||
|
@ -75,13 +75,13 @@ var (
|
|||
type testFixtures struct {
|
||||
srv *Server
|
||||
userRepo user.UserRepo
|
||||
sessionManager *session.SessionManager
|
||||
sessionManager *sessionmanager.SessionManager
|
||||
emailer *email.TemplatizedEmailer
|
||||
redirectURL url.URL
|
||||
clientIdentityRepo client.ClientIdentityRepo
|
||||
}
|
||||
|
||||
func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
|
||||
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
|
||||
x := 0
|
||||
return func() (string, error) {
|
||||
x += 1
|
||||
|
@ -90,8 +90,15 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
|
|||
}
|
||||
|
||||
func makeTestFixtures() (*testFixtures, error) {
|
||||
userRepo := user.NewUserRepoFromUsers(testUsers)
|
||||
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
|
||||
dbMap := db.NewMemDB()
|
||||
userRepo, err := db.NewUserRepoFromUsers(dbMap, testUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pwRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, testPasswordInfos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connConfigs := []connector.ConnectorConfig{
|
||||
&connector.OIDCConnectorConfig{
|
||||
|
@ -111,11 +118,14 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
ID: "local",
|
||||
},
|
||||
}
|
||||
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
||||
connCfgRepo := db.NewConnectorConfigRepo(dbMap)
|
||||
if err := connCfgRepo.Set(connConfigs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
|
||||
|
||||
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||
|
||||
emailer, err := email.NewTemplatizedEmailerFromGlobs(
|
||||
|
@ -126,11 +136,11 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
clientIdentityRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
clientIdentityRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: testClientSecret,
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -139,6 +149,9 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
km := key.NewPrivateKeyManager()
|
||||
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package session
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -27,11 +28,11 @@ func DefaultGenerateCode() (string, error) {
|
|||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager {
|
||||
func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
|
||||
return &SessionManager{
|
||||
GenerateCode: DefaultGenerateCode,
|
||||
Clock: clockwork.NewRealClock(),
|
||||
ValidityWindow: DefaultSessionValidityWindow,
|
||||
ValidityWindow: session.DefaultSessionValidityWindow,
|
||||
sessions: sRepo,
|
||||
keys: skRepo,
|
||||
}
|
||||
|
@ -41,8 +42,8 @@ type SessionManager struct {
|
|||
GenerateCode GenerateCodeFunc
|
||||
Clock clockwork.Clock
|
||||
ValidityWindow time.Duration
|
||||
sessions SessionRepo
|
||||
keys SessionKeyRepo
|
||||
sessions session.SessionRepo
|
||||
keys session.SessionKeyRepo
|
||||
}
|
||||
|
||||
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
|
||||
|
@ -52,10 +53,10 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r
|
|||
}
|
||||
|
||||
now := m.Clock.Now()
|
||||
s := Session{
|
||||
s := session.Session{
|
||||
ConnectorID: connectorID,
|
||||
ID: sID,
|
||||
State: SessionStateNew,
|
||||
State: session.SessionStateNew,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(m.ValidityWindow),
|
||||
ClientID: clientID,
|
||||
|
@ -80,11 +81,12 @@ func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
k := SessionKey{
|
||||
k := session.SessionKey{
|
||||
Key: key,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
|
||||
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
|
||||
err = m.keys.Push(k, sessionKeyValidityWindow)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -97,7 +99,7 @@ func (m *SessionManager) ExchangeKey(key string) (string, error) {
|
|||
return m.keys.Pop(key)
|
||||
}
|
||||
|
||||
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) {
|
||||
func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
|
||||
s, err := m.sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -110,14 +112,14 @@ func (m *SessionManager) getSessionInState(sessionID string, state SessionState)
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) {
|
||||
s, err := m.getSessionInState(sessionID, SessionStateNew)
|
||||
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
|
||||
s, err := m.getSessionInState(sessionID, session.SessionStateNew)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.Identity = ident
|
||||
s.State = SessionStateRemoteAttached
|
||||
s.State = session.SessionStateRemoteAttached
|
||||
|
||||
if err = m.sessions.Update(*s); err != nil {
|
||||
return nil, err
|
||||
|
@ -126,14 +128,14 @@ func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Ident
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) {
|
||||
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached)
|
||||
func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
|
||||
s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.UserID = userID
|
||||
s.State = SessionStateIdentified
|
||||
s.State = session.SessionStateIdentified
|
||||
|
||||
if err = m.sessions.Update(*s); err != nil {
|
||||
return nil, err
|
||||
|
@ -142,13 +144,13 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session,
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) Kill(sessionID string) (*Session, error) {
|
||||
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
|
||||
s, err := m.sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.State = SessionStateDead
|
||||
s.State = session.SessionStateDead
|
||||
|
||||
if err = m.sessions.Update(*s); err != nil {
|
||||
return nil, err
|
||||
|
@ -157,6 +159,6 @@ func (m *SessionManager) Kill(sessionID string) (*Session, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) Get(sessionID string) (*Session, error) {
|
||||
func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
|
||||
return m.sessions.Get(sessionID)
|
||||
}
|
|
@ -1,9 +1,11 @@
|
|||
package session
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -13,8 +15,13 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func newManager() *SessionManager {
|
||||
dbMap := db.NewMemDB()
|
||||
return NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
|
||||
}
|
||||
|
||||
func TestSessionManagerNewSession(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager()
|
||||
sm.GenerateCode = staticGenerateCodeFunc("boo")
|
||||
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
|
@ -26,7 +33,7 @@ func TestSessionManagerNewSession(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager()
|
||||
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
|
@ -43,7 +50,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionManagerExchangeKey(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager()
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
|
@ -68,8 +75,8 @@ func TestSessionManagerExchangeKey(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
ses, err := sm.getSessionInState("123", SessionStateNew)
|
||||
sm := newManager()
|
||||
ses, err := sm.getSessionInState("123", session.SessionStateNew)
|
||||
if err == nil {
|
||||
t.Errorf("Expected non-nil error")
|
||||
}
|
||||
|
@ -79,12 +86,12 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager()
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
ses, err := sm.getSessionInState(sessionID, SessionStateDead)
|
||||
ses, err := sm.getSessionInState(sessionID, session.SessionStateDead)
|
||||
if err == nil {
|
||||
t.Errorf("Expected non-nil error")
|
||||
}
|
||||
|
@ -94,7 +101,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionManagerKill(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager()
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
|
@ -1,11 +1,6 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
)
|
||||
import "time"
|
||||
|
||||
type SessionRepo interface {
|
||||
Get(string) (*Session, error)
|
||||
|
@ -17,87 +12,3 @@ type SessionKeyRepo interface {
|
|||
Push(SessionKey, time.Duration) error
|
||||
Pop(string) (string, error)
|
||||
}
|
||||
|
||||
func NewSessionRepo() SessionRepo {
|
||||
return NewSessionRepoWithClock(clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func NewSessionRepoWithClock(clock clockwork.Clock) SessionRepo {
|
||||
return &memSessionRepo{
|
||||
store: make(map[string]Session),
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
type memSessionRepo struct {
|
||||
store map[string]Session
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func (m *memSessionRepo) Get(sessionID string) (*Session, error) {
|
||||
s, ok := m.store[sessionID]
|
||||
if !ok || s.ExpiresAt.Before(m.clock.Now()) {
|
||||
return nil, errors.New("unrecognized ID")
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func (m *memSessionRepo) Create(s Session) error {
|
||||
if _, ok := m.store[s.ID]; ok {
|
||||
return errors.New("ID exists")
|
||||
}
|
||||
|
||||
m.store[s.ID] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memSessionRepo) Update(s Session) error {
|
||||
if _, ok := m.store[s.ID]; !ok {
|
||||
return errors.New("unrecognized ID")
|
||||
}
|
||||
m.store[s.ID] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
type expiringSessionKey struct {
|
||||
SessionKey
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func NewSessionKeyRepo() SessionKeyRepo {
|
||||
return NewSessionKeyRepoWithClock(clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func NewSessionKeyRepoWithClock(clock clockwork.Clock) SessionKeyRepo {
|
||||
return &memSessionKeyRepo{
|
||||
store: make(map[string]expiringSessionKey),
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
type memSessionKeyRepo struct {
|
||||
store map[string]expiringSessionKey
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func (m *memSessionKeyRepo) Pop(key string) (string, error) {
|
||||
esk, ok := m.store[key]
|
||||
if !ok {
|
||||
return "", errors.New("unrecognized key")
|
||||
}
|
||||
defer delete(m.store, key)
|
||||
|
||||
if esk.expiresAt.Before(m.clock.Now()) {
|
||||
return "", errors.New("expired key")
|
||||
}
|
||||
|
||||
return esk.SessionKey.SessionID, nil
|
||||
}
|
||||
|
||||
func (m *memSessionKeyRepo) Push(sk SessionKey, ttl time.Duration) error {
|
||||
m.store[sk.Key] = expiringSessionKey{
|
||||
SessionKey: sk,
|
||||
expiresAt: m.clock.Now().Add(ttl),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,22 +1,22 @@
|
|||
[
|
||||
{
|
||||
"id": "XXX",
|
||||
"secret": "secrete",
|
||||
"secret": "c2VjcmV0ZQ==",
|
||||
"redirectURLs": ["http://127.0.0.1:5555/callback"]
|
||||
},
|
||||
{
|
||||
"id": "example-app",
|
||||
"secret": "example-app-secret",
|
||||
"secret": "ZXhhbXBsZS1hcHAtc2VjcmV0",
|
||||
"redirectURLs": ["http://127.0.0.1:5555/callback"]
|
||||
},
|
||||
{
|
||||
"id": "example-cli",
|
||||
"secret": "example-cli-secret",
|
||||
"secret": "ZXhhbXBsZS1jbGktc2VjcmV0",
|
||||
"redirectURLs": ["http://127.0.0.1:8000/admin/v1/oauth/login"]
|
||||
},
|
||||
{
|
||||
"id": "oauth2_proxy",
|
||||
"secret": "proxy",
|
||||
"secret": "cHJveHk=",
|
||||
"redirectURLs": ["http://127.0.0.1:4180/oauth2/callback"]
|
||||
}
|
||||
]
|
||||
|
|
8
test
8
test
|
@ -12,9 +12,13 @@
|
|||
# Invoke ./cover for HTML output
|
||||
COVER=${COVER:-"-cover"}
|
||||
|
||||
source ./build
|
||||
source ./env
|
||||
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email admin"
|
||||
if [ ! -d $GOPATH/pkg ]; then
|
||||
echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests."
|
||||
fi
|
||||
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin"
|
||||
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
||||
|
||||
# user has not provided PKG override
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#!/bin/bash -e
|
||||
source ./build
|
||||
|
||||
source ./env
|
||||
|
||||
go test $@ github.com/coreos/dex/functional
|
||||
go test $@ github.com/coreos/dex/functional/repo
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -11,7 +12,7 @@ import (
|
|||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/db"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
|
@ -86,7 +87,9 @@ var (
|
|||
)
|
||||
|
||||
func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
||||
dbMap := db.NewMemDB()
|
||||
ur := func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
|
@ -115,7 +118,14 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
},
|
||||
},
|
||||
})
|
||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
pwr := func() user.PasswordInfoRepo {
|
||||
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{
|
||||
{
|
||||
UserID: "ID-1",
|
||||
Password: []byte("password-1"),
|
||||
|
@ -125,15 +135,29 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
Password: []byte("password-2"),
|
||||
},
|
||||
})
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
ccr := func() connector.ConnectorConfigRepo {
|
||||
repo := db.NewConnectorConfigRepo(dbMap)
|
||||
c := []connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
}
|
||||
if err := repo.Set(c); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
mgr := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||
mgr.Clock = clock
|
||||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secrete",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -141,7 +165,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
},
|
||||
},
|
||||
}
|
||||
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
||||
cir := func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
|
||||
if err != nil {
|
||||
panic("Failed to create client identity repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
emailer := &testEmailer{}
|
||||
api := NewUsersAPI(mgr, cir, emailer, "local")
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
"github.com/coreos/dex/user"
|
||||
)
|
||||
|
@ -45,7 +46,9 @@ func (t *testEmailer) SendMail(from, subject, text, html string, to ...string) e
|
|||
}
|
||||
|
||||
func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
|
||||
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
||||
dbMap := db.NewMemDB()
|
||||
ur := func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
|
@ -64,7 +67,14 @@ func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
|
|||
},
|
||||
},
|
||||
})
|
||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
pwr := func() user.PasswordInfoRepo {
|
||||
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{
|
||||
{
|
||||
UserID: "ID-1",
|
||||
Password: []byte("password-1"),
|
||||
|
@ -74,6 +84,11 @@ func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
|
|||
Password: []byte("password-2"),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
privKey, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/user"
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,9 @@ func makeTestFixtures() *testFixtures {
|
|||
f := &testFixtures{}
|
||||
f.clock = clockwork.NewFakeClock()
|
||||
|
||||
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
||||
dbMap := db.NewMemDB()
|
||||
f.ur = func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
|
@ -52,7 +54,14 @@ func makeTestFixtures() *testFixtures {
|
|||
},
|
||||
},
|
||||
})
|
||||
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
f.pwr = func() user.PasswordInfoRepo {
|
||||
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{
|
||||
{
|
||||
UserID: "ID-1",
|
||||
Password: []byte("password-1"),
|
||||
|
@ -62,10 +71,24 @@ func makeTestFixtures() *testFixtures {
|
|||
Password: []byte("password-2"),
|
||||
},
|
||||
})
|
||||
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
f.ccr = func() connector.ConnectorConfigRepo {
|
||||
repo := db.NewConnectorConfigRepo(dbMap)
|
||||
c := []connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
|
||||
}
|
||||
if err := repo.Set(c); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, db.TransactionFactory(dbMap), ManagerOptions{})
|
||||
f.mgr.Clock = f.clock
|
||||
return f
|
||||
}
|
||||
|
@ -207,18 +230,22 @@ func TestRegisterWithPassword(t *testing.T) {
|
|||
}
|
||||
if diff := pretty.Compare(usr, ridUSR); diff != "" {
|
||||
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
|
||||
continue
|
||||
}
|
||||
|
||||
pwi, err := f.pwr.Get(nil, userID)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: err != nil: %q", i, err)
|
||||
continue
|
||||
}
|
||||
ident, err := pwi.Authenticate(tt.plaintext)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: err != nil: %q", i, err)
|
||||
continue
|
||||
}
|
||||
if ident.ID != userID {
|
||||
t.Errorf("case %d: ident.ID: want=%q, got=%q", i, userID, ident.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = pwi.Authenticate(tt.plaintext + "WRONG")
|
||||
|
@ -274,7 +301,7 @@ func TestVerifyEmail(t *testing.T) {
|
|||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
cb, err := f.mgr.VerifyEmail(user.EmailVerification{tt.evClaims})
|
||||
cb, err := f.mgr.VerifyEmail(user.EmailVerification{Claims: tt.evClaims})
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
|
@ -344,7 +371,7 @@ func TestChangePassword(t *testing.T) {
|
|||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
cb, err := f.mgr.ChangePassword(user.PasswordReset{tt.pwrClaims}, tt.newPassword)
|
||||
cb, err := f.mgr.ChangePassword(user.PasswordReset{Claims: tt.pwrClaims}, tt.newPassword)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
|
|
|
@ -4,9 +4,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
@ -85,60 +83,6 @@ type PasswordInfoRepo interface {
|
|||
Create(repo.Transaction, PasswordInfo) error
|
||||
}
|
||||
|
||||
func NewPasswordInfoRepo() PasswordInfoRepo {
|
||||
return &memPasswordInfoRepo{
|
||||
pws: make(map[string]PasswordInfo),
|
||||
}
|
||||
}
|
||||
|
||||
type memPasswordInfoRepo struct {
|
||||
pws map[string]PasswordInfo
|
||||
}
|
||||
|
||||
func (m *memPasswordInfoRepo) Get(_ repo.Transaction, id string) (PasswordInfo, error) {
|
||||
pw, ok := m.pws[id]
|
||||
if !ok {
|
||||
return PasswordInfo{}, ErrorNotFound
|
||||
}
|
||||
return pw, nil
|
||||
}
|
||||
|
||||
func (m *memPasswordInfoRepo) Create(_ repo.Transaction, pw PasswordInfo) error {
|
||||
_, ok := m.pws[pw.UserID]
|
||||
if ok {
|
||||
return ErrorDuplicateID
|
||||
}
|
||||
|
||||
if pw.UserID == "" {
|
||||
return ErrorInvalidID
|
||||
}
|
||||
|
||||
if len(pw.Password) == 0 {
|
||||
return ErrorInvalidPassword
|
||||
}
|
||||
|
||||
m.pws[pw.UserID] = pw
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memPasswordInfoRepo) Update(_ repo.Transaction, pw PasswordInfo) error {
|
||||
if pw.UserID == "" {
|
||||
return ErrorInvalidID
|
||||
}
|
||||
|
||||
_, ok := m.pws[pw.UserID]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
|
||||
if len(pw.Password) == 0 {
|
||||
return ErrorInvalidPassword
|
||||
}
|
||||
|
||||
m.pws[pw.UserID] = pw
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *PasswordInfo) UnmarshalJSON(data []byte) error {
|
||||
var dec struct {
|
||||
UserID string `json:"userId"`
|
||||
|
@ -172,21 +116,6 @@ func (u *PasswordInfo) UnmarshalJSON(data []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newPasswordInfosFromReader(r io.Reader) ([]PasswordInfo, error) {
|
||||
var pws []PasswordInfo
|
||||
err := json.NewDecoder(r).Decode(&pws)
|
||||
return pws, err
|
||||
}
|
||||
|
||||
func readPasswordInfosFromFile(loc string) ([]PasswordInfo, error) {
|
||||
pwf, err := os.Open(loc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read password info from file %q: %v", loc, err)
|
||||
}
|
||||
|
||||
return newPasswordInfosFromReader(pwf)
|
||||
}
|
||||
|
||||
func LoadPasswordInfos(repo PasswordInfoRepo, pws []PasswordInfo) error {
|
||||
for i, pw := range pws {
|
||||
err := repo.Create(nil, pw)
|
||||
|
@ -197,23 +126,6 @@ func LoadPasswordInfos(repo PasswordInfoRepo, pws []PasswordInfo) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func NewPasswordInfoRepoFromPasswordInfos(pws []PasswordInfo) PasswordInfoRepo {
|
||||
memRepo := NewPasswordInfoRepo().(*memPasswordInfoRepo)
|
||||
for _, pw := range pws {
|
||||
memRepo.pws[pw.UserID] = pw
|
||||
}
|
||||
return memRepo
|
||||
}
|
||||
|
||||
func NewPasswordInfoRepoFromFile(loc string) (PasswordInfoRepo, error) {
|
||||
pws, err := readPasswordInfosFromFile(loc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewPasswordInfoRepoFromPasswordInfos(pws), nil
|
||||
}
|
||||
|
||||
func NewPasswordReset(userID string, password Password, issuer url.URL, clientID string, callback url.URL, expires time.Duration) PasswordReset {
|
||||
claims := oidc.NewClaims(issuer.String(), userID, clientID, clock.Now(), clock.Now().Add(expires))
|
||||
claims.Add(ClaimPasswordResetPassword, string(password))
|
||||
|
|
|
@ -2,7 +2,6 @@ package user
|
|||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -14,48 +13,6 @@ import (
|
|||
"github.com/coreos/go-oidc/key"
|
||||
)
|
||||
|
||||
func TestNewPasswordInfosFromReader(t *testing.T) {
|
||||
PasswordHasher = func(plaintext string) ([]byte, error) {
|
||||
return []byte(strings.ToUpper(plaintext)), nil
|
||||
}
|
||||
defer func() {
|
||||
PasswordHasher = DefaultPasswordHasher
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
json string
|
||||
want []PasswordInfo
|
||||
}{
|
||||
{
|
||||
json: `[{"userId":"12345","passwordPlaintext":"password"},{"userId":"78901","passwordHash":"WORDPASS", "passwordExpires":"2006-01-01T15:04:05Z"}]`,
|
||||
want: []PasswordInfo{
|
||||
{
|
||||
UserID: "12345",
|
||||
Password: []byte("PASSWORD"),
|
||||
},
|
||||
{
|
||||
UserID: "78901",
|
||||
Password: []byte("WORDPASS"),
|
||||
PasswordExpires: time.Date(2006,
|
||||
1, 1, 15, 4, 5, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r := strings.NewReader(tt.json)
|
||||
us, err := newPasswordInfosFromReader(r)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want nil err: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if diff := pretty.Compare(tt.want, us); diff != "" {
|
||||
t.Errorf("case %d: Compare(want, got): %v", i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPasswordFromHash(t *testing.T) {
|
||||
tests := []string{
|
||||
"test",
|
||||
|
|
254
user/user.go
254
user/user.go
|
@ -4,13 +4,10 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/pborman/uuid"
|
||||
|
@ -172,262 +169,11 @@ func ValidPassword(plaintext string) bool {
|
|||
return len(plaintext) > 5
|
||||
}
|
||||
|
||||
// NewUserRepo returns an in-memory UserRepo useful for development.
|
||||
func NewUserRepo() UserRepo {
|
||||
return &memUserRepo{
|
||||
usersByID: make(map[string]User),
|
||||
userIDsByEmail: make(map[string]string),
|
||||
userIDsByRemoteID: make(map[RemoteIdentity]string),
|
||||
remoteIDsByUserID: make(map[string]map[RemoteIdentity]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
type memUserRepo struct {
|
||||
usersByID map[string]User
|
||||
userIDsByEmail map[string]string
|
||||
userIDsByRemoteID map[RemoteIdentity]string
|
||||
remoteIDsByUserID map[string]map[RemoteIdentity]struct{}
|
||||
}
|
||||
|
||||
func (r *memUserRepo) Get(_ repo.Transaction, id string) (User, error) {
|
||||
user, ok := r.usersByID[id]
|
||||
if !ok {
|
||||
return User{}, ErrorNotFound
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
type usersByEmail []User
|
||||
|
||||
func (s usersByEmail) Len() int { return len(s) }
|
||||
func (s usersByEmail) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s usersByEmail) Less(i, j int) bool { return s[i].Email < s[j].Email }
|
||||
|
||||
func (r *memUserRepo) List(tx repo.Transaction, filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) {
|
||||
var offset int
|
||||
var err error
|
||||
if nextPageToken != "" {
|
||||
filter, maxResults, offset, err = DecodeNextPageToken(nextPageToken)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
users := []User{}
|
||||
for _, usr := range r.usersByID {
|
||||
users = append(users, usr)
|
||||
}
|
||||
|
||||
sort.Sort(usersByEmail(users))
|
||||
|
||||
high := offset + maxResults
|
||||
|
||||
var tok string
|
||||
if high >= len(users) {
|
||||
high = len(users)
|
||||
} else {
|
||||
tok, err = EncodeNextPageToken(filter, maxResults, high)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if len(users[offset:high]) == 0 {
|
||||
return nil, "", ErrorNotFound
|
||||
}
|
||||
return users[offset:high], tok, nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) GetByEmail(tx repo.Transaction, email string) (User, error) {
|
||||
userID, ok := r.userIDsByEmail[email]
|
||||
if !ok {
|
||||
return User{}, ErrorNotFound
|
||||
}
|
||||
return r.Get(tx, userID)
|
||||
}
|
||||
|
||||
func (r *memUserRepo) Create(_ repo.Transaction, user User) error {
|
||||
if user.ID == "" {
|
||||
return ErrorInvalidID
|
||||
}
|
||||
|
||||
if !ValidEmail(user.Email) {
|
||||
return ErrorInvalidEmail
|
||||
}
|
||||
|
||||
// make sure no one has the same ID; if using UUID the chances of this
|
||||
// happening are astronomically small.
|
||||
_, ok := r.usersByID[user.ID]
|
||||
if ok {
|
||||
return ErrorDuplicateID
|
||||
}
|
||||
|
||||
// make sure there's no other user with the same Email
|
||||
_, ok = r.userIDsByEmail[user.Email]
|
||||
if ok {
|
||||
return ErrorDuplicateEmail
|
||||
}
|
||||
|
||||
r.set(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) Update(_ repo.Transaction, user User) error {
|
||||
if user.ID == "" {
|
||||
return ErrorInvalidID
|
||||
}
|
||||
|
||||
if !ValidEmail(user.Email) {
|
||||
return ErrorInvalidEmail
|
||||
}
|
||||
|
||||
// make sure this user exists already
|
||||
_, ok := r.usersByID[user.ID]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
|
||||
// make sure there's no other user with the same Email
|
||||
otherID, ok := r.userIDsByEmail[user.Email]
|
||||
if ok && otherID != user.ID {
|
||||
return ErrorDuplicateEmail
|
||||
}
|
||||
|
||||
r.set(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error {
|
||||
if id == "" {
|
||||
return ErrorInvalidID
|
||||
}
|
||||
user, ok := r.usersByID[id]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
user.Disabled = disable
|
||||
r.set(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
|
||||
_, ok := r.usersByID[userID]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
_, ok = r.userIDsByRemoteID[ri]
|
||||
if ok {
|
||||
return ErrorDuplicateRemoteIdentity
|
||||
}
|
||||
|
||||
r.userIDsByRemoteID[ri] = userID
|
||||
rIDs, ok := r.remoteIDsByUserID[userID]
|
||||
if !ok {
|
||||
rIDs = make(map[RemoteIdentity]struct{})
|
||||
r.remoteIDsByUserID[userID] = rIDs
|
||||
}
|
||||
|
||||
rIDs[ri] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) RemoveRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
|
||||
otherID, ok := r.userIDsByRemoteID[ri]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
if otherID != userID {
|
||||
return ErrorNotFound
|
||||
}
|
||||
delete(r.userIDsByRemoteID, ri)
|
||||
delete(r.remoteIDsByUserID[userID], ri)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) GetByRemoteIdentity(_ repo.Transaction, ri RemoteIdentity) (User, error) {
|
||||
userID, ok := r.userIDsByRemoteID[ri]
|
||||
if !ok {
|
||||
return User{}, ErrorNotFound
|
||||
}
|
||||
|
||||
user, ok := r.usersByID[userID]
|
||||
if !ok {
|
||||
return User{}, ErrorNotFound
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) GetRemoteIdentities(_ repo.Transaction, userID string) ([]RemoteIdentity, error) {
|
||||
ids := []RemoteIdentity{}
|
||||
for id := range r.remoteIDsByUserID[userID] {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) GetAdminCount(_ repo.Transaction) (int, error) {
|
||||
var i int
|
||||
for _, usr := range r.usersByID {
|
||||
if usr.Admin {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) set(user User) error {
|
||||
r.usersByID[user.ID] = user
|
||||
r.userIDsByEmail[user.Email] = user.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
type UserWithRemoteIdentities struct {
|
||||
User User `json:"user"`
|
||||
RemoteIdentities []RemoteIdentity `json:"remoteIdentities"`
|
||||
}
|
||||
|
||||
// NewUserRepoFromFile returns an in-memory UserRepo useful for development given a JSON serialized file of Users.
|
||||
func NewUserRepoFromFile(loc string) (UserRepo, error) {
|
||||
us, err := readUsersFromFile(loc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewUserRepoFromUsers(us), nil
|
||||
}
|
||||
|
||||
func NewUserRepoFromUsers(us []UserWithRemoteIdentities) UserRepo {
|
||||
memUserRepo := NewUserRepo().(*memUserRepo)
|
||||
for _, u := range us {
|
||||
memUserRepo.set(u.User)
|
||||
for _, ri := range u.RemoteIdentities {
|
||||
memUserRepo.AddRemoteIdentity(nil, u.User.ID, ri)
|
||||
}
|
||||
}
|
||||
return memUserRepo
|
||||
}
|
||||
|
||||
func newUsersFromReader(r io.Reader) ([]UserWithRemoteIdentities, error) {
|
||||
var us []UserWithRemoteIdentities
|
||||
err := json.NewDecoder(r).Decode(&us)
|
||||
return us, err
|
||||
}
|
||||
|
||||
func readUsersFromFile(loc string) ([]UserWithRemoteIdentities, error) {
|
||||
uf, err := os.Open(loc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read users from file %q: %v", loc, err)
|
||||
}
|
||||
defer uf.Close()
|
||||
|
||||
us, err := newUsersFromReader(uf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return us, err
|
||||
}
|
||||
|
||||
func (u *User) UnmarshalJSON(data []byte) error {
|
||||
var dec struct {
|
||||
ID string `json:"id"`
|
||||
|
|
|
@ -2,7 +2,6 @@ package user
|
|||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
@ -10,44 +9,6 @@ import (
|
|||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
func TestNewUsersFromReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
json string
|
||||
want []UserWithRemoteIdentities
|
||||
}{
|
||||
{
|
||||
json: `[{"user":{"id":"12345", "displayName": "Elroy Canis", "email":"elroy23@example.com"}, "remoteIdentities":[{"connectorID":"google", "id":"elroy@example.com"}] }]`,
|
||||
want: []UserWithRemoteIdentities{
|
||||
{
|
||||
User: User{
|
||||
ID: "12345",
|
||||
DisplayName: "Elroy Canis",
|
||||
Email: "elroy23@example.com",
|
||||
},
|
||||
RemoteIdentities: []RemoteIdentity{
|
||||
{
|
||||
ConnectorID: "google",
|
||||
ID: "elroy@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r := strings.NewReader(tt.json)
|
||||
us, err := newUsersFromReader(r)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want nil err: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if diff := pretty.Compare(tt.want, us); diff != "" {
|
||||
t.Errorf("case %d: Compare(want, got): %v", i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddToClaims(t *testing.T) {
|
||||
tests := []struct {
|
||||
user User
|
||||
|
|
Loading…
Reference in a new issue