Merge pull request #304 from ericchiang/sqlite3

move to sqlite3 for --no-db mode and tests
This commit is contained in:
bobbyrullo 2016-02-19 16:03:15 -08:00
commit f51125f555
83 changed files with 159246 additions and 1796 deletions

View file

@ -22,6 +22,7 @@ install:
script: script:
- docker run -d -p 127.0.0.1:15432:5432 quay.io/coreos/postgres - docker run -d -p 127.0.0.1:15432:5432 quay.io/coreos/postgres
- LDAPCONTAINER=`docker run -e LDAP_TLS_PROTOCOL_MIN=3.0 -e LDAP_TLS_CIPHER_SUITE=NORMAL -d -p 127.0.0.1:1389:389 -p 127.0.0.1:1636:636 -h tlstest.local osixia/openldap` - LDAPCONTAINER=`docker run -e LDAP_TLS_PROTOCOL_MIN=3.0 -e LDAP_TLS_CIPHER_SUITE=NORMAL -d -p 127.0.0.1:1389:389 -p 127.0.0.1:1636:636 -h tlstest.local osixia/openldap`
- ./build
- ./test - ./test
- docker cp ${LDAPCONTAINER}:container/service/:cfssl/assets/default-ca/default-ca.pem /tmp/openldap-ca.pem - docker cp ${LDAPCONTAINER}:container/service/:cfssl/assets/default-ca/default-ca.pem /tmp/openldap-ca.pem
- docker cp ${LDAPCONTAINER}:container/service/slapd/assets/certs/ldap.key /tmp/ldap.key - docker cp ${LDAPCONTAINER}:container/service/slapd/assets/certs/ldap.key /tmp/ldap.key
@ -29,6 +30,7 @@ script:
- docker cp ${LDAPCONTAINER}:container/service/slapd/assets/certs/ldap.crt /tmp/ldap.crt - docker cp ${LDAPCONTAINER}:container/service/slapd/assets/certs/ldap.crt /tmp/ldap.crt
- sudo sh -c 'echo "127.0.0.1 tlstest.local" >> /etc/hosts' - sudo sh -c 'echo "127.0.0.1 tlstest.local" >> /etc/hosts'
- ./test-functional - ./test-functional
- DEX_TEST_DSN="sqlite3://:memory:" ./test-functional
deploy: deploy:
provider: script provider: script

7
Godeps/Godeps.json generated
View file

@ -1,6 +1,6 @@
{ {
"ImportPath": "github.com/coreos/dex", "ImportPath": "github.com/coreos/dex",
"GoVersion": "go1.4.2", "GoVersion": "go1.5",
"Packages": [ "Packages": [
"./..." "./..."
], ],
@ -91,6 +91,11 @@
"ImportPath": "github.com/mailgun/mailgun-go", "ImportPath": "github.com/mailgun/mailgun-go",
"Rev": "9578dc67692294bb7e2a6f4b15dd18c97af19440" "Rev": "9578dc67692294bb7e2a6f4b15dd18c97af19440"
}, },
{
"ImportPath": "github.com/mattn/go-sqlite3",
"Comment": "v1.1.0-25-g2513631",
"Rev": "2513631704612107a1c8b1803fb8e6b3dda2230e"
},
{ {
"ImportPath": "github.com/mbanzon/simplehttp", "ImportPath": "github.com/mbanzon/simplehttp",
"Rev": "04c542e7ac706a25820090f274ea6a4f39a63326" "Rev": "04c542e7ac706a25820090f274ea6a4f39a63326"

View file

@ -0,0 +1,3 @@
*.db
*.exe
*.dll

View file

@ -0,0 +1,9 @@
language: go
go:
- tip
before_install:
- go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover
script:
- $HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Yasuhiro Matsumoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,74 @@
go-sqlite3
==========
[![Build Status](https://travis-ci.org/mattn/go-sqlite3.png?branch=master)](https://travis-ci.org/mattn/go-sqlite3)
[![Coverage Status](https://coveralls.io/repos/mattn/go-sqlite3/badge.png?branch=master)](https://coveralls.io/r/mattn/go-sqlite3?branch=master)
Description
-----------
sqlite3 driver conforming to the built-in database/sql interface
Installation
------------
This package can be installed with the go get command:
go get github.com/mattn/go-sqlite3
_go-sqlite3_ is *cgo* package.
If you want to build your app using go-sqlite3, you need gcc.
However, if you install _go-sqlite3_ with `go install github.com/mattn/go-sqlite`, you don't need gcc to build your app anymore.
Documentation
-------------
API documentation can be found here: http://godoc.org/github.com/mattn/go-sqlite3
Examples can be found under the `./_example` directory
FAQ
---
* Want to build go-sqlite3 with libsqlite3 on my linux.
Use `go build --tags "libsqlite3 linux"`
* Want to build go-sqlite3 with icu extension.
Use `go build --tags "icu"`
* Can't build go-sqlite3 on windows 64bit.
> Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit.
> See: https://github.com/mattn/go-sqlite3/issues/27
* Getting insert error while query is opened.
> You can pass some arguments into the connection string, for example, a URI.
> See: https://github.com/mattn/go-sqlite3/issues/39
* Do you want cross compiling? mingw on Linux or Mac?
> See: https://github.com/mattn/go-sqlite3/issues/106
> See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html
* Want to get time.Time with current locale
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
License
-------
MIT: http://mattn.mit-license.org/2012
sqlite3-binding.c, sqlite3-binding.h, sqlite3ext.h
The -binding suffix was added to avoid build failures under gccgo.
In this repository, those files are amalgamation code that copied from SQLite3. The license of those codes are depend on the license of SQLite3.
Author
------
Yasuhiro Matsumoto (a.k.a mattn)

View file

@ -0,0 +1,70 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
/*
#include <sqlite3-binding.h>
#include <stdlib.h>
*/
import "C"
import (
"runtime"
"unsafe"
)
type SQLiteBackup struct {
b *C.sqlite3_backup
}
func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*SQLiteBackup, error) {
destptr := C.CString(dest)
defer C.free(unsafe.Pointer(destptr))
srcptr := C.CString(src)
defer C.free(unsafe.Pointer(srcptr))
if b := C.sqlite3_backup_init(c.db, destptr, conn.db, srcptr); b != nil {
bb := &SQLiteBackup{b: b}
runtime.SetFinalizer(bb, (*SQLiteBackup).Finish)
return bb, nil
}
return nil, c.lastError()
}
// Backs up for one step. Calls the underlying `sqlite3_backup_step` function.
// This function returns a boolean indicating if the backup is done and
// an error signalling any other error. Done is returned if the underlying C
// function returns SQLITE_DONE (Code 101)
func (b *SQLiteBackup) Step(p int) (bool, error) {
ret := C.sqlite3_backup_step(b.b, C.int(p))
if ret == C.SQLITE_DONE {
return true, nil
} else if ret != 0 && ret != C.SQLITE_LOCKED && ret != C.SQLITE_BUSY {
return false, Error{Code: ErrNo(ret)}
}
return false, nil
}
func (b *SQLiteBackup) Remaining() int {
return int(C.sqlite3_backup_remaining(b.b))
}
func (b *SQLiteBackup) PageCount() int {
return int(C.sqlite3_backup_pagecount(b.b))
}
func (b *SQLiteBackup) Finish() error {
return b.Close()
}
func (b *SQLiteBackup) Close() error {
ret := C.sqlite3_backup_finish(b.b)
if ret != 0 {
return Error{Code: ErrNo(ret)}
}
b.b = nil
runtime.SetFinalizer(b, nil)
return nil
}

View file

@ -0,0 +1,290 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
// You can't export a Go function to C and have definitions in the C
// preamble in the same file, so we have to have callbackTrampoline in
// its own file. Because we need a separate file anyway, the support
// code for SQLite custom functions is in here.
/*
#include <sqlite3-binding.h>
#include <stdlib.h>
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
*/
import "C"
import (
"errors"
"fmt"
"math"
"reflect"
"unsafe"
)
//export callbackTrampoline
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
fi.Call(ctx, args)
}
//export stepTrampoline
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
ai.Step(ctx, args)
}
//export doneTrampoline
func doneTrampoline(ctx *C.sqlite3_context) {
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
ai.Done(ctx)
}
// This is only here so that tests can refer to it.
type callbackArgRaw C.sqlite3_value
type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
type callbackArgCast struct {
f callbackArgConverter
typ reflect.Type
}
func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
val, err := c.f(v)
if err != nil {
return reflect.Value{}, err
}
if !val.Type().ConvertibleTo(c.typ) {
return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
}
return val.Convert(c.typ), nil
}
func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
}
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
}
func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
}
i := int64(C.sqlite3_value_int64(v))
val := false
if i != 0 {
val = true
}
return reflect.ValueOf(val), nil
}
func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
}
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
}
func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_BLOB:
l := C.sqlite3_value_bytes(v)
p := C.sqlite3_value_blob(v)
return reflect.ValueOf(C.GoBytes(p, l)), nil
case C.SQLITE_TEXT:
l := C.sqlite3_value_bytes(v)
c := unsafe.Pointer(C.sqlite3_value_text(v))
return reflect.ValueOf(C.GoBytes(c, l)), nil
default:
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
}
}
func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_BLOB:
l := C.sqlite3_value_bytes(v)
p := (*C.char)(C.sqlite3_value_blob(v))
return reflect.ValueOf(C.GoStringN(p, l)), nil
case C.SQLITE_TEXT:
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
return reflect.ValueOf(C.GoString(c)), nil
default:
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
}
}
func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_INTEGER:
return callbackArgInt64(v)
case C.SQLITE_FLOAT:
return callbackArgFloat64(v)
case C.SQLITE_TEXT:
return callbackArgString(v)
case C.SQLITE_BLOB:
return callbackArgBytes(v)
case C.SQLITE_NULL:
// Interpret NULL as a nil byte slice.
var ret []byte
return reflect.ValueOf(ret), nil
default:
panic("unreachable")
}
}
func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
switch typ.Kind() {
case reflect.Interface:
if typ.NumMethod() != 0 {
return nil, errors.New("the only supported interface type is interface{}")
}
return callbackArgGeneric, nil
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
return nil, errors.New("the only supported slice type is []byte")
}
return callbackArgBytes, nil
case reflect.String:
return callbackArgString, nil
case reflect.Bool:
return callbackArgBool, nil
case reflect.Int64:
return callbackArgInt64, nil
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
c := callbackArgCast{callbackArgInt64, typ}
return c.Run, nil
case reflect.Float64:
return callbackArgFloat64, nil
case reflect.Float32:
c := callbackArgCast{callbackArgFloat64, typ}
return c.Run, nil
default:
return nil, fmt.Errorf("don't know how to convert to %s", typ)
}
}
func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
var args []reflect.Value
if len(argv) < len(converters) {
return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
}
for i, arg := range argv[:len(converters)] {
v, err := converters[i](arg)
if err != nil {
return nil, err
}
args = append(args, v)
}
if variadic != nil {
for _, arg := range argv[len(converters):] {
v, err := variadic(arg)
if err != nil {
return nil, err
}
args = append(args, v)
}
}
return args, nil
}
type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
switch v.Type().Kind() {
case reflect.Int64:
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
v = v.Convert(reflect.TypeOf(int64(0)))
case reflect.Bool:
b := v.Interface().(bool)
if b {
v = reflect.ValueOf(int64(1))
} else {
v = reflect.ValueOf(int64(0))
}
default:
return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
}
C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
return nil
}
func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
switch v.Type().Kind() {
case reflect.Float64:
case reflect.Float32:
v = v.Convert(reflect.TypeOf(float64(0)))
default:
return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
}
C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
return nil
}
func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("cannot convert %s to BLOB", v.Type())
}
i := v.Interface()
if i == nil || len(i.([]byte)) == 0 {
C.sqlite3_result_null(ctx)
} else {
bs := i.([]byte)
C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
}
return nil
}
func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
if v.Type().Kind() != reflect.String {
return fmt.Errorf("cannot convert %s to TEXT", v.Type())
}
C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
return nil
}
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
switch typ.Kind() {
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
return nil, errors.New("the only supported slice type is []byte")
}
return callbackRetBlob, nil
case reflect.String:
return callbackRetText, nil
case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
return callbackRetInteger, nil
case reflect.Float32, reflect.Float64:
return callbackRetFloat, nil
default:
return nil, fmt.Errorf("don't know how to convert to %s", typ)
}
}
func callbackError(ctx *C.sqlite3_context, err error) {
cstr := C.CString(err.Error())
defer C.free(unsafe.Pointer(cstr))
C.sqlite3_result_error(ctx, cstr, -1)
}
// Test support code. Tests are not allowed to import "C", so we can't
// declare any functions that use C.sqlite3_value.
func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
return func(*C.sqlite3_value) (reflect.Value, error) {
return v, err
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,487 @@
/*
** 2006 June 7
**
** The author disclaims copyright to this source code. In place of
** a legal notice, here is a blessing:
**
** May you do good and not evil.
** May you find forgiveness for yourself and forgive others.
** May you share freely, never taking more than you give.
**
*************************************************************************
** This header file defines the SQLite interface for use by
** shared libraries that want to be imported as extensions into
** an SQLite instance. Shared libraries that intend to be loaded
** as extensions by SQLite should #include this file instead of
** sqlite3.h.
*/
#ifndef _SQLITE3EXT_H_
#define _SQLITE3EXT_H_
#include "sqlite3-binding.h"
typedef struct sqlite3_api_routines sqlite3_api_routines;
/*
** The following structure holds pointers to all of the SQLite API
** routines.
**
** WARNING: In order to maintain backwards compatibility, add new
** interfaces to the end of this structure only. If you insert new
** interfaces in the middle of this structure, then older different
** versions of SQLite will not be able to load each others' shared
** libraries!
*/
struct sqlite3_api_routines {
void * (*aggregate_context)(sqlite3_context*,int nBytes);
int (*aggregate_count)(sqlite3_context*);
int (*bind_blob)(sqlite3_stmt*,int,const void*,int n,void(*)(void*));
int (*bind_double)(sqlite3_stmt*,int,double);
int (*bind_int)(sqlite3_stmt*,int,int);
int (*bind_int64)(sqlite3_stmt*,int,sqlite_int64);
int (*bind_null)(sqlite3_stmt*,int);
int (*bind_parameter_count)(sqlite3_stmt*);
int (*bind_parameter_index)(sqlite3_stmt*,const char*zName);
const char * (*bind_parameter_name)(sqlite3_stmt*,int);
int (*bind_text)(sqlite3_stmt*,int,const char*,int n,void(*)(void*));
int (*bind_text16)(sqlite3_stmt*,int,const void*,int,void(*)(void*));
int (*bind_value)(sqlite3_stmt*,int,const sqlite3_value*);
int (*busy_handler)(sqlite3*,int(*)(void*,int),void*);
int (*busy_timeout)(sqlite3*,int ms);
int (*changes)(sqlite3*);
int (*close)(sqlite3*);
int (*collation_needed)(sqlite3*,void*,void(*)(void*,sqlite3*,
int eTextRep,const char*));
int (*collation_needed16)(sqlite3*,void*,void(*)(void*,sqlite3*,
int eTextRep,const void*));
const void * (*column_blob)(sqlite3_stmt*,int iCol);
int (*column_bytes)(sqlite3_stmt*,int iCol);
int (*column_bytes16)(sqlite3_stmt*,int iCol);
int (*column_count)(sqlite3_stmt*pStmt);
const char * (*column_database_name)(sqlite3_stmt*,int);
const void * (*column_database_name16)(sqlite3_stmt*,int);
const char * (*column_decltype)(sqlite3_stmt*,int i);
const void * (*column_decltype16)(sqlite3_stmt*,int);
double (*column_double)(sqlite3_stmt*,int iCol);
int (*column_int)(sqlite3_stmt*,int iCol);
sqlite_int64 (*column_int64)(sqlite3_stmt*,int iCol);
const char * (*column_name)(sqlite3_stmt*,int);
const void * (*column_name16)(sqlite3_stmt*,int);
const char * (*column_origin_name)(sqlite3_stmt*,int);
const void * (*column_origin_name16)(sqlite3_stmt*,int);
const char * (*column_table_name)(sqlite3_stmt*,int);
const void * (*column_table_name16)(sqlite3_stmt*,int);
const unsigned char * (*column_text)(sqlite3_stmt*,int iCol);
const void * (*column_text16)(sqlite3_stmt*,int iCol);
int (*column_type)(sqlite3_stmt*,int iCol);
sqlite3_value* (*column_value)(sqlite3_stmt*,int iCol);
void * (*commit_hook)(sqlite3*,int(*)(void*),void*);
int (*complete)(const char*sql);
int (*complete16)(const void*sql);
int (*create_collation)(sqlite3*,const char*,int,void*,
int(*)(void*,int,const void*,int,const void*));
int (*create_collation16)(sqlite3*,const void*,int,void*,
int(*)(void*,int,const void*,int,const void*));
int (*create_function)(sqlite3*,const char*,int,int,void*,
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
void (*xFinal)(sqlite3_context*));
int (*create_function16)(sqlite3*,const void*,int,int,void*,
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
void (*xFinal)(sqlite3_context*));
int (*create_module)(sqlite3*,const char*,const sqlite3_module*,void*);
int (*data_count)(sqlite3_stmt*pStmt);
sqlite3 * (*db_handle)(sqlite3_stmt*);
int (*declare_vtab)(sqlite3*,const char*);
int (*enable_shared_cache)(int);
int (*errcode)(sqlite3*db);
const char * (*errmsg)(sqlite3*);
const void * (*errmsg16)(sqlite3*);
int (*exec)(sqlite3*,const char*,sqlite3_callback,void*,char**);
int (*expired)(sqlite3_stmt*);
int (*finalize)(sqlite3_stmt*pStmt);
void (*free)(void*);
void (*free_table)(char**result);
int (*get_autocommit)(sqlite3*);
void * (*get_auxdata)(sqlite3_context*,int);
int (*get_table)(sqlite3*,const char*,char***,int*,int*,char**);
int (*global_recover)(void);
void (*interruptx)(sqlite3*);
sqlite_int64 (*last_insert_rowid)(sqlite3*);
const char * (*libversion)(void);
int (*libversion_number)(void);
void *(*malloc)(int);
char * (*mprintf)(const char*,...);
int (*open)(const char*,sqlite3**);
int (*open16)(const void*,sqlite3**);
int (*prepare)(sqlite3*,const char*,int,sqlite3_stmt**,const char**);
int (*prepare16)(sqlite3*,const void*,int,sqlite3_stmt**,const void**);
void * (*profile)(sqlite3*,void(*)(void*,const char*,sqlite_uint64),void*);
void (*progress_handler)(sqlite3*,int,int(*)(void*),void*);
void *(*realloc)(void*,int);
int (*reset)(sqlite3_stmt*pStmt);
void (*result_blob)(sqlite3_context*,const void*,int,void(*)(void*));
void (*result_double)(sqlite3_context*,double);
void (*result_error)(sqlite3_context*,const char*,int);
void (*result_error16)(sqlite3_context*,const void*,int);
void (*result_int)(sqlite3_context*,int);
void (*result_int64)(sqlite3_context*,sqlite_int64);
void (*result_null)(sqlite3_context*);
void (*result_text)(sqlite3_context*,const char*,int,void(*)(void*));
void (*result_text16)(sqlite3_context*,const void*,int,void(*)(void*));
void (*result_text16be)(sqlite3_context*,const void*,int,void(*)(void*));
void (*result_text16le)(sqlite3_context*,const void*,int,void(*)(void*));
void (*result_value)(sqlite3_context*,sqlite3_value*);
void * (*rollback_hook)(sqlite3*,void(*)(void*),void*);
int (*set_authorizer)(sqlite3*,int(*)(void*,int,const char*,const char*,
const char*,const char*),void*);
void (*set_auxdata)(sqlite3_context*,int,void*,void (*)(void*));
char * (*snprintf)(int,char*,const char*,...);
int (*step)(sqlite3_stmt*);
int (*table_column_metadata)(sqlite3*,const char*,const char*,const char*,
char const**,char const**,int*,int*,int*);
void (*thread_cleanup)(void);
int (*total_changes)(sqlite3*);
void * (*trace)(sqlite3*,void(*xTrace)(void*,const char*),void*);
int (*transfer_bindings)(sqlite3_stmt*,sqlite3_stmt*);
void * (*update_hook)(sqlite3*,void(*)(void*,int ,char const*,char const*,
sqlite_int64),void*);
void * (*user_data)(sqlite3_context*);
const void * (*value_blob)(sqlite3_value*);
int (*value_bytes)(sqlite3_value*);
int (*value_bytes16)(sqlite3_value*);
double (*value_double)(sqlite3_value*);
int (*value_int)(sqlite3_value*);
sqlite_int64 (*value_int64)(sqlite3_value*);
int (*value_numeric_type)(sqlite3_value*);
const unsigned char * (*value_text)(sqlite3_value*);
const void * (*value_text16)(sqlite3_value*);
const void * (*value_text16be)(sqlite3_value*);
const void * (*value_text16le)(sqlite3_value*);
int (*value_type)(sqlite3_value*);
char *(*vmprintf)(const char*,va_list);
/* Added ??? */
int (*overload_function)(sqlite3*, const char *zFuncName, int nArg);
/* Added by 3.3.13 */
int (*prepare_v2)(sqlite3*,const char*,int,sqlite3_stmt**,const char**);
int (*prepare16_v2)(sqlite3*,const void*,int,sqlite3_stmt**,const void**);
int (*clear_bindings)(sqlite3_stmt*);
/* Added by 3.4.1 */
int (*create_module_v2)(sqlite3*,const char*,const sqlite3_module*,void*,
void (*xDestroy)(void *));
/* Added by 3.5.0 */
int (*bind_zeroblob)(sqlite3_stmt*,int,int);
int (*blob_bytes)(sqlite3_blob*);
int (*blob_close)(sqlite3_blob*);
int (*blob_open)(sqlite3*,const char*,const char*,const char*,sqlite3_int64,
int,sqlite3_blob**);
int (*blob_read)(sqlite3_blob*,void*,int,int);
int (*blob_write)(sqlite3_blob*,const void*,int,int);
int (*create_collation_v2)(sqlite3*,const char*,int,void*,
int(*)(void*,int,const void*,int,const void*),
void(*)(void*));
int (*file_control)(sqlite3*,const char*,int,void*);
sqlite3_int64 (*memory_highwater)(int);
sqlite3_int64 (*memory_used)(void);
sqlite3_mutex *(*mutex_alloc)(int);
void (*mutex_enter)(sqlite3_mutex*);
void (*mutex_free)(sqlite3_mutex*);
void (*mutex_leave)(sqlite3_mutex*);
int (*mutex_try)(sqlite3_mutex*);
int (*open_v2)(const char*,sqlite3**,int,const char*);
int (*release_memory)(int);
void (*result_error_nomem)(sqlite3_context*);
void (*result_error_toobig)(sqlite3_context*);
int (*sleep)(int);
void (*soft_heap_limit)(int);
sqlite3_vfs *(*vfs_find)(const char*);
int (*vfs_register)(sqlite3_vfs*,int);
int (*vfs_unregister)(sqlite3_vfs*);
int (*xthreadsafe)(void);
void (*result_zeroblob)(sqlite3_context*,int);
void (*result_error_code)(sqlite3_context*,int);
int (*test_control)(int, ...);
void (*randomness)(int,void*);
sqlite3 *(*context_db_handle)(sqlite3_context*);
int (*extended_result_codes)(sqlite3*,int);
int (*limit)(sqlite3*,int,int);
sqlite3_stmt *(*next_stmt)(sqlite3*,sqlite3_stmt*);
const char *(*sql)(sqlite3_stmt*);
int (*status)(int,int*,int*,int);
int (*backup_finish)(sqlite3_backup*);
sqlite3_backup *(*backup_init)(sqlite3*,const char*,sqlite3*,const char*);
int (*backup_pagecount)(sqlite3_backup*);
int (*backup_remaining)(sqlite3_backup*);
int (*backup_step)(sqlite3_backup*,int);
const char *(*compileoption_get)(int);
int (*compileoption_used)(const char*);
int (*create_function_v2)(sqlite3*,const char*,int,int,void*,
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
void (*xFinal)(sqlite3_context*),
void(*xDestroy)(void*));
int (*db_config)(sqlite3*,int,...);
sqlite3_mutex *(*db_mutex)(sqlite3*);
int (*db_status)(sqlite3*,int,int*,int*,int);
int (*extended_errcode)(sqlite3*);
void (*log)(int,const char*,...);
sqlite3_int64 (*soft_heap_limit64)(sqlite3_int64);
const char *(*sourceid)(void);
int (*stmt_status)(sqlite3_stmt*,int,int);
int (*strnicmp)(const char*,const char*,int);
int (*unlock_notify)(sqlite3*,void(*)(void**,int),void*);
int (*wal_autocheckpoint)(sqlite3*,int);
int (*wal_checkpoint)(sqlite3*,const char*);
void *(*wal_hook)(sqlite3*,int(*)(void*,sqlite3*,const char*,int),void*);
int (*blob_reopen)(sqlite3_blob*,sqlite3_int64);
int (*vtab_config)(sqlite3*,int op,...);
int (*vtab_on_conflict)(sqlite3*);
/* Version 3.7.16 and later */
int (*close_v2)(sqlite3*);
const char *(*db_filename)(sqlite3*,const char*);
int (*db_readonly)(sqlite3*,const char*);
int (*db_release_memory)(sqlite3*);
const char *(*errstr)(int);
int (*stmt_busy)(sqlite3_stmt*);
int (*stmt_readonly)(sqlite3_stmt*);
int (*stricmp)(const char*,const char*);
int (*uri_boolean)(const char*,const char*,int);
sqlite3_int64 (*uri_int64)(const char*,const char*,sqlite3_int64);
const char *(*uri_parameter)(const char*,const char*);
char *(*vsnprintf)(int,char*,const char*,va_list);
int (*wal_checkpoint_v2)(sqlite3*,const char*,int,int*,int*);
};
/*
** The following macros redefine the API routines so that they are
** redirected throught the global sqlite3_api structure.
**
** This header file is also used by the loadext.c source file
** (part of the main SQLite library - not an extension) so that
** it can get access to the sqlite3_api_routines structure
** definition. But the main library does not want to redefine
** the API. So the redefinition macros are only valid if the
** SQLITE_CORE macros is undefined.
*/
#ifndef SQLITE_CORE
#define sqlite3_aggregate_context sqlite3_api->aggregate_context
#ifndef SQLITE_OMIT_DEPRECATED
#define sqlite3_aggregate_count sqlite3_api->aggregate_count
#endif
#define sqlite3_bind_blob sqlite3_api->bind_blob
#define sqlite3_bind_double sqlite3_api->bind_double
#define sqlite3_bind_int sqlite3_api->bind_int
#define sqlite3_bind_int64 sqlite3_api->bind_int64
#define sqlite3_bind_null sqlite3_api->bind_null
#define sqlite3_bind_parameter_count sqlite3_api->bind_parameter_count
#define sqlite3_bind_parameter_index sqlite3_api->bind_parameter_index
#define sqlite3_bind_parameter_name sqlite3_api->bind_parameter_name
#define sqlite3_bind_text sqlite3_api->bind_text
#define sqlite3_bind_text16 sqlite3_api->bind_text16
#define sqlite3_bind_value sqlite3_api->bind_value
#define sqlite3_busy_handler sqlite3_api->busy_handler
#define sqlite3_busy_timeout sqlite3_api->busy_timeout
#define sqlite3_changes sqlite3_api->changes
#define sqlite3_close sqlite3_api->close
#define sqlite3_collation_needed sqlite3_api->collation_needed
#define sqlite3_collation_needed16 sqlite3_api->collation_needed16
#define sqlite3_column_blob sqlite3_api->column_blob
#define sqlite3_column_bytes sqlite3_api->column_bytes
#define sqlite3_column_bytes16 sqlite3_api->column_bytes16
#define sqlite3_column_count sqlite3_api->column_count
#define sqlite3_column_database_name sqlite3_api->column_database_name
#define sqlite3_column_database_name16 sqlite3_api->column_database_name16
#define sqlite3_column_decltype sqlite3_api->column_decltype
#define sqlite3_column_decltype16 sqlite3_api->column_decltype16
#define sqlite3_column_double sqlite3_api->column_double
#define sqlite3_column_int sqlite3_api->column_int
#define sqlite3_column_int64 sqlite3_api->column_int64
#define sqlite3_column_name sqlite3_api->column_name
#define sqlite3_column_name16 sqlite3_api->column_name16
#define sqlite3_column_origin_name sqlite3_api->column_origin_name
#define sqlite3_column_origin_name16 sqlite3_api->column_origin_name16
#define sqlite3_column_table_name sqlite3_api->column_table_name
#define sqlite3_column_table_name16 sqlite3_api->column_table_name16
#define sqlite3_column_text sqlite3_api->column_text
#define sqlite3_column_text16 sqlite3_api->column_text16
#define sqlite3_column_type sqlite3_api->column_type
#define sqlite3_column_value sqlite3_api->column_value
#define sqlite3_commit_hook sqlite3_api->commit_hook
#define sqlite3_complete sqlite3_api->complete
#define sqlite3_complete16 sqlite3_api->complete16
#define sqlite3_create_collation sqlite3_api->create_collation
#define sqlite3_create_collation16 sqlite3_api->create_collation16
#define sqlite3_create_function sqlite3_api->create_function
#define sqlite3_create_function16 sqlite3_api->create_function16
#define sqlite3_create_module sqlite3_api->create_module
#define sqlite3_create_module_v2 sqlite3_api->create_module_v2
#define sqlite3_data_count sqlite3_api->data_count
#define sqlite3_db_handle sqlite3_api->db_handle
#define sqlite3_declare_vtab sqlite3_api->declare_vtab
#define sqlite3_enable_shared_cache sqlite3_api->enable_shared_cache
#define sqlite3_errcode sqlite3_api->errcode
#define sqlite3_errmsg sqlite3_api->errmsg
#define sqlite3_errmsg16 sqlite3_api->errmsg16
#define sqlite3_exec sqlite3_api->exec
#ifndef SQLITE_OMIT_DEPRECATED
#define sqlite3_expired sqlite3_api->expired
#endif
#define sqlite3_finalize sqlite3_api->finalize
#define sqlite3_free sqlite3_api->free
#define sqlite3_free_table sqlite3_api->free_table
#define sqlite3_get_autocommit sqlite3_api->get_autocommit
#define sqlite3_get_auxdata sqlite3_api->get_auxdata
#define sqlite3_get_table sqlite3_api->get_table
#ifndef SQLITE_OMIT_DEPRECATED
#define sqlite3_global_recover sqlite3_api->global_recover
#endif
#define sqlite3_interrupt sqlite3_api->interruptx
#define sqlite3_last_insert_rowid sqlite3_api->last_insert_rowid
#define sqlite3_libversion sqlite3_api->libversion
#define sqlite3_libversion_number sqlite3_api->libversion_number
#define sqlite3_malloc sqlite3_api->malloc
#define sqlite3_mprintf sqlite3_api->mprintf
#define sqlite3_open sqlite3_api->open
#define sqlite3_open16 sqlite3_api->open16
#define sqlite3_prepare sqlite3_api->prepare
#define sqlite3_prepare16 sqlite3_api->prepare16
#define sqlite3_prepare_v2 sqlite3_api->prepare_v2
#define sqlite3_prepare16_v2 sqlite3_api->prepare16_v2
#define sqlite3_profile sqlite3_api->profile
#define sqlite3_progress_handler sqlite3_api->progress_handler
#define sqlite3_realloc sqlite3_api->realloc
#define sqlite3_reset sqlite3_api->reset
#define sqlite3_result_blob sqlite3_api->result_blob
#define sqlite3_result_double sqlite3_api->result_double
#define sqlite3_result_error sqlite3_api->result_error
#define sqlite3_result_error16 sqlite3_api->result_error16
#define sqlite3_result_int sqlite3_api->result_int
#define sqlite3_result_int64 sqlite3_api->result_int64
#define sqlite3_result_null sqlite3_api->result_null
#define sqlite3_result_text sqlite3_api->result_text
#define sqlite3_result_text16 sqlite3_api->result_text16
#define sqlite3_result_text16be sqlite3_api->result_text16be
#define sqlite3_result_text16le sqlite3_api->result_text16le
#define sqlite3_result_value sqlite3_api->result_value
#define sqlite3_rollback_hook sqlite3_api->rollback_hook
#define sqlite3_set_authorizer sqlite3_api->set_authorizer
#define sqlite3_set_auxdata sqlite3_api->set_auxdata
#define sqlite3_snprintf sqlite3_api->snprintf
#define sqlite3_step sqlite3_api->step
#define sqlite3_table_column_metadata sqlite3_api->table_column_metadata
#define sqlite3_thread_cleanup sqlite3_api->thread_cleanup
#define sqlite3_total_changes sqlite3_api->total_changes
#define sqlite3_trace sqlite3_api->trace
#ifndef SQLITE_OMIT_DEPRECATED
#define sqlite3_transfer_bindings sqlite3_api->transfer_bindings
#endif
#define sqlite3_update_hook sqlite3_api->update_hook
#define sqlite3_user_data sqlite3_api->user_data
#define sqlite3_value_blob sqlite3_api->value_blob
#define sqlite3_value_bytes sqlite3_api->value_bytes
#define sqlite3_value_bytes16 sqlite3_api->value_bytes16
#define sqlite3_value_double sqlite3_api->value_double
#define sqlite3_value_int sqlite3_api->value_int
#define sqlite3_value_int64 sqlite3_api->value_int64
#define sqlite3_value_numeric_type sqlite3_api->value_numeric_type
#define sqlite3_value_text sqlite3_api->value_text
#define sqlite3_value_text16 sqlite3_api->value_text16
#define sqlite3_value_text16be sqlite3_api->value_text16be
#define sqlite3_value_text16le sqlite3_api->value_text16le
#define sqlite3_value_type sqlite3_api->value_type
#define sqlite3_vmprintf sqlite3_api->vmprintf
#define sqlite3_overload_function sqlite3_api->overload_function
#define sqlite3_prepare_v2 sqlite3_api->prepare_v2
#define sqlite3_prepare16_v2 sqlite3_api->prepare16_v2
#define sqlite3_clear_bindings sqlite3_api->clear_bindings
#define sqlite3_bind_zeroblob sqlite3_api->bind_zeroblob
#define sqlite3_blob_bytes sqlite3_api->blob_bytes
#define sqlite3_blob_close sqlite3_api->blob_close
#define sqlite3_blob_open sqlite3_api->blob_open
#define sqlite3_blob_read sqlite3_api->blob_read
#define sqlite3_blob_write sqlite3_api->blob_write
#define sqlite3_create_collation_v2 sqlite3_api->create_collation_v2
#define sqlite3_file_control sqlite3_api->file_control
#define sqlite3_memory_highwater sqlite3_api->memory_highwater
#define sqlite3_memory_used sqlite3_api->memory_used
#define sqlite3_mutex_alloc sqlite3_api->mutex_alloc
#define sqlite3_mutex_enter sqlite3_api->mutex_enter
#define sqlite3_mutex_free sqlite3_api->mutex_free
#define sqlite3_mutex_leave sqlite3_api->mutex_leave
#define sqlite3_mutex_try sqlite3_api->mutex_try
#define sqlite3_open_v2 sqlite3_api->open_v2
#define sqlite3_release_memory sqlite3_api->release_memory
#define sqlite3_result_error_nomem sqlite3_api->result_error_nomem
#define sqlite3_result_error_toobig sqlite3_api->result_error_toobig
#define sqlite3_sleep sqlite3_api->sleep
#define sqlite3_soft_heap_limit sqlite3_api->soft_heap_limit
#define sqlite3_vfs_find sqlite3_api->vfs_find
#define sqlite3_vfs_register sqlite3_api->vfs_register
#define sqlite3_vfs_unregister sqlite3_api->vfs_unregister
#define sqlite3_threadsafe sqlite3_api->xthreadsafe
#define sqlite3_result_zeroblob sqlite3_api->result_zeroblob
#define sqlite3_result_error_code sqlite3_api->result_error_code
#define sqlite3_test_control sqlite3_api->test_control
#define sqlite3_randomness sqlite3_api->randomness
#define sqlite3_context_db_handle sqlite3_api->context_db_handle
#define sqlite3_extended_result_codes sqlite3_api->extended_result_codes
#define sqlite3_limit sqlite3_api->limit
#define sqlite3_next_stmt sqlite3_api->next_stmt
#define sqlite3_sql sqlite3_api->sql
#define sqlite3_status sqlite3_api->status
#define sqlite3_backup_finish sqlite3_api->backup_finish
#define sqlite3_backup_init sqlite3_api->backup_init
#define sqlite3_backup_pagecount sqlite3_api->backup_pagecount
#define sqlite3_backup_remaining sqlite3_api->backup_remaining
#define sqlite3_backup_step sqlite3_api->backup_step
#define sqlite3_compileoption_get sqlite3_api->compileoption_get
#define sqlite3_compileoption_used sqlite3_api->compileoption_used
#define sqlite3_create_function_v2 sqlite3_api->create_function_v2
#define sqlite3_db_config sqlite3_api->db_config
#define sqlite3_db_mutex sqlite3_api->db_mutex
#define sqlite3_db_status sqlite3_api->db_status
#define sqlite3_extended_errcode sqlite3_api->extended_errcode
#define sqlite3_log sqlite3_api->log
#define sqlite3_soft_heap_limit64 sqlite3_api->soft_heap_limit64
#define sqlite3_sourceid sqlite3_api->sourceid
#define sqlite3_stmt_status sqlite3_api->stmt_status
#define sqlite3_strnicmp sqlite3_api->strnicmp
#define sqlite3_unlock_notify sqlite3_api->unlock_notify
#define sqlite3_wal_autocheckpoint sqlite3_api->wal_autocheckpoint
#define sqlite3_wal_checkpoint sqlite3_api->wal_checkpoint
#define sqlite3_wal_hook sqlite3_api->wal_hook
#define sqlite3_blob_reopen sqlite3_api->blob_reopen
#define sqlite3_vtab_config sqlite3_api->vtab_config
#define sqlite3_vtab_on_conflict sqlite3_api->vtab_on_conflict
/* Version 3.7.16 and later */
#define sqlite3_close_v2 sqlite3_api->close_v2
#define sqlite3_db_filename sqlite3_api->db_filename
#define sqlite3_db_readonly sqlite3_api->db_readonly
#define sqlite3_db_release_memory sqlite3_api->db_release_memory
#define sqlite3_errstr sqlite3_api->errstr
#define sqlite3_stmt_busy sqlite3_api->stmt_busy
#define sqlite3_stmt_readonly sqlite3_api->stmt_readonly
#define sqlite3_stricmp sqlite3_api->stricmp
#define sqlite3_uri_boolean sqlite3_api->uri_boolean
#define sqlite3_uri_int64 sqlite3_api->uri_int64
#define sqlite3_uri_parameter sqlite3_api->uri_parameter
#define sqlite3_uri_vsnprintf sqlite3_api->vsnprintf
#define sqlite3_wal_checkpoint_v2 sqlite3_api->wal_checkpoint_v2
#endif /* SQLITE_CORE */
#ifndef SQLITE_CORE
/* This case when the file really is being compiled as a loadable
** extension */
# define SQLITE_EXTENSION_INIT1 const sqlite3_api_routines *sqlite3_api=0;
# define SQLITE_EXTENSION_INIT2(v) sqlite3_api=v;
# define SQLITE_EXTENSION_INIT3 \
extern const sqlite3_api_routines *sqlite3_api;
#else
/* This case when the file is being statically linked into the
** application */
# define SQLITE_EXTENSION_INIT1 /*no-op*/
# define SQLITE_EXTENSION_INIT2(v) (void)v; /* unused parameter */
# define SQLITE_EXTENSION_INIT3 /*no-op*/
#endif
#endif /* _SQLITE3EXT_H_ */

View file

@ -0,0 +1,112 @@
/*
Package sqlite3 provides interface to SQLite3 databases.
This works as driver for database/sql.
Installation
go get github.com/mattn/go-sqlite3
Supported Types
Currently, go-sqlite3 support following data types.
+------------------------------+
|go | sqlite3 |
|----------|-------------------|
|nil | null |
|int | integer |
|int64 | integer |
|float64 | float |
|bool | integer |
|[]byte | blob |
|string | text |
|time.Time | timestamp/datetime|
+------------------------------+
SQLite3 Extension
You can write your own extension module for sqlite3. For example, below is a
extension for Regexp matcher operation.
#include <pcre.h>
#include <string.h>
#include <stdio.h>
#include <sqlite3ext.h>
SQLITE_EXTENSION_INIT1
static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) {
if (argc >= 2) {
const char *target = (const char *)sqlite3_value_text(argv[1]);
const char *pattern = (const char *)sqlite3_value_text(argv[0]);
const char* errstr = NULL;
int erroff = 0;
int vec[500];
int n, rc;
pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL);
rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500);
if (rc <= 0) {
sqlite3_result_error(context, errstr, 0);
return;
}
sqlite3_result_int(context, 1);
}
}
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_extension_init(sqlite3 *db, char **errmsg,
const sqlite3_api_routines *api) {
SQLITE_EXTENSION_INIT2(api);
return sqlite3_create_function(db, "regexp", 2, SQLITE_UTF8,
(void*)db, regexp_func, NULL, NULL);
}
It need to build as so/dll shared library. And you need to register
extension module like below.
sql.Register("sqlite3_with_extensions",
&sqlite3.SQLiteDriver{
Extensions: []string{
"sqlite3_mod_regexp",
},
})
Then, you can use this extension.
rows, err := db.Query("select text from mytable where name regexp '^golang'")
Connection Hook
You can hook and inject your codes when connection established. database/sql
doesn't provide the way to get native go-sqlite3 interfaces. So if you want,
you need to hook ConnectHook and get the SQLiteConn.
sql.Register("sqlite3_with_hook_example",
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
sqlite3conn = append(sqlite3conn, conn)
return nil
},
})
Go SQlite3 Extensions
If you want to register Go functions as SQLite extension functions,
call RegisterFunction from ConnectHook.
regex = func(re, s string) (bool, error) {
return regexp.MatchString(re, s)
}
sql.Register("sqlite3_with_go_func",
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.RegisterFunc("regex", regex, true)
},
})
See the documentation of RegisterFunc for more details.
*/
package sqlite3

View file

@ -0,0 +1,128 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
import "C"
type ErrNo int
const ErrNoMask C.int = 0xff
type ErrNoExtended int
type Error struct {
Code ErrNo /* The error code returned by SQLite */
ExtendedCode ErrNoExtended /* The extended error code returned by SQLite */
err string /* The error string returned by sqlite3_errmsg(),
this usually contains more specific details. */
}
// result codes from http://www.sqlite.org/c3ref/c_abort.html
var (
ErrError = ErrNo(1) /* SQL error or missing database */
ErrInternal = ErrNo(2) /* Internal logic error in SQLite */
ErrPerm = ErrNo(3) /* Access permission denied */
ErrAbort = ErrNo(4) /* Callback routine requested an abort */
ErrBusy = ErrNo(5) /* The database file is locked */
ErrLocked = ErrNo(6) /* A table in the database is locked */
ErrNomem = ErrNo(7) /* A malloc() failed */
ErrReadonly = ErrNo(8) /* Attempt to write a readonly database */
ErrInterrupt = ErrNo(9) /* Operation terminated by sqlite3_interrupt() */
ErrIoErr = ErrNo(10) /* Some kind of disk I/O error occurred */
ErrCorrupt = ErrNo(11) /* The database disk image is malformed */
ErrNotFound = ErrNo(12) /* Unknown opcode in sqlite3_file_control() */
ErrFull = ErrNo(13) /* Insertion failed because database is full */
ErrCantOpen = ErrNo(14) /* Unable to open the database file */
ErrProtocol = ErrNo(15) /* Database lock protocol error */
ErrEmpty = ErrNo(16) /* Database is empty */
ErrSchema = ErrNo(17) /* The database schema changed */
ErrTooBig = ErrNo(18) /* String or BLOB exceeds size limit */
ErrConstraint = ErrNo(19) /* Abort due to constraint violation */
ErrMismatch = ErrNo(20) /* Data type mismatch */
ErrMisuse = ErrNo(21) /* Library used incorrectly */
ErrNoLFS = ErrNo(22) /* Uses OS features not supported on host */
ErrAuth = ErrNo(23) /* Authorization denied */
ErrFormat = ErrNo(24) /* Auxiliary database format error */
ErrRange = ErrNo(25) /* 2nd parameter to sqlite3_bind out of range */
ErrNotADB = ErrNo(26) /* File opened that is not a database file */
ErrNotice = ErrNo(27) /* Notifications from sqlite3_log() */
ErrWarning = ErrNo(28) /* Warnings from sqlite3_log() */
)
func (err ErrNo) Error() string {
return Error{Code: err}.Error()
}
func (err ErrNo) Extend(by int) ErrNoExtended {
return ErrNoExtended(int(err) | (by << 8))
}
func (err ErrNoExtended) Error() string {
return Error{Code: ErrNo(C.int(err) & ErrNoMask), ExtendedCode: err}.Error()
}
func (err Error) Error() string {
if err.err != "" {
return err.err
}
return errorString(err)
}
// result codes from http://www.sqlite.org/c3ref/c_abort_rollback.html
var (
ErrIoErrRead = ErrIoErr.Extend(1)
ErrIoErrShortRead = ErrIoErr.Extend(2)
ErrIoErrWrite = ErrIoErr.Extend(3)
ErrIoErrFsync = ErrIoErr.Extend(4)
ErrIoErrDirFsync = ErrIoErr.Extend(5)
ErrIoErrTruncate = ErrIoErr.Extend(6)
ErrIoErrFstat = ErrIoErr.Extend(7)
ErrIoErrUnlock = ErrIoErr.Extend(8)
ErrIoErrRDlock = ErrIoErr.Extend(9)
ErrIoErrDelete = ErrIoErr.Extend(10)
ErrIoErrBlocked = ErrIoErr.Extend(11)
ErrIoErrNoMem = ErrIoErr.Extend(12)
ErrIoErrAccess = ErrIoErr.Extend(13)
ErrIoErrCheckReservedLock = ErrIoErr.Extend(14)
ErrIoErrLock = ErrIoErr.Extend(15)
ErrIoErrClose = ErrIoErr.Extend(16)
ErrIoErrDirClose = ErrIoErr.Extend(17)
ErrIoErrSHMOpen = ErrIoErr.Extend(18)
ErrIoErrSHMSize = ErrIoErr.Extend(19)
ErrIoErrSHMLock = ErrIoErr.Extend(20)
ErrIoErrSHMMap = ErrIoErr.Extend(21)
ErrIoErrSeek = ErrIoErr.Extend(22)
ErrIoErrDeleteNoent = ErrIoErr.Extend(23)
ErrIoErrMMap = ErrIoErr.Extend(24)
ErrIoErrGetTempPath = ErrIoErr.Extend(25)
ErrIoErrConvPath = ErrIoErr.Extend(26)
ErrLockedSharedCache = ErrLocked.Extend(1)
ErrBusyRecovery = ErrBusy.Extend(1)
ErrBusySnapshot = ErrBusy.Extend(2)
ErrCantOpenNoTempDir = ErrCantOpen.Extend(1)
ErrCantOpenIsDir = ErrCantOpen.Extend(2)
ErrCantOpenFullPath = ErrCantOpen.Extend(3)
ErrCantOpenConvPath = ErrCantOpen.Extend(4)
ErrCorruptVTab = ErrCorrupt.Extend(1)
ErrReadonlyRecovery = ErrReadonly.Extend(1)
ErrReadonlyCantLock = ErrReadonly.Extend(2)
ErrReadonlyRollback = ErrReadonly.Extend(3)
ErrReadonlyDbMoved = ErrReadonly.Extend(4)
ErrAbortRollback = ErrAbort.Extend(2)
ErrConstraintCheck = ErrConstraint.Extend(1)
ErrConstraintCommitHook = ErrConstraint.Extend(2)
ErrConstraintForeignKey = ErrConstraint.Extend(3)
ErrConstraintFunction = ErrConstraint.Extend(4)
ErrConstraintNotNull = ErrConstraint.Extend(5)
ErrConstraintPrimaryKey = ErrConstraint.Extend(6)
ErrConstraintTrigger = ErrConstraint.Extend(7)
ErrConstraintUnique = ErrConstraint.Extend(8)
ErrConstraintVTab = ErrConstraint.Extend(9)
ErrConstraintRowId = ErrConstraint.Extend(10)
ErrNoticeRecoverWAL = ErrNotice.Extend(1)
ErrNoticeRecoverRollback = ErrNotice.Extend(2)
ErrWarningAutoIndex = ErrWarning.Extend(1)
)

View file

@ -0,0 +1,4 @@
#ifndef USE_LIBSQLITE3
# include "code/sqlite3-binding.c"
#endif

View file

@ -0,0 +1,5 @@
#ifndef USE_LIBSQLITE3
#include "code/sqlite3-binding.h"
#else
#include <sqlite3.h>
#endif

View file

@ -0,0 +1,977 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
/*
#cgo CFLAGS: -std=gnu99
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
#include <sqlite3-binding.h>
#include <stdlib.h>
#include <string.h>
#ifdef __CYGWIN__
# include <errno.h>
#endif
#ifndef SQLITE_OPEN_READWRITE
# define SQLITE_OPEN_READWRITE 0
#endif
#ifndef SQLITE_OPEN_FULLMUTEX
# define SQLITE_OPEN_FULLMUTEX 0
#endif
static int
_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) {
#ifdef SQLITE_OPEN_URI
return sqlite3_open_v2(filename, ppDb, flags | SQLITE_OPEN_URI, zVfs);
#else
return sqlite3_open_v2(filename, ppDb, flags, zVfs);
#endif
}
static int
_sqlite3_bind_text(sqlite3_stmt *stmt, int n, char *p, int np) {
return sqlite3_bind_text(stmt, n, p, np, SQLITE_TRANSIENT);
}
static int
_sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
return sqlite3_bind_blob(stmt, n, p, np, SQLITE_TRANSIENT);
}
#include <stdio.h>
#include <stdint.h>
static int
_sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* changes)
{
int rv = sqlite3_exec(db, pcmd, 0, 0, 0);
*rowid = (long long) sqlite3_last_insert_rowid(db);
*changes = (long long) sqlite3_changes(db);
return rv;
}
static int
_sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes)
{
int rv = sqlite3_step(stmt);
sqlite3* db = sqlite3_db_handle(stmt);
*rowid = (long long) sqlite3_last_insert_rowid(db);
*changes = (long long) sqlite3_changes(db);
return rv;
}
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
sqlite3_result_text(ctx, s, -1, &free);
}
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) {
sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT);
}
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*);
*/
import "C"
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"net/url"
"reflect"
"runtime"
"strconv"
"strings"
"time"
"unsafe"
)
// Timestamp formats understood by both this module and SQLite.
// The first format in the slice will be used when saving time values
// into the database. When parsing a string from a timestamp or
// datetime column, the formats are tried in order.
var SQLiteTimestampFormats = []string{
// By default, store timestamps with whatever timezone they come with.
// When parsed, they will be returned with the same timezone.
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02T15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04",
"2006-01-02",
}
func init() {
sql.Register("sqlite3", &SQLiteDriver{})
}
// Return SQLite library Version information.
func Version() (libVersion string, libVersionNumber int, sourceId string) {
libVersion = C.GoString(C.sqlite3_libversion())
libVersionNumber = int(C.sqlite3_libversion_number())
sourceId = C.GoString(C.sqlite3_sourceid())
return libVersion, libVersionNumber, sourceId
}
// Driver struct.
type SQLiteDriver struct {
Extensions []string
ConnectHook func(*SQLiteConn) error
}
// Conn struct.
type SQLiteConn struct {
db *C.sqlite3
loc *time.Location
txlock string
funcs []*functionInfo
aggregators []*aggInfo
}
// Tx struct.
type SQLiteTx struct {
c *SQLiteConn
}
// Stmt struct.
type SQLiteStmt struct {
c *SQLiteConn
s *C.sqlite3_stmt
nv int
nn []string
t string
closed bool
cls bool
}
// Result struct.
type SQLiteResult struct {
id int64
changes int64
}
// Rows struct.
type SQLiteRows struct {
s *SQLiteStmt
nc int
cols []string
decltype []string
cls bool
}
type functionInfo struct {
f reflect.Value
argConverters []callbackArgConverter
variadicConverter callbackArgConverter
retConverter callbackRetConverter
}
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
args, err := callbackConvertArgs(argv, fi.argConverters, fi.variadicConverter)
if err != nil {
callbackError(ctx, err)
return
}
ret := fi.f.Call(args)
if len(ret) == 2 && ret[1].Interface() != nil {
callbackError(ctx, ret[1].Interface().(error))
return
}
err = fi.retConverter(ctx, ret[0])
if err != nil {
callbackError(ctx, err)
return
}
}
type aggInfo struct {
constructor reflect.Value
// Active aggregator objects for aggregations in flight. The
// aggregators are indexed by a counter stored in the aggregation
// user data space provided by sqlite.
active map[int64]reflect.Value
next int64
stepArgConverters []callbackArgConverter
stepVariadicConverter callbackArgConverter
doneRetConverter callbackRetConverter
}
func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8)))
if *aggIdx == 0 {
*aggIdx = ai.next
ret := ai.constructor.Call(nil)
if len(ret) == 2 && ret[1].Interface() != nil {
return 0, reflect.Value{}, ret[1].Interface().(error)
}
if ret[0].IsNil() {
return 0, reflect.Value{}, errors.New("aggregator constructor returned nil state")
}
ai.next++
ai.active[*aggIdx] = ret[0]
}
return *aggIdx, ai.active[*aggIdx], nil
}
func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
_, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}
args, err := callbackConvertArgs(argv, ai.stepArgConverters, ai.stepVariadicConverter)
if err != nil {
callbackError(ctx, err)
return
}
ret := agg.MethodByName("Step").Call(args)
if len(ret) == 1 && ret[0].Interface() != nil {
callbackError(ctx, ret[0].Interface().(error))
return
}
}
func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
idx, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}
defer func() { delete(ai.active, idx) }()
ret := agg.MethodByName("Done").Call(nil)
if len(ret) == 2 && ret[1].Interface() != nil {
callbackError(ctx, ret[1].Interface().(error))
return
}
err = ai.doneRetConverter(ctx, ret[0])
if err != nil {
callbackError(ctx, err)
return
}
}
// Commit transaction.
func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec("COMMIT")
return err
}
// Rollback transaction.
func (tx *SQLiteTx) Rollback() error {
_, err := tx.c.exec("ROLLBACK")
return err
}
// RegisterFunc makes a Go function available as a SQLite function.
//
// The Go function can have arguments of the following types: any
// numeric type except complex, bool, []byte, string and
// interface{}. interface{} arguments are given the direct translation
// of the SQLite data type: int64 for INTEGER, float64 for FLOAT,
// []byte for BLOB, string for TEXT.
//
// The function can additionally be variadic, as long as the type of
// the variadic argument is one of the above.
//
// If pure is true. SQLite will assume that the function's return
// value depends only on its inputs, and make more aggressive
// optimizations in its queries.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error {
var fi functionInfo
fi.f = reflect.ValueOf(impl)
t := fi.f.Type()
if t.Kind() != reflect.Func {
return errors.New("Non-function passed to RegisterFunc")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite functions must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
numArgs := t.NumIn()
if t.IsVariadic() {
numArgs--
}
for i := 0; i < numArgs; i++ {
conv, err := callbackArg(t.In(i))
if err != nil {
return err
}
fi.argConverters = append(fi.argConverters, conv)
}
if t.IsVariadic() {
conv, err := callbackArg(t.In(numArgs).Elem())
if err != nil {
return err
}
fi.variadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
numArgs = -1
}
conv, err := callbackRet(t.Out(0))
if err != nil {
return err
}
fi.retConverter = conv
// fi must outlast the database connection, or we'll have dangling pointers.
c.funcs = append(c.funcs, &fi)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := C.sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
//
// The constructor function and the Step/Done methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
var ai aggInfo
ai.constructor = reflect.ValueOf(impl)
t := ai.constructor.Type()
if t.Kind() != reflect.Func {
return errors.New("non-function passed to RegisterAggregator")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
if t.NumIn() != 0 {
return errors.New("SQLite aggregator constructors must not have arguments")
}
agg := t.Out(0)
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
default:
return errors.New("SQlite aggregator constructor must return a pointer object")
}
stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQlite aggregator doesn't have a Step() function")
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQlite aggregator Step() return value must be error")
}
stepNArgs := step.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
stepNArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
if err != nil {
return err
}
ai.stepArgConverters = append(ai.stepArgConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
if err != nil {
return err
}
ai.stepVariadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
}
doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQlite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
doneNArgs--
}
if doneNArgs != 0 {
return errors.New("SQlite aggregator Done() function must have no arguments")
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
}
conv, err := callbackRet(done.Out(0))
if err != nil {
return err
}
ai.doneRetConverter = conv
ai.active = make(map[int64]reflect.Value)
ai.next = 1
// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := C.sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), unsafe.Pointer(&ai), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// AutoCommit return which currently auto commit or not.
func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0
}
func (c *SQLiteConn) lastError() Error {
return Error{
Code: ErrNo(C.sqlite3_errcode(c.db)),
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)),
err: C.GoString(C.sqlite3_errmsg(c.db)),
}
}
// Implements Execer
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if len(args) == 0 {
return c.exec(query)
}
for {
s, err := c.Prepare(query)
if err != nil {
return nil, err
}
var res driver.Result
if s.(*SQLiteStmt).s != nil {
na := s.NumInput()
if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
}
res, err = s.Exec(args[:na])
if err != nil && err != driver.ErrSkip {
s.Close()
return nil, err
}
args = args[na:]
}
tail := s.(*SQLiteStmt).t
s.Close()
if tail == "" {
return res, nil
}
query = tail
}
}
// Implements Queryer
func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
for {
s, err := c.Prepare(query)
if err != nil {
return nil, err
}
s.(*SQLiteStmt).cls = true
na := s.NumInput()
if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
}
rows, err := s.Query(args[:na])
if err != nil && err != driver.ErrSkip {
s.Close()
return nil, err
}
args = args[na:]
tail := s.(*SQLiteStmt).t
if tail == "" {
return rows, nil
}
rows.Close()
s.Close()
query = tail
}
}
func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd))
var rowid, changes C.longlong
rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
if rv != C.SQLITE_OK {
return nil, c.lastError()
}
return &SQLiteResult{int64(rowid), int64(changes)}, nil
}
// Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) {
if _, err := c.exec(c.txlock); err != nil {
return nil, err
}
return &SQLiteTx{c}, nil
}
func errorString(err Error) string {
return C.GoString(C.sqlite3_errstr(C.int(err.Code)))
}
// Open database and return a new connection.
// You can specify DSN string with URI filename.
// test.db
// file:test.db?cache=shared&mode=memory
// :memory:
// file::memory:
// go-sqlite3 adds the following query parameters to those used by SQLite:
// _loc=XXX
// Specify location of time format. It's possible to specify "auto".
// _busy_timeout=XXX
// Specify value for sqlite3_busy_timeout.
// _txlock=XXX
// Specify locking behavior for transactions. XXX can be "immediate",
// "deferred", "exclusive".
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
}
var loc *time.Location
txlock := "BEGIN"
busy_timeout := 5000
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:])
if err != nil {
return nil, err
}
// _loc
if val := params.Get("_loc"); val != "" {
if val == "auto" {
loc = time.Local
} else {
loc, err = time.LoadLocation(val)
if err != nil {
return nil, fmt.Errorf("Invalid _loc: %v: %v", val, err)
}
}
}
// _busy_timeout
if val := params.Get("_busy_timeout"); val != "" {
iv, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err)
}
busy_timeout = int(iv)
}
// _txlock
if val := params.Get("_txlock"); val != "" {
switch val {
case "immediate":
txlock = "BEGIN IMMEDIATE"
case "exclusive":
txlock = "BEGIN EXCLUSIVE"
case "deferred":
txlock = "BEGIN"
default:
return nil, fmt.Errorf("Invalid _txlock: %v", val)
}
}
if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
}
var db *C.sqlite3
name := C.CString(dsn)
defer C.free(unsafe.Pointer(name))
rv := C._sqlite3_open_v2(name, &db,
C.SQLITE_OPEN_FULLMUTEX|
C.SQLITE_OPEN_READWRITE|
C.SQLITE_OPEN_CREATE,
nil)
if rv != 0 {
return nil, Error{Code: ErrNo(rv)}
}
if db == nil {
return nil, errors.New("sqlite succeeded without returning a database")
}
rv = C.sqlite3_busy_timeout(db, C.int(busy_timeout))
if rv != C.SQLITE_OK {
return nil, Error{Code: ErrNo(rv)}
}
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
if len(d.Extensions) > 0 {
if err := conn.loadExtensions(d.Extensions); err != nil {
return nil, err
}
}
if d.ConnectHook != nil {
if err := d.ConnectHook(conn); err != nil {
return nil, err
}
}
runtime.SetFinalizer(conn, (*SQLiteConn).Close)
return conn, nil
}
// Close the connection.
func (c *SQLiteConn) Close() error {
rv := C.sqlite3_close_v2(c.db)
if rv != C.SQLITE_OK {
return c.lastError()
}
c.db = nil
runtime.SetFinalizer(c, nil)
return nil
}
// Prepare query string. Return a new statement.
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
pquery := C.CString(query)
defer C.free(unsafe.Pointer(pquery))
var s *C.sqlite3_stmt
var tail *C.char
rv := C.sqlite3_prepare_v2(c.db, pquery, -1, &s, &tail)
if rv != C.SQLITE_OK {
return nil, c.lastError()
}
var t string
if tail != nil && *tail != '\000' {
t = strings.TrimSpace(C.GoString(tail))
}
nv := int(C.sqlite3_bind_parameter_count(s))
var nn []string
for i := 0; i < nv; i++ {
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
}
}
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
return ss, nil
}
// Close the statement.
func (s *SQLiteStmt) Close() error {
if s.closed {
return nil
}
s.closed = true
if s.c == nil || s.c.db == nil {
return errors.New("sqlite statement with already closed database connection")
}
rv := C.sqlite3_finalize(s.s)
if rv != C.SQLITE_OK {
return s.c.lastError()
}
runtime.SetFinalizer(s, nil)
return nil
}
// Return a number of parameters.
func (s *SQLiteStmt) NumInput() int {
return s.nv
}
type bindArg struct {
n int
v driver.Value
}
func (s *SQLiteStmt) bind(args []driver.Value) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return s.c.lastError()
}
var vargs []bindArg
narg := len(args)
vargs = make([]bindArg, narg)
if len(s.nn) > 0 {
for i, v := range s.nn {
if pi, err := strconv.Atoi(v[1:]); err == nil {
vargs[i] = bindArg{pi, args[i]}
}
}
} else {
for i, v := range args {
vargs[i] = bindArg{i + 1, v}
}
}
for _, varg := range vargs {
n := C.int(varg.n)
v := varg.v
switch v := v.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
if len(v) == 0 {
b := []byte{0}
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
} else {
b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case bool:
if bool(v) {
rv = C.sqlite3_bind_int(s.s, n, 1)
} else {
rv = C.sqlite3_bind_int(s.s, n, 0)
}
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
var p *byte
if len(v) > 0 {
p = &v[0]
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
case time.Time:
b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return s.c.lastError()
}
}
return nil
}
// Query the statement with arguments. Return records.
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
if err := s.bind(args); err != nil {
return nil, err
}
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil, s.cls}, nil
}
// Return last inserted ID.
func (r *SQLiteResult) LastInsertId() (int64, error) {
return r.id, nil
}
// Return how many rows affected.
func (r *SQLiteResult) RowsAffected() (int64, error) {
return r.changes, nil
}
// Execute the statement with arguments. Return result object.
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
}
var rowid, changes C.longlong
rv := C._sqlite3_step(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
err := s.c.lastError()
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
}
return &SQLiteResult{int64(rowid), int64(changes)}, nil
}
// Close the rows.
func (rc *SQLiteRows) Close() error {
if rc.s.closed {
return nil
}
if rc.cls {
return rc.s.Close()
}
rv := C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return rc.s.c.lastError()
}
return nil
}
// Return column names.
func (rc *SQLiteRows) Columns() []string {
if rc.nc != len(rc.cols) {
rc.cols = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
}
}
return rc.cols
}
// Move cursor to next.
func (rc *SQLiteRows) Next(dest []driver.Value) error {
rv := C.sqlite3_step(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF
}
if rv != C.SQLITE_ROW {
rv = C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return rc.s.c.lastError()
}
return nil
}
if rc.decltype == nil {
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
}
}
for i := range dest {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
case C.SQLITE_INTEGER:
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
switch rc.decltype[i] {
case "timestamp", "datetime", "date":
var t time.Time
// Assume a millisecond unix timestamp if it's 13 digits -- too
// large to be a reasonable timestamp in seconds.
if val > 1e12 || val < -1e12 {
val *= int64(time.Millisecond) // convert ms to nsec
} else {
val *= int64(time.Second) // convert sec to nsec
}
t = time.Unix(0, val).UTC()
if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc)
}
dest[i] = t
case "boolean":
dest[i] = val > 0
default:
dest[i] = val
}
case C.SQLITE_FLOAT:
dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i)))
case C.SQLITE_BLOB:
p := C.sqlite3_column_blob(rc.s.s, C.int(i))
if p == nil {
dest[i] = nil
continue
}
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
switch dest[i].(type) {
case sql.RawBytes:
dest[i] = (*[1 << 30]byte)(unsafe.Pointer(p))[0:n]
default:
slice := make([]byte, n)
copy(slice[:], (*[1 << 30]byte)(unsafe.Pointer(p))[0:n])
dest[i] = slice
}
case C.SQLITE_NULL:
dest[i] = nil
case C.SQLITE_TEXT:
var err error
var timeVal time.Time
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n))
switch rc.decltype[i] {
case "timestamp", "datetime", "date":
var t time.Time
s = strings.TrimSuffix(s, "Z")
for _, format := range SQLiteTimestampFormats {
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
t = timeVal
break
}
}
if err != nil {
// The column is a time value, so return the zero time on parse failure.
t = time.Time{}
}
if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc)
}
dest[i] = t
default:
dest[i] = []byte(s)
}
}
}
return nil
}

View file

@ -0,0 +1,13 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build icu
package sqlite3
/*
#cgo LDFLAGS: -licuuc -licui18n
#cgo CFLAGS: -DSQLITE_ENABLE_ICU
*/
import "C"

View file

@ -0,0 +1,13 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build libsqlite3
package sqlite3
/*
#cgo CFLAGS: -DUSE_LIBSQLITE3
#cgo LDFLAGS: -lsqlite3
*/
import "C"

View file

@ -0,0 +1,39 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build !sqlite_omit_load_extension
package sqlite3
/*
#include <sqlite3-binding.h>
#include <stdlib.h>
*/
import "C"
import (
"errors"
"unsafe"
)
func (c *SQLiteConn) loadExtensions(extensions []string) error {
rv := C.sqlite3_enable_load_extension(c.db, 1)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
for _, extension := range extensions {
cext := C.CString(extension)
defer C.free(unsafe.Pointer(cext))
rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
}
rv = C.sqlite3_enable_load_extension(c.db, 0)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
return nil
}

View file

@ -0,0 +1,19 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build sqlite_omit_load_extension
package sqlite3
/*
#cgo CFLAGS: -DSQLITE_OMIT_LOAD_EXTENSION
*/
import "C"
import (
"errors"
)
func (c *SQLiteConn) loadExtensions(extensions []string) error {
return errors.New("Extensions have been disabled for static builds")
}

View file

@ -0,0 +1,13 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build !windows
package sqlite3
/*
#cgo CFLAGS: -I.
#cgo linux LDFLAGS: -ldl
*/
import "C"

View file

@ -0,0 +1,409 @@
package sqlite3_test
import (
"database/sql"
"fmt"
"math/rand"
"regexp"
"strconv"
"sync"
"testing"
"time"
)
type Dialect int
const (
SQLITE Dialect = iota
POSTGRESQL
MYSQL
)
type DB struct {
*testing.T
*sql.DB
dialect Dialect
once sync.Once
}
var db *DB
// the following tables will be created and dropped during the test
var testTables = []string{"foo", "bar", "t", "bench"}
var tests = []testing.InternalTest{
{"TestBlobs", TestBlobs},
{"TestManyQueryRow", TestManyQueryRow},
{"TestTxQuery", TestTxQuery},
{"TestPreparedStmt", TestPreparedStmt},
}
var benchmarks = []testing.InternalBenchmark{
{"BenchmarkExec", BenchmarkExec},
{"BenchmarkQuery", BenchmarkQuery},
{"BenchmarkParams", BenchmarkParams},
{"BenchmarkStmt", BenchmarkStmt},
{"BenchmarkRows", BenchmarkRows},
{"BenchmarkStmtRows", BenchmarkStmtRows},
}
// RunTests runs the SQL test suite
func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
db = &DB{t, d, dialect, sync.Once{}}
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
if !testing.Short() {
for _, b := range benchmarks {
fmt.Printf("%-20s", b.Name)
r := testing.Benchmark(b.F)
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
}
}
db.tearDown()
}
func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
res, err := db.Exec(sql, args...)
if err != nil {
db.Fatalf("Error running %q: %v", sql, err)
}
return res
}
func (db *DB) tearDown() {
for _, tbl := range testTables {
switch db.dialect {
case SQLITE:
db.mustExec("drop table if exists " + tbl)
case MYSQL, POSTGRESQL:
db.mustExec("drop table if exists " + tbl)
default:
db.Fatal("unkown dialect")
}
}
}
// q replaces ? parameters if needed
func (db *DB) q(sql string) string {
switch db.dialect {
case POSTGRESQL: // repace with $1, $2, ..
qrx := regexp.MustCompile(`\?`)
n := 0
return qrx.ReplaceAllStringFunc(sql, func(string) string {
n++
return "$" + strconv.Itoa(n)
})
}
return sql
}
func (db *DB) blobType(size int) string {
switch db.dialect {
case SQLITE:
return fmt.Sprintf("blob[%d]", size)
case POSTGRESQL:
return "bytea"
case MYSQL:
return fmt.Sprintf("VARBINARY(%d)", size)
}
panic("unkown dialect")
}
func (db *DB) serialPK() string {
switch db.dialect {
case SQLITE:
return "integer primary key autoincrement"
case POSTGRESQL:
return "serial primary key"
case MYSQL:
return "integer primary key auto_increment"
}
panic("unkown dialect")
}
func (db *DB) now() string {
switch db.dialect {
case SQLITE:
return "datetime('now')"
case POSTGRESQL:
return "now()"
case MYSQL:
return "now()"
}
panic("unkown dialect")
}
func makeBench() {
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
panic(err)
}
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
if err != nil {
panic(err)
}
defer st.Close()
for i := 0; i < 100; i++ {
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
panic(err)
}
}
}
func TestResult(t *testing.T) {
db.tearDown()
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
for i := 1; i < 3; i++ {
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
n, err := r.RowsAffected()
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Errorf("got %v, want %v", n, 1)
}
n, err = r.LastInsertId()
if err != nil {
t.Fatal(err)
}
if n != int64(i) {
t.Errorf("got %v, want %v", n, i)
}
}
if _, err := db.Exec("error!"); err == nil {
t.Fatalf("expected error")
}
}
func TestBlobs(t *testing.T) {
db.tearDown()
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
want := fmt.Sprintf("%x", blob)
b := make([]byte, 16)
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
got := fmt.Sprintf("%x", b)
if err != nil {
t.Errorf("[]byte scan: %v", err)
} else if got != want {
t.Errorf("for []byte, got %q; want %q", got, want)
}
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
want = string(blob)
if err != nil {
t.Errorf("string scan: %v", err)
} else if got != want {
t.Errorf("for string, got %q; want %q", got, want)
}
}
func TestManyQueryRow(t *testing.T) {
if testing.Short() {
t.Log("skipping in short mode")
return
}
db.tearDown()
db.mustExec("create table foo (id integer primary key, name varchar(50))")
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
var name string
for i := 0; i < 10000; i++ {
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
if err != nil || name != "bob" {
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
}
}
}
func TestTxQuery(t *testing.T) {
db.tearDown()
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
if err != nil {
t.Fatal(err)
}
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
if err != nil {
t.Fatal(err)
}
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
}
var name string
err = r.Scan(&name)
if err != nil {
t.Fatal(err)
}
}
func TestPreparedStmt(t *testing.T) {
db.tearDown()
db.mustExec("CREATE TABLE t (count INT)")
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
if err != nil {
t.Fatalf("prepare 1: %v", err)
}
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
if err != nil {
t.Fatalf("prepare 2: %v", err)
}
for n := 1; n <= 3; n++ {
if _, err := ins.Exec(n); err != nil {
t.Fatalf("insert(%d) = %v", n, err)
}
}
const nRuns = 10
var wg sync.WaitGroup
for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
t.Errorf("Query: %v", err)
return
}
if _, err := ins.Exec(rand.Intn(100)); err != nil {
t.Errorf("Insert: %v", err)
return
}
}
}()
}
wg.Wait()
}
// Benchmarks need to use panic() since b.Error errors are lost when
// running via testing.Benchmark() I would like to run these via go
// test -bench but calling Benchmark() from a benchmark test
// currently hangs go.
func BenchmarkExec(b *testing.B) {
for i := 0; i < b.N; i++ {
if _, err := db.Exec("select 1"); err != nil {
panic(err)
}
}
}
func BenchmarkQuery(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
func BenchmarkParams(b *testing.B) {
for i := 0; i < b.N; i++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
func BenchmarkStmt(b *testing.B) {
st, err := db.Prepare("select ?, ?, ?, ?")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
// var t time.Time
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
}
}
func BenchmarkRows(b *testing.B) {
db.once.Do(makeBench)
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := db.Query("select * from bench")
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}
func BenchmarkStmtRows(b *testing.B) {
db.once.Do(makeBench)
st, err := db.Prepare("select * from bench")
if err != nil {
panic(err)
}
defer st.Close()
for n := 0; n < b.N; n++ {
var n sql.NullString
var i int
var f float64
var s string
var t time.Time
r, err := st.Query()
if err != nil {
panic(err)
}
for r.Next() {
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
panic(err)
}
}
if err = r.Err(); err != nil {
panic(err)
}
}
}

View file

@ -0,0 +1,14 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build windows
package sqlite3
/*
#cgo CFLAGS: -I. -fno-stack-check -fno-stack-protector -mno-stack-arg-probe
#cgo windows,386 CFLAGS: -D_localtime32=localtime
#cgo LDFLAGS: -lmingwex -lmingw32
*/
import "C"

View file

@ -4,7 +4,7 @@ import (
"testing" "testing"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/db"
"github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" "github.com/coreos/dex/user/manager"
@ -22,32 +22,53 @@ type testFixtures struct {
func makeTestFixtures() *testFixtures { func makeTestFixtures() *testFixtures {
f := &testFixtures{} f := &testFixtures{}
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ dbMap := db.NewMemDB()
{ f.ur = func() user.UserRepo {
User: user.User{ repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
ID: "ID-1", {
Email: "email-1@example.com", User: user.User{
DisplayName: "Name-1", ID: "ID-1",
Email: "email-1@example.com",
DisplayName: "Name-1",
},
}, },
}, {
{ User: user.User{
User: user.User{ ID: "ID-2",
ID: "ID-2", Email: "email-2@example.com",
Email: "email-2@example.com", DisplayName: "Name-2",
DisplayName: "Name-2", },
}, },
}, })
}) if err != nil {
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ panic("Failed to create user repo: " + err.Error())
{ }
UserID: "ID-1", return repo
Password: []byte("hi."), }()
},
}) f.pwr = func() user.PasswordInfoRepo {
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{ repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{
&connector.LocalConnectorConfig{ID: "local"}, {
}) UserID: "ID-1",
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{}) Password: []byte("hi."),
},
})
if err != nil {
panic("Failed to create user repo: " + err.Error())
}
return repo
}()
ccr := func() connector.ConnectorConfigRepo {
c := []connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}}
repo := db.NewConnectorConfigRepo(dbMap)
if err := repo.Set(c); err != nil {
panic(err)
}
return repo
}()
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
f.adAPI = NewAdminAPI(f.mgr, f.ur, f.pwr, "local") f.adAPI = NewAdminAPI(f.mgr, f.ur, f.pwr, "local")
return f return f

18
build
View file

@ -1,18 +1,10 @@
#!/bin/bash -e #!/bin/bash -e
export GOPATH=${PWD}/Godeps/_workspace source ./env
export GOBIN=${PWD}/bin
rm -rf $GOPATH/src/github.com/coreos/dex go install -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dex-worker
mkdir -p $GOPATH/src/github.com/coreos/ go install -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dexctl
go install -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dex-overlord
# Only attempt to link dex into godeps if it isn't already there
[ -d $GOPATH/src/github.com/coreos/dex ] || ln -s ${PWD} $GOPATH/src/github.com/coreos/dex
LD_FLAGS="-X main.version=$(./git-version)"
go build -o bin/dex-worker -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dex-worker
go build -o bin/dexctl -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dexctl
go build -o bin/dex-overlord -ldflags="$LD_FLAGS" github.com/coreos/dex/cmd/dex-overlord
go build -o bin/example-app github.com/coreos/dex/examples/app go build -o bin/example-app github.com/coreos/dex/examples/app
go build -o bin/example-cli github.com/coreos/dex/examples/cli go build -o bin/example-cli github.com/coreos/dex/examples/cli
go build -o bin/gendoc github.com/coreos/dex/cmd/gendoc go install github.com/coreos/dex/cmd/gendoc

View file

@ -1,16 +1,10 @@
package client package client
import ( import (
"encoding/base64"
"encoding/json"
"errors" "errors"
"io"
"io/ioutil"
"net/url" "net/url"
"reflect" "reflect"
"sort"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
@ -46,146 +40,6 @@ type ClientIdentityRepo interface {
IsDexAdmin(clientID string) (bool, error) IsDexAdmin(clientID string) (bool, error)
} }
func NewClientIdentityRepo(cs []oidc.ClientIdentity) ClientIdentityRepo {
cr := memClientIdentityRepo{
idents: make(map[string]oidc.ClientIdentity, len(cs)),
admins: make(map[string]bool),
}
for _, c := range cs {
c := c
cr.idents[c.Credentials.ID] = c
}
return &cr
}
type memClientIdentityRepo struct {
idents map[string]oidc.ClientIdentity
admins map[string]bool
}
func (cr *memClientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) {
if _, ok := cr.idents[id]; ok {
return nil, errors.New("client ID already exists")
}
secret, err := pcrypto.RandBytes(32)
if err != nil {
return nil, err
}
cc := oidc.ClientCredentials{
ID: id,
Secret: base64.URLEncoding.EncodeToString(secret),
}
cr.idents[id] = oidc.ClientIdentity{
Metadata: meta,
Credentials: cc,
}
return &cc, nil
}
func (cr *memClientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
ci, ok := cr.idents[clientID]
if !ok {
return nil, ErrorNotFound
}
return &ci.Metadata, nil
}
func (cr *memClientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
ci, ok := cr.idents[creds.ID]
ok = ok && ci.Credentials.Secret == creds.Secret
return ok, nil
}
func (cr *memClientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
cs := make(sortableClientIdentities, 0, len(cr.idents))
for _, ci := range cr.idents {
ci := ci
cs = append(cs, ci)
}
sort.Sort(cs)
return cs, nil
}
func (cr *memClientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
cr.admins[clientID] = isAdmin
return nil
}
func (cr *memClientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
return cr.admins[clientID], nil
}
type sortableClientIdentities []oidc.ClientIdentity
func (s sortableClientIdentities) Len() int {
return len([]oidc.ClientIdentity(s))
}
func (s sortableClientIdentities) Less(i, j int) bool {
return s[i].Credentials.ID < s[j].Credentials.ID
}
func (s sortableClientIdentities) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func NewClientIdentityRepoFromReader(r io.Reader) (ClientIdentityRepo, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
var cs []clientIdentity
if err = json.Unmarshal(b, &cs); err != nil {
return nil, err
}
ocs := make([]oidc.ClientIdentity, len(cs))
for i, c := range cs {
ocs[i] = oidc.ClientIdentity(c)
}
return NewClientIdentityRepo(ocs), nil
}
type clientIdentity oidc.ClientIdentity
func (ci *clientIdentity) UnmarshalJSON(data []byte) error {
c := struct {
ID string `json:"id"`
Secret string `json:"secret"`
RedirectURLs []string `json:"redirectURLs"`
}{}
if err := json.Unmarshal(data, &c); err != nil {
return err
}
ci.Credentials = oidc.ClientCredentials{
ID: c.ID,
Secret: c.Secret,
}
ci.Metadata = oidc.ClientMetadata{
RedirectURIs: make([]url.URL, len(c.RedirectURLs)),
}
for i, us := range c.RedirectURLs {
up, err := url.Parse(us)
if err != nil {
return err
}
ci.Metadata.RedirectURIs[i] = *up
}
return nil
}
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise. // ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.
// If nil is passed in as the rURL and there is only one URL in redirectURLs, // If nil is passed in as the rURL and there is only one URL in redirectURLs,
// that URL will be returned. If nil is passed but theres >1 URL in the slice, // that URL will be returned. If nil is passed but theres >1 URL in the slice,

View file

@ -1,190 +0,0 @@
package client
import (
"encoding/json"
"net/url"
"reflect"
"sort"
"testing"
"github.com/coreos/go-oidc/oidc"
)
func TestMemClientIdentityRepoNew(t *testing.T) {
tests := []struct {
id string
meta oidc.ClientMetadata
}{
{
id: "foo",
meta: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{
Scheme: "https",
Host: "example.com",
},
},
},
},
{
id: "bar",
meta: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "example.com/foo"},
url.URL{Scheme: "https", Host: "example.com/bar"},
},
},
},
}
for i, tt := range tests {
cr := NewClientIdentityRepo(nil)
creds, err := cr.New(tt.id, tt.meta)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if creds.ID != tt.id {
t.Errorf("case %d: expected non-empty Client ID", i)
}
if creds.Secret == "" {
t.Errorf("case %d: expected non-empty Client Secret", i)
}
all, err := cr.All()
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if len(all) != 1 {
t.Errorf("case %d: expected repo to contain newly created Client", i)
}
wantURLs := tt.meta.RedirectURIs
gotURLs := all[0].Metadata.RedirectURIs
if !reflect.DeepEqual(wantURLs, gotURLs) {
t.Errorf("case %d: redirect url mismatch, want=%v, got=%v", i, wantURLs, gotURLs)
}
}
}
func TestMemClientIdentityRepoNewDuplicate(t *testing.T) {
cr := NewClientIdentityRepo(nil)
meta1 := oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "foo.example.com"},
},
}
if _, err := cr.New("foo", meta1); err != nil {
t.Errorf("unexpected error: %v", err)
}
meta2 := oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "bar.example.com"},
},
}
if _, err := cr.New("foo", meta2); err == nil {
t.Errorf("expected non-nil error")
}
}
func TestMemClientIdentityRepoAll(t *testing.T) {
tests := []struct {
ids []string
}{
{
ids: nil,
},
{
ids: []string{"foo"},
},
{
ids: []string{"foo", "bar"},
},
}
for i, tt := range tests {
cs := make([]oidc.ClientIdentity, len(tt.ids))
for i, s := range tt.ids {
cs[i] = oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
ID: s,
},
}
}
cr := NewClientIdentityRepo(cs)
all, err := cr.All()
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
want := sortableClientIdentities(cs)
sort.Sort(want)
got := sortableClientIdentities(all)
sort.Sort(got)
if len(got) != len(want) {
t.Errorf("case %d: wrong length: %d", i, len(got))
}
if !reflect.DeepEqual(want, got) {
t.Errorf("case %d: want=%#v, got=%#v", i, want, got)
}
}
}
func TestClientIdentityUnmarshalJSON(t *testing.T) {
for i, test := range []struct {
json string
expectedID string
expectedSecret string
expectedURLs []string
}{
{
json: `{"id":"12345","secret":"rosebud","redirectURLs":["https://redirectone.com", "https://redirecttwo.com"]}`,
expectedID: "12345",
expectedSecret: "rosebud",
expectedURLs: []string{
"https://redirectone.com",
"https://redirecttwo.com",
},
},
} {
var actual clientIdentity
err := json.Unmarshal([]byte(test.json), &actual)
if err != nil {
t.Errorf("case %d: error unmarshalling: %v", i, err)
continue
}
if actual.Credentials.ID != test.expectedID {
t.Errorf("case %d: actual.Credentials.ID == %v, want %v", i, actual.Credentials.ID, test.expectedID)
}
if actual.Credentials.Secret != test.expectedSecret {
t.Errorf("case %d: actual.Credentials.Secret == %v, want %v", i, actual.Credentials.Secret, test.expectedSecret)
}
expectedURLs := test.expectedURLs
sort.Strings(expectedURLs)
actualURLs := make([]string, 0)
for _, u := range actual.Metadata.RedirectURIs {
actualURLs = append(actualURLs, u.String())
}
sort.Strings(actualURLs)
if len(actualURLs) != len(expectedURLs) {
t.Errorf("case %d: len(actualURLs) == %v, want %v", i, len(actualURLs), len(expectedURLs))
}
for ui, actualURL := range actualURLs {
if actualURL != expectedURLs[ui] {
t.Errorf("case %d: actualURLs[%d] == %q, want %q", i, ui, actualURL, expectedURLs[ui])
}
}
}
}

View file

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/admin" "github.com/coreos/dex/admin"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
@ -94,6 +95,9 @@ func main() {
if err != nil { if err != nil {
log.Fatalf(err.Error()) log.Fatalf(err.Error())
} }
if _, ok := dbc.Dialect.(gorp.PostgresDialect); !ok {
log.Fatal("only postgres backend supported for multi server configurations")
}
if *dbMigrate { if *dbMigrate {
var sleep time.Duration var sleep time.Duration

View file

@ -3,8 +3,6 @@ package connector
import ( import (
"encoding/json" "encoding/json"
"io" "io"
"github.com/coreos/dex/repo"
) )
func ReadConfigs(r io.Reader) ([]ConnectorConfig, error) { func ReadConfigs(r io.Reader) ([]ConnectorConfig, error) {
@ -22,24 +20,3 @@ func ReadConfigs(r io.Reader) ([]ConnectorConfig, error) {
} }
return cfgs, nil return cfgs, nil
} }
type memConnectorConfigRepo struct {
configs []ConnectorConfig
}
func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo {
return &memConnectorConfigRepo{configs: cfgs}
}
func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
return r.configs, nil
}
func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) {
for _, cfg := range r.configs {
if cfg.ConnectorID() == id {
return cfg, nil
}
}
return nil, ErrorNotFound
}

View file

@ -11,6 +11,7 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
@ -85,35 +86,47 @@ func (m *clientIdentityModel) ClientIdentity() (*oidc.ClientIdentity, error) {
} }
func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo { func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo {
return &clientIdentityRepo{dbMap: dbm} return newClientIdentityRepo(dbm)
}
func newClientIdentityRepo(dbm *gorp.DbMap) *clientIdentityRepo {
return &clientIdentityRepo{db: &db{dbm}}
} }
func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) { func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) {
repo := NewClientIdentityRepo(dbm).(*clientIdentityRepo) repo := newClientIdentityRepo(dbm)
tx, err := repo.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
exec := repo.executor(tx)
for _, c := range clients { for _, c := range clients {
dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret) dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata) cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = repo.dbMap.Insert(cm) err = exec.Insert(cm)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if err := tx.Commit(); err != nil {
return nil, err
}
return repo, nil return repo, nil
} }
type clientIdentityRepo struct { type clientIdentityRepo struct {
dbMap *gorp.DbMap *db
} }
func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) { func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
m, err := r.dbMap.Get(clientIdentityModel{}, clientID) m, err := r.executor(nil).Get(clientIdentityModel{}, clientID)
if err == sql.ErrNoRows || m == nil { if err == sql.ErrNoRows || m == nil {
return nil, client.ErrorNotFound return nil, client.ErrorNotFound
} }
@ -136,7 +149,7 @@ func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, er
} }
func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) { func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
m, err := r.dbMap.Get(clientIdentityModel{}, clientID) m, err := r.executor(nil).Get(clientIdentityModel{}, clientID)
if m == nil || err != nil { if m == nil || err != nil {
return false, err return false, err
} }
@ -151,42 +164,35 @@ func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
} }
func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error { func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
tx, err := r.dbMap.Begin() tx, err := r.begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback()
exec := r.executor(tx)
m, err := r.dbMap.Get(clientIdentityModel{}, clientID) m, err := exec.Get(clientIdentityModel{}, clientID)
if m == nil || err != nil { if m == nil || err != nil {
rollback(tx)
return err return err
} }
cim, ok := m.(*clientIdentityModel) cim, ok := m.(*clientIdentityModel)
if !ok { if !ok {
rollback(tx)
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m)) log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
return errors.New("unrecognized model") return errors.New("unrecognized model")
} }
cim.DexAdmin = isAdmin cim.DexAdmin = isAdmin
_, err = r.dbMap.Update(cim) _, err = exec.Update(cim)
if err != nil { if err != nil {
rollback(tx)
return err return err
} }
err = tx.Commit() return tx.Commit()
if err != nil {
rollback(tx)
return err
}
return nil
} }
func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) { func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
m, err := r.dbMap.Get(clientIdentityModel{}, creds.ID) m, err := r.executor(nil).Get(clientIdentityModel{}, creds.ID)
if m == nil || err != nil { if m == nil || err != nil {
return false, err return false, err
} }
@ -222,9 +228,16 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
return nil, err return nil, err
} }
if err := r.dbMap.Insert(cim); err != nil { if err := r.executor(nil).Insert(cim); err != nil {
if perr, ok := err.(*pq.Error); ok && perr.Code == pgErrorCodeUniqueViolation { switch sqlErr := err.(type) {
err = errors.New("client ID already exists") case *pq.Error:
if sqlErr.Code == pgErrorCodeUniqueViolation {
err = errors.New("client ID already exists")
}
case *sqlite3.Error:
if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique {
err = errors.New("client ID already exists")
}
} }
return nil, err return nil, err
@ -239,9 +252,9 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
} }
func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) { func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
qt := pq.QuoteIdentifier(clientIdentityTableName) qt := r.quote(clientIdentityTableName)
q := fmt.Sprintf("SELECT * FROM %s", qt) q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.dbMap.Select(&clientIdentityModel{}, q) objs, err := r.executor(nil).Select(&clientIdentityModel{}, q)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -4,13 +4,15 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"strings" "net/url"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
_ "github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
// Import database drivers
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
) )
type table struct { type table struct {
@ -43,22 +45,35 @@ type Config struct {
} }
func NewConnection(cfg Config) (*gorp.DbMap, error) { func NewConnection(cfg Config) (*gorp.DbMap, error) {
if !strings.HasPrefix(cfg.DSN, "postgres://") { u, err := url.Parse(cfg.DSN)
if err != nil {
return nil, fmt.Errorf("parse DSN: %v", err)
}
var (
db *sql.DB
dialect gorp.Dialect
)
switch u.Scheme {
case "postgres":
db, err = sql.Open("postgres", cfg.DSN)
if err != nil {
return nil, err
}
db.SetMaxIdleConns(cfg.MaxIdleConnections)
db.SetMaxOpenConns(cfg.MaxOpenConnections)
dialect = gorp.PostgresDialect{}
case "sqlite3":
db, err = sql.Open("sqlite3", u.Host)
if err != nil {
return nil, err
}
// NOTE(ericchiang): sqlite does NOT work with SetMaxIdleConns.
dialect = gorp.SqliteDialect{}
default:
return nil, errors.New("unrecognized database driver") return nil, errors.New("unrecognized database driver")
} }
db, err := sql.Open("postgres", cfg.DSN) dbm := gorp.DbMap{Db: db, Dialect: dialect}
if err != nil {
return nil, err
}
db.SetMaxIdleConns(cfg.MaxIdleConnections)
db.SetMaxOpenConns(cfg.MaxOpenConnections)
dbm := gorp.DbMap{
Db: db,
Dialect: gorp.PostgresDialect{},
}
for _, t := range tables { for _, t := range tables {
tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...) tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...)
@ -70,7 +85,6 @@ func NewConnection(cfg Config) (*gorp.DbMap, error) {
cm.SetUnique(true) cm.SetUnique(true)
} }
} }
return &dbm, nil return &dbm, nil
} }
@ -80,9 +94,14 @@ func TransactionFactory(conn *gorp.DbMap) repo.TransactionFactory {
} }
} }
func rollback(tx *gorp.Transaction) { // NewMemDB creates a new in memory sqlite3 database.
err := tx.Rollback() func NewMemDB() *gorp.DbMap {
dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"})
if err != nil { if err != nil {
log.Errorf("unable to rollback: %v", err) panic("Failed to create in memory database: " + err.Error())
} }
if _, err := MigrateToLatest(dbMap); err != nil {
panic("In memory database migration failed: " + err.Error())
}
return dbMap
} }

View file

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
@ -61,17 +60,17 @@ func (m *connectorConfigModel) ConnectorConfig() (connector.ConnectorConfig, err
} }
func NewConnectorConfigRepo(dbm *gorp.DbMap) *ConnectorConfigRepo { func NewConnectorConfigRepo(dbm *gorp.DbMap) *ConnectorConfigRepo {
return &ConnectorConfigRepo{dbMap: dbm} return &ConnectorConfigRepo{&db{dbm}}
} }
type ConnectorConfigRepo struct { type ConnectorConfigRepo struct {
dbMap *gorp.DbMap *db
} }
func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) { func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName) qt := r.quote(connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s", qt) q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.dbMap.Select(&connectorConfigModel{}, q) objs, err := r.executor(nil).Select(&connectorConfigModel{}, q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -94,7 +93,7 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
} }
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) { func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName) qt := r.quote(connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt) q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
var c connectorConfigModel var c connectorConfigModel
if err := r.executor(tx).SelectOne(&c, q, id); err != nil { if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
@ -117,32 +116,22 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
insert[i] = m insert[i] = m
} }
tx, err := r.dbMap.Begin() tx, err := r.begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback()
exec := r.executor(tx)
qt := pq.QuoteIdentifier(connectorConfigTableName) qt := r.quote(connectorConfigTableName)
q := fmt.Sprintf("DELETE FROM %s", qt) q := fmt.Sprintf("DELETE FROM %s", qt)
if _, err = r.dbMap.Exec(q); err != nil { if _, err = exec.Exec(q); err != nil {
return err return err
} }
if err = r.dbMap.Insert(insert...); err != nil { if err = exec.Insert(insert...); err != nil {
return fmt.Errorf("DB insert failed %#v: %v", insert, err) return fmt.Errorf("DB insert failed %#v: %v", insert, err)
} }
return tx.Commit() return tx.Commit()
} }
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}

60
db/db.go Normal file
View file

@ -0,0 +1,60 @@
// Package db provides SQL implementations of dex's storage interfaces.
package db
import (
"github.com/go-gorp/gorp"
"github.com/coreos/dex/db/translate"
"github.com/coreos/dex/repo"
)
// db is the connection type passed to repos.
//
// TODO(ericchiang): Eventually just return this instead of gorp.DbMap during Conn.
// All actions should go through this type instead of dbMap.
type db struct {
dbMap *gorp.DbMap
}
// executor returns a driver agnostic SQL executor.
//
// The expected flavor of all queries is the flavor used by github.com/lib/pq. All bind
// parameters must be unique and in sequential order (e.g. $1, $2, ...).
//
// See github.com/coreos/dex/db/translate for details on the translation.
//
// If tx is nil, a non-transaction context is provided.
func (conn *db) executor(tx repo.Transaction) gorp.SqlExecutor {
var exec gorp.SqlExecutor
if tx == nil {
exec = conn.dbMap
} else {
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
// Check if the underlying value of the pointer is nil.
// This is not caught by the initial comparison (tx == nil).
if gorpTx == nil {
exec = conn.dbMap
} else {
exec = gorpTx
}
}
if _, ok := conn.dbMap.Dialect.(gorp.SqliteDialect); ok {
exec = translate.NewTranslatingExecutor(exec, translate.PostgresToSQLite)
}
return exec
}
// quote escapes a table name for a driver specific SQL query. quote uses the
// gorp's package underlying quote logic and should NOT be used on untrusted input.
func (conn *db) quote(tableName string) string {
return conn.dbMap.Dialect.QuotedTableForQuery("", tableName)
}
func (conn *db) begin() (repo.Transaction, error) {
return conn.dbMap.Begin()
}

View file

@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq"
pcrypto "github.com/coreos/dex/pkg/crypto" pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
@ -99,7 +98,7 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, useOldFormat bool, secrets ...[]byte)
} }
r := &PrivateKeySetRepo{ r := &PrivateKeySetRepo{
dbMap: dbm, db: &db{dbm},
useOldFormat: useOldFormat, useOldFormat: useOldFormat,
secrets: secrets, secrets: secrets,
} }
@ -108,17 +107,22 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, useOldFormat bool, secrets ...[]byte)
} }
type PrivateKeySetRepo struct { type PrivateKeySetRepo struct {
dbMap *gorp.DbMap *db
useOldFormat bool useOldFormat bool
secrets [][]byte secrets [][]byte
} }
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
qt := pq.QuoteIdentifier(keyTableName) qt := r.quote(keyTableName)
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt)) tx, err := r.begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback()
exec := r.executor(tx)
if _, err := exec.Exec(fmt.Sprintf("DELETE FROM %s", qt)); err != nil {
return err
}
pks, ok := ks.(*key.PrivateKeySet) pks, ok := ks.(*key.PrivateKeySet)
if !ok { if !ok {
@ -148,12 +152,15 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
} }
b := &privateKeySetBlob{Value: v} b := &privateKeySetBlob{Value: v}
return r.dbMap.Insert(b) if err := exec.Insert(b); err != nil {
return err
}
return tx.Commit()
} }
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) { func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
qt := pq.QuoteIdentifier(keyTableName) qt := r.quote(keyTableName)
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt)) objs, err := r.executor(nil).Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,19 +1,18 @@
package db package db
import ( import (
"errors"
"fmt" "fmt"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq" "github.com/rubenv/sql-migrate"
migrate "github.com/rubenv/sql-migrate"
"github.com/coreos/dex/db/migrations" "github.com/coreos/dex/db/migrations"
) )
const ( const (
migrationDialect = "postgres" migrationTable = "dex_migrations"
migrationTable = "dex_migrations" migrationDir = "db/migrations"
migrationDir = "db/migrations"
) )
func init() { func init() {
@ -21,32 +20,57 @@ func init() {
} }
func MigrateToLatest(dbMap *gorp.DbMap) (int, error) { func MigrateToLatest(dbMap *gorp.DbMap) (int, error) {
source := getSource() source, dialect, err := migrationSource(dbMap)
if err != nil {
return migrate.Exec(dbMap.Db, migrationDialect, source, migrate.Up) return 0, err
}
return migrate.Exec(dbMap.Db, dialect, source, migrate.Up)
} }
func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) { func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) {
source := getSource() source, dialect, err := migrationSource(dbMap)
if err != nil {
return migrate.ExecMax(dbMap.Db, migrationDialect, source, migrate.Up, max) return 0, err
}
return migrate.ExecMax(dbMap.Db, dialect, source, migrate.Up, max)
} }
func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) { func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) {
migrations, _, err := migrate.PlanMigration(dbMap.Db, migrationDialect, getSource(), migrate.Up, 0) source, dialect, err := migrationSource(dbMap)
if err != nil {
return nil, err
}
migrations, _, err := migrate.PlanMigration(dbMap.Db, dialect, source, migrate.Up, 0)
return migrations, err return migrations, err
} }
func DropMigrationsTable(dbMap *gorp.DbMap) error { func DropMigrationsTable(dbMap *gorp.DbMap) error {
qt := pq.QuoteIdentifier(migrationTable) qt := fmt.Sprintf("DROP TABLE IF EXISTS %s;", dbMap.Dialect.QuotedTableForQuery("", migrationTable))
_, err := dbMap.Exec(fmt.Sprintf("drop table if exists %s ;", qt)) _, err := dbMap.Exec(qt)
return err return err
} }
func getSource() migrate.MigrationSource { func migrationSource(dbMap *gorp.DbMap) (src migrate.MigrationSource, dialect string, err error) {
return &migrate.AssetMigrationSource{ switch dbMap.Dialect.(type) {
Dir: migrationDir, case gorp.PostgresDialect:
Asset: migrations.Asset, src = &migrate.AssetMigrationSource{
AssetDir: migrations.AssetDir, Dir: migrationDir,
Asset: migrations.Asset,
AssetDir: migrations.AssetDir,
}
return src, "postgres", nil
case gorp.SqliteDialect:
src = &migrate.MemoryMigrationSource{
Migrations: []*migrate.Migration{
{
Id: "dex.sql",
Up: []string{sqlite3Migration},
},
},
}
return src, "sqlite3", nil
default:
return nil, "", errors.New("unsupported migration driver")
} }
} }

73
db/migrate_sqlite3.go Normal file
View file

@ -0,0 +1,73 @@
package db
// SQLite3 is a test only database. There is only one migration because we do not support migrations.
const sqlite3Migration = `
CREATE TABLE authd_user (
id text NOT NULL UNIQUE,
email text,
email_verified integer,
display_name text,
admin integer,
created_at bigint,
disabled integer
);
CREATE TABLE client_identity (
id text NOT NULL UNIQUE,
secret blob,
metadata text,
dex_admin integer
);
CREATE TABLE connector_config (
id text NOT NULL UNIQUE,
type text,
config text
);
CREATE TABLE key (
value blob
);
CREATE TABLE password_info (
user_id text NOT NULL UNIQUE,
password text,
password_expires bigint
);
CREATE TABLE refresh_token (
id integer PRIMARY KEY,
payload_hash blob,
user_id text,
client_id text
);
CREATE TABLE remote_identity_mapping (
connector_id text NOT NULL,
user_id text,
remote_id text NOT NULL
);
CREATE TABLE session (
id text NOT NULL UNIQUE,
state text,
created_at bigint,
expires_at bigint,
client_id text,
client_state text,
redirect_url text,
identity text,
connector_id text,
user_id text,
register integer,
nonce text,
scope text
);
CREATE TABLE session_key (
key text NOT NULL,
session_id text,
expires_at bigint,
stale integer
);
`

View file

@ -105,7 +105,7 @@ func TestMigrateClientMetadata(t *testing.T) {
id := strconv.Itoa(i) id := strconv.Itoa(i)
m, err := dbMap.Get(clientIdentityModel{}, id) m, err := dbMap.Get(clientIdentityModel{}, id)
if err != nil { if err != nil {
t.Errorf("case %d: failed to get model: %err", i, err) t.Errorf("case %d: failed to get model: %v", i, err)
continue continue
} }
cim, ok := m.(*clientIdentityModel) cim, ok := m.(*clientIdentityModel)

View file

@ -5,10 +5,11 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/go-gorp/gorp"
) )
const ( const (
@ -33,12 +34,22 @@ type passwordInfoModel struct {
func NewPasswordInfoRepo(dbm *gorp.DbMap) user.PasswordInfoRepo { func NewPasswordInfoRepo(dbm *gorp.DbMap) user.PasswordInfoRepo {
return &passwordInfoRepo{ return &passwordInfoRepo{
dbMap: dbm, db: &db{dbm},
} }
} }
func NewPasswordInfoRepoFromPasswordInfos(dbm *gorp.DbMap, infos []user.PasswordInfo) (user.PasswordInfoRepo, error) {
repo := NewPasswordInfoRepo(dbm)
for _, info := range infos {
if err := repo.Create(nil, info); err != nil {
return nil, err
}
}
return repo, nil
}
type passwordInfoRepo struct { type passwordInfoRepo struct {
dbMap *gorp.DbMap *db
} }
func (r *passwordInfoRepo) Get(tx repo.Transaction, userID string) (user.PasswordInfo, error) { func (r *passwordInfoRepo) Get(tx repo.Transaction, userID string) (user.PasswordInfo, error) {
@ -89,18 +100,6 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err
return nil return nil
} }
func (r *passwordInfoRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}
func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) { func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) {
ex := r.executor(tx) ex := r.executor(tx)

View file

@ -8,10 +8,12 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo"
) )
const ( const (
@ -28,7 +30,7 @@ func init() {
} }
type refreshTokenRepo struct { type refreshTokenRepo struct {
dbMap *gorp.DbMap *db
tokenGenerator refresh.RefreshTokenGenerator tokenGenerator refresh.RefreshTokenGenerator
} }
@ -76,9 +78,13 @@ func checkTokenPayload(payloadHash, payload []byte) error {
} }
func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo { func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo {
return NewRefreshTokenRepoWithGenerator(dbm, refresh.DefaultRefreshTokenGenerator)
}
func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo {
return &refreshTokenRepo{ return &refreshTokenRepo{
dbMap: dbm, db: &db{dbm},
tokenGenerator: refresh.DefaultRefreshTokenGenerator, tokenGenerator: gen,
} }
} }
@ -107,7 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
ClientID: clientID, ClientID: clientID,
} }
if err := r.dbMap.Insert(record); err != nil { if err := r.executor(nil).Insert(record); err != nil {
return "", err return "", err
} }
@ -143,7 +149,13 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
return err return err
} }
record, err := r.get(nil, tokenID) tx, err := r.begin()
if err != nil {
return err
}
defer tx.Rollback()
exec := r.executor(tx)
record, err := r.get(tx, tokenID)
if err != nil { if err != nil {
return err return err
} }
@ -156,7 +168,7 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
return err return err
} }
deleted, err := r.dbMap.Delete(record) deleted, err := exec.Delete(record)
if err != nil { if err != nil {
return err return err
} }
@ -164,17 +176,10 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
return refresh.ErrorInvalidToken return refresh.ErrorInvalidToken
} }
return nil return tx.Commit()
} }
func (r *refreshTokenRepo) executor(tx *gorp.Transaction) gorp.SqlExecutor { func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) {
if tx == nil {
return r.dbMap
}
return tx
}
func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx) ex := r.executor(tx)
result, err := ex.Get(refreshTokenModel{}, tokenID) result, err := ex.Get(refreshTokenModel{}, tokenID)
if err != nil { if err != nil {

View file

@ -11,7 +11,6 @@ import (
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
@ -124,16 +123,16 @@ func NewSessionRepo(dbm *gorp.DbMap) *SessionRepo {
} }
func NewSessionRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionRepo { func NewSessionRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionRepo {
return &SessionRepo{dbMap: dbm, clock: clock} return &SessionRepo{db: &db{dbm}, clock: clock}
} }
type SessionRepo struct { type SessionRepo struct {
dbMap *gorp.DbMap *db
clock clockwork.Clock clock clockwork.Clock
} }
func (r *SessionRepo) Get(sessionID string) (*session.Session, error) { func (r *SessionRepo) Get(sessionID string) (*session.Session, error) {
m, err := r.dbMap.Get(sessionModel{}, sessionID) m, err := r.executor(nil).Get(sessionModel{}, sessionID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -164,7 +163,7 @@ func (r *SessionRepo) Create(s session.Session) error {
if err != nil { if err != nil {
return err return err
} }
return r.dbMap.Insert(sm) return r.executor(nil).Insert(sm)
} }
func (r *SessionRepo) Update(s session.Session) error { func (r *SessionRepo) Update(s session.Session) error {
@ -172,7 +171,7 @@ func (r *SessionRepo) Update(s session.Session) error {
if err != nil { if err != nil {
return err return err
} }
n, err := r.dbMap.Update(sm) n, err := r.executor(nil).Update(sm)
if err != nil { if err != nil {
return err return err
} }
@ -183,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error {
} }
func (r *SessionRepo) purge() error { func (r *SessionRepo) purge() error {
qt := pq.QuoteIdentifier(sessionTableName) qt := r.quote(sessionTableName)
q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt) q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt)
res, err := r.dbMap.Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead)) res, err := r.executor(nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
if err != nil { if err != nil {
return err return err
} }

View file

@ -8,7 +8,6 @@ import (
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
@ -39,11 +38,11 @@ func NewSessionKeyRepo(dbm *gorp.DbMap) *SessionKeyRepo {
} }
func NewSessionKeyRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionKeyRepo { func NewSessionKeyRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionKeyRepo {
return &SessionKeyRepo{dbMap: dbm, clock: clock} return &SessionKeyRepo{db: &db{dbm}, clock: clock}
} }
type SessionKeyRepo struct { type SessionKeyRepo struct {
dbMap *gorp.DbMap *db
clock clockwork.Clock clock clockwork.Clock
} }
@ -54,11 +53,11 @@ func (r *SessionKeyRepo) Push(sk session.SessionKey, exp time.Duration) error {
ExpiresAt: r.clock.Now().Unix() + int64(exp.Seconds()), ExpiresAt: r.clock.Now().Unix() + int64(exp.Seconds()),
Stale: false, Stale: false,
} }
return r.dbMap.Insert(skm) return r.executor(nil).Insert(skm)
} }
func (r *SessionKeyRepo) Pop(key string) (string, error) { func (r *SessionKeyRepo) Pop(key string) (string, error) {
m, err := r.dbMap.Get(sessionKeyModel{}, key) m, err := r.executor(nil).Get(sessionKeyModel{}, key)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -77,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
return "", errors.New("invalid session key") return "", errors.New("invalid session key")
} }
qt := pq.QuoteIdentifier(sessionKeyTableName) qt := r.quote(sessionKeyTableName)
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt) q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
res, err := r.dbMap.Exec(q, true, key, false) res, err := r.executor(nil).Exec(q, true, key, false)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -95,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
} }
func (r *SessionKeyRepo) purge() error { func (r *SessionKeyRepo) purge() error {
qt := pq.QuoteIdentifier(sessionKeyTableName) qt := r.quote(sessionKeyTableName)
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt) q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
res, err := r.dbMap.Exec(q, true, r.clock.Now().Unix()) res, err := r.executor(nil).Exec(q, true, r.clock.Now().Unix())
if err != nil { if err != nil {
return err return err
} }

72
db/translate/translate.go Normal file
View file

@ -0,0 +1,72 @@
/*
Package translate implements translation of driver specific SQL queries.
*/
package translate
import (
"database/sql"
"regexp"
"github.com/go-gorp/gorp"
)
var (
bindRegexp = regexp.MustCompile(`\$\d+`)
trueRegexp = regexp.MustCompile(`\btrue\b`)
)
// PostgresToSQLite translates github.com/lib/pq flavored SQL quries to github.com/mattn/go-sqlite3's flavor.
//
// It assumes that possitional bind arguements ($1, $2, etc.) are unqiue and in sequential order.
func PostgresToSQLite(query string) string {
query = bindRegexp.ReplaceAllString(query, "?")
query = trueRegexp.ReplaceAllString(query, "1")
return query
}
// NewTranslatingExecutor returns an executor wrapping the existing executor. All query strings passed to
// the executor will be run through the translate function before begin passed to the driver.
func NewTranslatingExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor {
return &executor{exec, translate}
}
type executor struct {
gorp.SqlExecutor
Translate func(string) string
}
func (e *executor) Exec(query string, args ...interface{}) (sql.Result, error) {
return e.SqlExecutor.Exec(e.Translate(query), args...)
}
func (e *executor) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return e.SqlExecutor.Select(i, e.Translate(query), args...)
}
func (e *executor) SelectInt(query string, args ...interface{}) (int64, error) {
return e.SqlExecutor.SelectInt(e.Translate(query), args...)
}
func (e *executor) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
return e.SqlExecutor.SelectNullInt(e.Translate(query), args...)
}
func (e *executor) SelectFloat(query string, args ...interface{}) (float64, error) {
return e.SqlExecutor.SelectFloat(e.Translate(query), args...)
}
func (e *executor) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
return e.SqlExecutor.SelectNullFloat(e.Translate(query), args...)
}
func (e *executor) SelectStr(query string, args ...interface{}) (string, error) {
return e.SqlExecutor.SelectStr(e.Translate(query), args...)
}
func (e *executor) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
return e.SqlExecutor.SelectNullStr(e.Translate(query), args...)
}
func (e *executor) SelectOne(holder interface{}, query string, args ...interface{}) error {
return e.SqlExecutor.SelectOne(holder, e.Translate(query), args...)
}

View file

@ -0,0 +1,28 @@
package translate
import "testing"
func TestPostgresToSQLite(t *testing.T) {
tests := []struct {
query string
want string
}{
{"SELECT * FROM foo", "SELECT * FROM foo"},
{"SELECT * FROM %s", "SELECT * FROM %s"},
{"SELECT * FROM foo WHERE is_admin=true", "SELECT * FROM foo WHERE is_admin=1"},
{"SELECT * FROM foo WHERE is_admin=true;", "SELECT * FROM foo WHERE is_admin=1;"},
{"SELECT * FROM foo WHERE is_admin=$10", "SELECT * FROM foo WHERE is_admin=?"},
{"SELECT * FROM foo WHERE is_admin=$10;", "SELECT * FROM foo WHERE is_admin=?;"},
{"SELECT * FROM foo WHERE name=$1 AND is_admin=$2;", "SELECT * FROM foo WHERE name=? AND is_admin=?;"},
{"$1", "?"},
{"$", "$"},
}
for _, tt := range tests {
got := PostgresToSQLite(tt.query)
if got != tt.want {
t.Errorf("PostgresToSQLite(%q): want=%q, got=%q", tt.query, tt.want, got)
}
}
}

View file

@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
@ -42,7 +41,7 @@ func init() {
func NewUserRepo(dbm *gorp.DbMap) user.UserRepo { func NewUserRepo(dbm *gorp.DbMap) user.UserRepo {
return &userRepo{ return &userRepo{
dbMap: dbm, db: &db{dbm},
} }
} }
@ -53,7 +52,7 @@ func NewUserRepoFromUsers(dbm *gorp.DbMap, us []user.UserWithRemoteIdentities) (
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = repo.dbMap.Insert(um) err = repo.executor(nil).Insert(um)
for _, ri := range u.RemoteIdentities { for _, ri := range u.RemoteIdentities {
err = repo.AddRemoteIdentity(nil, u.User.ID, ri) err = repo.AddRemoteIdentity(nil, u.User.ID, ri)
if err != nil { if err != nil {
@ -65,7 +64,7 @@ func NewUserRepoFromUsers(dbm *gorp.DbMap, us []user.UserWithRemoteIdentities) (
} }
type userRepo struct { type userRepo struct {
dbMap *gorp.DbMap *db
} }
func (r *userRepo) Get(tx repo.Transaction, userID string) (user.User, error) { func (r *userRepo) Get(tx repo.Transaction, userID string) (user.User, error) {
@ -107,9 +106,9 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err
return user.ErrorInvalidID return user.ErrorInvalidID
} }
qt := pq.QuoteIdentifier(userTableName) qt := r.quote(userTableName)
ex := r.executor(tx) ex := r.executor(tx)
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable) result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID)
if err != nil { if err != nil {
return err return err
} }
@ -241,9 +240,8 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
return nil, user.ErrorInvalidID return nil, user.ErrorInvalidID
} }
qt := pq.QuoteIdentifier(remoteIdentityMappingTableName) qt := r.quote(remoteIdentityMappingTableName)
rims, err := ex.Select(&remoteIdentityMappingModel{}, rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID)
fmt.Sprintf("select * from %s where user_id = $1", qt), userID)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
@ -273,9 +271,9 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
} }
func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) { func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) {
qt := pq.QuoteIdentifier(userTableName) qt := r.quote(userTableName)
ex := r.executor(tx) ex := r.executor(tx)
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s where admin=true", qt)) i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt))
return int(i), err return int(i), err
} }
@ -290,12 +288,11 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
} }
ex := r.executor(tx) ex := r.executor(tx)
qt := pq.QuoteIdentifier(userTableName) qt := r.quote(userTableName)
// Ask for one more than needed so we know if there's more results, and // Ask for one more than needed so we know if there's more results, and
// hence, whether a nextPageToken is necessary. // hence, whether a nextPageToken is necessary.
ums, err := ex.Select(&userModel{}, ums, err := ex.Select(&userModel{}, fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2", qt), maxResults+1, offset)
fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2 ", qt), maxResults+1, offset)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@ -338,18 +335,6 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
} }
func (r *userRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}
func (r *userRepo) insert(tx repo.Transaction, usr user.User) error { func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
ex := r.executor(tx) ex := r.executor(tx)
um, err := newUserModel(&usr) um, err := newUserModel(&usr)
@ -412,7 +397,7 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot
} }
func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) { func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) {
qt := pq.QuoteIdentifier(userTableName) qt := r.quote(userTableName)
ex := r.executor(tx) ex := r.executor(tx)
var um userModel var um userModel
err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email) err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email)

10
env Normal file
View file

@ -0,0 +1,10 @@
export GOPATH=${PWD}/Godeps/_workspace
export GOBIN=${PWD}/bin
rm -rf $GOPATH/src/github.com/coreos/dex
mkdir -p $GOPATH/src/github.com/coreos/
# Only attempt to link dex into godeps if it isn't already there
[ -d $GOPATH/src/github.com/coreos/dex ] || ln -s ${PWD} $GOPATH/src/github.com/coreos/dex
LD_FLAGS="-X main.version=$(./git-version)"

View file

@ -36,17 +36,16 @@ func connect(t *testing.T) *gorp.DbMap {
if err != nil { if err != nil {
t.Fatalf("Unable to connect to database: %v", err) t.Fatalf("Unable to connect to database: %v", err)
} }
if err = c.DropTablesIfExists(); err != nil { if err = c.DropTablesIfExists(); err != nil {
t.Fatalf("Unable to drop database tables: %v", err) t.Fatalf("Unable to drop database tables: %v", err)
} }
if err = db.DropMigrationsTable(c); err != nil { if err = db.DropMigrationsTable(c); err != nil {
panic(fmt.Sprintf("Unable to drop migration table: %v", err)) t.Fatalf("Unable to drop migration table: %v", err)
} }
if _, err = db.MigrateToLatest(c); err != nil { if _, err = db.MigrateToLatest(c); err != nil {
panic(fmt.Sprintf("Unable to migrate: %v", err)) t.Fatalf("Unable to migrate: %v", err)
} }
return c return c
@ -157,12 +156,13 @@ func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
setRepo, err := db.NewPrivateKeySetRepo(connect(t), false, tt.setSecrets...) dbMap := connect(t)
setRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.setSecrets...)
if err != nil { if err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
getRepo, err := db.NewPrivateKeySetRepo(connect(t), false, tt.getSecrets...) getRepo, err := db.NewPrivateKeySetRepo(dbMap, false, tt.getSecrets...)
if err != nil { if err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
@ -377,9 +377,24 @@ func TestDBRefreshRepoCreate(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
_, err := r.Create(tt.userID, tt.clientID) token, err := r.Create(tt.userID, tt.clientID)
if err != tt.err { if err != nil {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) if tt.err == nil {
t.Errorf("case %d: create failed: %v", i, err)
}
continue
}
if tt.err != nil {
t.Errorf("case %d: expected error, didn't get one", i)
continue
}
userID, err := r.Verify(tt.clientID, token)
if err != nil {
t.Errorf("case %d: failed to verify good token: %v", i, err)
continue
}
if userID != tt.userID {
t.Errorf("case %d: want userID=%s, got userID=%s", i, tt.userID, userID)
} }
} }
} }

View file

@ -28,6 +28,9 @@ var connConfigExample = []byte(`[
]`) ]`)
func TestDexctlCommands(t *testing.T) { func TestDexctlCommands(t *testing.T) {
if strings.HasPrefix(dsn, "sqlite3://") {
t.Skip("only test dexctl conmand with postgres")
}
tempFile, err := ioutil.TempFile("", "dexctl_functional_tests_") tempFile, err := ioutil.TempFile("", "dexctl_functional_tests_")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -10,6 +10,7 @@ import (
"testing" "testing"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"gopkg.in/ldap.v2" "gopkg.in/ldap.v2"
@ -93,13 +94,17 @@ func TestConnectorLDAPConnectFail(t *testing.T) {
templates := template.New(connector.LDAPLoginPageTemplateName) templates := template.New(connector.LDAPLoginPageTemplateName)
ccr := connector.NewConnectorConfigRepoFromConfigs( ccr := db.NewConnectorConfigRepo(db.NewMemDB())
err := ccr.Set(
[]connector.ConnectorConfig{&connector.LDAPConnectorConfig{ []connector.ConnectorConfig{&connector.LDAPConnectorConfig{
ID: "ldap", ID: "ldap",
ServerHost: ldapHost, ServerHost: ldapHost,
ServerPort: ldapPort + 1, ServerPort: ldapPort + 1,
}}, }},
) )
if err != nil {
t.Fatal(err)
}
cc, err := ccr.GetConnectorByID(tx, "ldap") cc, err := ccr.GetConnectorByID(tx, "ldap")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -121,13 +126,17 @@ func TestConnectorLDAPConnectSuccess(t *testing.T) {
templates := template.New(connector.LDAPLoginPageTemplateName) templates := template.New(connector.LDAPLoginPageTemplateName)
ccr := connector.NewConnectorConfigRepoFromConfigs( ccr := db.NewConnectorConfigRepo(db.NewMemDB())
err := ccr.Set(
[]connector.ConnectorConfig{&connector.LDAPConnectorConfig{ []connector.ConnectorConfig{&connector.LDAPConnectorConfig{
ID: "ldap", ID: "ldap",
ServerHost: ldapHost, ServerHost: ldapHost,
ServerPort: ldapPort, ServerPort: ldapPort,
}}, }},
) )
if err != nil {
t.Fatal(err)
}
cc, err := ccr.GetConnectorByID(tx, "ldap") cc, err := ccr.GetConnectorByID(tx, "ldap")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -149,7 +158,8 @@ func TestConnectorLDAPcaFilecertFileConnectTLS(t *testing.T) {
templates := template.New(connector.LDAPLoginPageTemplateName) templates := template.New(connector.LDAPLoginPageTemplateName)
ccr := connector.NewConnectorConfigRepoFromConfigs( ccr := db.NewConnectorConfigRepo(db.NewMemDB())
err := ccr.Set(
[]connector.ConnectorConfig{&connector.LDAPConnectorConfig{ []connector.ConnectorConfig{&connector.LDAPConnectorConfig{
ID: "ldap", ID: "ldap",
ServerHost: ldapHost, ServerHost: ldapHost,
@ -160,6 +170,9 @@ func TestConnectorLDAPcaFilecertFileConnectTLS(t *testing.T) {
CaFile: "/tmp/openldap-ca.pem", CaFile: "/tmp/openldap-ca.pem",
}}, }},
) )
if err != nil {
t.Fatal(err)
}
cc, err := ccr.GetConnectorByID(tx, "ldap") cc, err := ccr.GetConnectorByID(tx, "ldap")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -181,7 +194,8 @@ func TestConnectorLDAPcaFilecertFileConnectSSL(t *testing.T) {
templates := template.New(connector.LDAPLoginPageTemplateName) templates := template.New(connector.LDAPLoginPageTemplateName)
ccr := connector.NewConnectorConfigRepoFromConfigs( ccr := db.NewConnectorConfigRepo(db.NewMemDB())
err := ccr.Set(
[]connector.ConnectorConfig{&connector.LDAPConnectorConfig{ []connector.ConnectorConfig{&connector.LDAPConnectorConfig{
ID: "ldap", ID: "ldap",
ServerHost: ldapHost, ServerHost: ldapHost,
@ -192,6 +206,9 @@ func TestConnectorLDAPcaFilecertFileConnectSSL(t *testing.T) {
CaFile: "/tmp/openldap-ca.pem", CaFile: "/tmp/openldap-ca.pem",
}}, }},
) )
if err != nil {
t.Fatal(err)
}
cc, err := ccr.GetConnectorByID(tx, "ldap") cc, err := ccr.GetConnectorByID(tx, "ldap")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -1,25 +1,24 @@
package repo package repo
import ( import (
"fmt" "encoding/base64"
"net/url" "net/url"
"os" "os"
"testing" "testing"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
) )
var makeTestClientIdentityRepoFromClients func(clients []oidc.ClientIdentity) client.ClientIdentityRepo
var ( var (
testClients = []oidc.ClientIdentity{ testClients = []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "client1", ID: "client1",
Secret: "secret-1", Secret: base64.URLEncoding.EncodeToString([]byte("secret-1")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -33,7 +32,7 @@ var (
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "client2", ID: "client2",
Secret: "secret-2", Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -47,34 +46,19 @@ var (
} }
) )
func init() { func newClientIdentityRepo(t *testing.T) client.ClientIdentityRepo {
dsn := os.Getenv("DEX_TEST_DSN") dsn := os.Getenv("DEX_TEST_DSN")
var dbMap *gorp.DbMap
if dsn == "" { if dsn == "" {
makeTestClientIdentityRepoFromClients = makeTestClientIdentityRepoMem dbMap = db.NewMemDB()
} else { } else {
makeTestClientIdentityRepoFromClients = makeTestClientIdentityRepoDB(dsn) dbMap = connect(t)
} }
} repo, err := db.NewClientIdentityRepoFromClients(dbMap, testClients)
if err != nil {
func makeTestClientIdentityRepoMem(clients []oidc.ClientIdentity) client.ClientIdentityRepo { t.Fatalf("failed to create client repo from clients: %v", err)
return client.NewClientIdentityRepo(clients)
}
func makeTestClientIdentityRepoDB(dsn string) func([]oidc.ClientIdentity) client.ClientIdentityRepo {
return func(clients []oidc.ClientIdentity) client.ClientIdentityRepo {
c := initDB(dsn)
repo, err := db.NewClientIdentityRepoFromClients(c, clients)
if err != nil {
panic(fmt.Sprintf("Unable to add clients: %v", err))
}
return repo
} }
return repo
}
func makeTestClientIdentityRepo() client.ClientIdentityRepo {
return makeTestClientIdentityRepoFromClients(testClients)
} }
func TestGetSetAdminClient(t *testing.T) { func TestGetSetAdminClient(t *testing.T) {
@ -113,12 +97,14 @@ func TestGetSetAdminClient(t *testing.T) {
}, },
} }
Tests:
for i, tt := range tests { for i, tt := range tests {
repo := makeTestClientIdentityRepo() repo := newClientIdentityRepo(t)
for _, cid := range startAdmins { for _, cid := range startAdmins {
err := repo.SetDexAdmin(cid, true) err := repo.SetDexAdmin(cid, true)
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Errorf("case %d: failed to set dex admin: %v", i, err)
continue Tests
} }
} }
@ -130,7 +116,7 @@ func TestGetSetAdminClient(t *testing.T) {
continue continue
} }
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Errorf("case %d: unexpected error: %v", i, err)
} }
if gotAdmin != tt.wantAdmin { if gotAdmin != tt.wantAdmin {
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin) t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
@ -138,12 +124,12 @@ func TestGetSetAdminClient(t *testing.T) {
err = repo.SetDexAdmin(tt.cid, tt.setAdmin) err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Errorf("case %d: unexpected error: %v", i, err)
} }
gotAdmin, err = repo.IsDexAdmin(tt.cid) gotAdmin, err = repo.IsDexAdmin(tt.cid)
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Errorf("case %d: unexpected error: %v", i, err)
} }
if gotAdmin != tt.setAdmin { if gotAdmin != tt.setAdmin {
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin) t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)

View file

@ -1,36 +1,27 @@
package repo package repo
import ( import (
"fmt"
"os" "os"
"testing" "testing"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
) )
type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo func newConnectorConfigRepo(t *testing.T, configs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
var dbMap *gorp.DbMap
var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory if os.Getenv("DEX_TEST_DSN") == "" {
dbMap = db.NewMemDB()
func init() {
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
} else { } else {
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn) dbMap = connect(t)
} }
} repo := db.NewConnectorConfigRepo(dbMap)
if err := repo.Set(configs); err != nil {
func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory { t.Fatalf("Unable to set connector configs: %v", err)
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
dbMap := initDB(dsn)
repo := db.NewConnectorConfigRepo(dbMap)
if err := repo.Set(cfgs); err != nil {
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
}
return repo
} }
return repo
} }
func TestConnectorConfigRepoGetByID(t *testing.T) { func TestConnectorConfigRepoGetByID(t *testing.T) {
@ -63,7 +54,7 @@ func TestConnectorConfigRepoGetByID(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs) repo := newConnectorConfigRepo(t, tt.cfgs)
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err { if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err) t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
} }

View file

@ -1,19 +1,17 @@
package repo package repo
import ( import (
"fmt"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/go-gorp/gorp"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
) )
var makeTestPasswordInfoRepo func() user.PasswordInfoRepo
var ( var (
testPWs = []user.PasswordInfo{ testPWs = []user.PasswordInfo{
{ {
@ -23,30 +21,18 @@ var (
} }
) )
func init() { func newPasswordInfoRepo(t *testing.T) user.PasswordInfoRepo {
dsn := os.Getenv("DEX_TEST_DSN") var dbMap *gorp.DbMap
if dsn == "" { if os.Getenv("DEX_TEST_DSN") == "" {
makeTestPasswordInfoRepo = makeTestPasswordInfoRepoMem dbMap = db.NewMemDB()
} else { } else {
makeTestPasswordInfoRepo = makeTestPasswordInfoRepoDB(dsn) dbMap = connect(t)
} }
} repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, testPWs)
if err != nil {
func makeTestPasswordInfoRepoMem() user.PasswordInfoRepo { t.Fatalf("Unable to add password infos: %v", err)
return user.NewPasswordInfoRepoFromPasswordInfos(testPWs)
}
func makeTestPasswordInfoRepoDB(dsn string) func() user.PasswordInfoRepo {
return func() user.PasswordInfoRepo {
c := initDB(dsn)
repo := db.NewPasswordInfoRepo(c)
err := user.LoadPasswordInfos(repo, testPWs)
if err != nil {
panic(fmt.Sprintf("Unable to add passwordInfos: %v", err))
}
return repo
} }
return repo
} }
func TestCreatePasswordInfo(t *testing.T) { func TestCreatePasswordInfo(t *testing.T) {
@ -87,7 +73,7 @@ func TestCreatePasswordInfo(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestPasswordInfoRepo() repo := newPasswordInfoRepo(t)
err := repo.Create(nil, tt.pw) err := repo.Create(nil, tt.pw)
if tt.err != nil { if tt.err != nil {
if err != tt.err { if err != tt.err {
@ -142,7 +128,7 @@ func TestUpdatePasswordInfo(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestPasswordInfoRepo() repo := newPasswordInfoRepo(t)
err := repo.Update(nil, tt.pw) err := repo.Update(nil, tt.pw)
if tt.err != nil { if tt.err != nil {
if err != tt.err { if err != tt.err {

View file

@ -12,48 +12,26 @@ import (
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
) )
var makeTestSessionRepo func() (session.SessionRepo, clockwork.FakeClock) func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
var makeTestSessionKeyRepo func() (session.SessionKeyRepo, clockwork.FakeClock) clock := clockwork.NewFakeClock()
if os.Getenv("DEX_TEST_DSN") == "" {
func init() { return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock
dsn := os.Getenv("DEX_TEST_DSN")
if dsn == "" {
makeTestSessionRepo = makeTestSessionRepoMem
makeTestSessionKeyRepo = makeTestSessionKeyRepoMem
} else {
makeTestSessionRepo = makeTestSessionRepoDB(dsn)
makeTestSessionKeyRepo = makeTestSessionKeyRepoDB(dsn)
} }
dbMap := connect(t)
return db.NewSessionRepoWithClock(dbMap, clock), clock
} }
func makeTestSessionRepoMem() (session.SessionRepo, clockwork.FakeClock) { func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
fc := clockwork.NewFakeClock() clock := clockwork.NewFakeClock()
return session.NewSessionRepoWithClock(fc), fc if os.Getenv("DEX_TEST_DSN") == "" {
} return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock
func makeTestSessionRepoDB(dsn string) func() (session.SessionRepo, clockwork.FakeClock) {
return func() (session.SessionRepo, clockwork.FakeClock) {
c := initDB(dsn)
fc := clockwork.NewFakeClock()
return db.NewSessionRepoWithClock(c, fc), fc
}
}
func makeTestSessionKeyRepoMem() (session.SessionKeyRepo, clockwork.FakeClock) {
fc := clockwork.NewFakeClock()
return session.NewSessionKeyRepoWithClock(fc), fc
}
func makeTestSessionKeyRepoDB(dsn string) func() (session.SessionKeyRepo, clockwork.FakeClock) {
return func() (session.SessionKeyRepo, clockwork.FakeClock) {
c := initDB(dsn)
fc := clockwork.NewFakeClock()
return db.NewSessionKeyRepoWithClock(c, fc), fc
} }
dbMap := connect(t)
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
} }
func TestSessionKeyRepoPopNoExist(t *testing.T) { func TestSessionKeyRepoPopNoExist(t *testing.T) {
r, _ := makeTestSessionKeyRepo() r, _ := newSessionKeyRepo(t)
_, err := r.Pop("123") _, err := r.Pop("123")
if err == nil { if err == nil {
@ -62,7 +40,7 @@ func TestSessionKeyRepoPopNoExist(t *testing.T) {
} }
func TestSessionKeyRepoPushPop(t *testing.T) { func TestSessionKeyRepoPushPop(t *testing.T) {
r, _ := makeTestSessionKeyRepo() r, _ := newSessionKeyRepo(t)
key := "123" key := "123"
sessionID := "456" sessionID := "456"
@ -80,7 +58,7 @@ func TestSessionKeyRepoPushPop(t *testing.T) {
} }
func TestSessionKeyRepoExpired(t *testing.T) { func TestSessionKeyRepoExpired(t *testing.T) {
r, fc := makeTestSessionKeyRepo() r, fc := newSessionKeyRepo(t)
key := "123" key := "123"
sessionID := "456" sessionID := "456"
@ -96,7 +74,7 @@ func TestSessionKeyRepoExpired(t *testing.T) {
} }
func TestSessionRepoGetNoExist(t *testing.T) { func TestSessionRepoGetNoExist(t *testing.T) {
r, _ := makeTestSessionRepo() r, _ := newSessionRepo(t)
ses, err := r.Get("123") ses, err := r.Get("123")
if ses != nil { if ses != nil {
@ -129,7 +107,7 @@ func TestSessionRepoCreateGet(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
r, _ := makeTestSessionRepo() r, _ := newSessionRepo(t)
r.Create(tt) r.Create(tt)
@ -166,7 +144,7 @@ func TestSessionRepoCreateUpdate(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
r, _ := makeTestSessionRepo() r, _ := newSessionRepo(t)
r.Create(tt.initial) r.Create(tt.initial)
ses, _ := r.Get(tt.initial.ID) ses, _ := r.Get(tt.initial.ID)
@ -186,7 +164,7 @@ func TestSessionRepoCreateUpdate(t *testing.T) {
} }
func TestSessionRepoUpdateNoExist(t *testing.T) { func TestSessionRepoUpdateNoExist(t *testing.T) {
r, _ := makeTestSessionRepo() r, _ := newSessionRepo(t)
err := r.Update(session.Session{ID: "123", ClientState: "boom"}) err := r.Update(session.Session{ID: "123", ClientState: "boom"})
if err == nil { if err == nil {

View file

@ -1,28 +1,38 @@
package repo package repo
import ( import (
"fmt" "os"
"testing"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/go-gorp/gorp"
) )
func initDB(dsn string) *gorp.DbMap { func connect(t *testing.T) *gorp.DbMap {
dsn := os.Getenv("DEX_TEST_DSN")
if dsn == "" {
t.Fatal("DEX_TEST_DSN environment variable not set")
}
c, err := db.NewConnection(db.Config{DSN: dsn}) c, err := db.NewConnection(db.Config{DSN: dsn})
if err != nil { if err != nil {
panic(fmt.Sprintf("Unable to connect to database: %v", err)) t.Fatalf("Unable to connect to database: %v", err)
} }
if err = c.DropTablesIfExists(); err != nil { if err = c.DropTablesIfExists(); err != nil {
panic(fmt.Sprintf("Unable to drop database tables: %v", err)) t.Fatalf("Unable to drop database tables: %v", err)
} }
if err = db.DropMigrationsTable(c); err != nil { if err = db.DropMigrationsTable(c); err != nil {
panic(fmt.Sprintf("Unable to drop migration table: %v", err)) t.Fatalf("Unable to drop migration table: %v", err)
} }
if _, err = db.MigrateToLatest(c); err != nil { n, err := db.MigrateToLatest(c)
panic(fmt.Sprintf("Unable to migrate: %v", err)) if err != nil {
t.Fatalf("Unable to migrate: %v", err)
} }
if n == 0 {
t.Fatalf("No migrations performed")
}
return c return c
} }

View file

@ -7,14 +7,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-gorp/gorp"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
) )
var makeTestUserRepoFromUsers func(users []user.UserWithRemoteIdentities) user.UserRepo
var ( var (
testUsers = []user.UserWithRemoteIdentities{ testUsers = []user.UserWithRemoteIdentities{
{ {
@ -47,34 +46,21 @@ var (
} }
) )
func init() { func newUserRepo(t *testing.T, users []user.UserWithRemoteIdentities) user.UserRepo {
dsn := os.Getenv("DEX_TEST_DSN") if users == nil {
if dsn == "" { users = []user.UserWithRemoteIdentities{}
makeTestUserRepoFromUsers = makeTestUserRepoMem }
var dbMap *gorp.DbMap
if os.Getenv("DEX_TEST_DSN") == "" {
dbMap = db.NewMemDB()
} else { } else {
makeTestUserRepoFromUsers = makeTestUserRepoDB(dsn) dbMap = connect(t)
} }
} repo, err := db.NewUserRepoFromUsers(dbMap, users)
if err != nil {
func makeTestUserRepoMem(users []user.UserWithRemoteIdentities) user.UserRepo { t.Fatalf("Unable to add users: %v", err)
return user.NewUserRepoFromUsers(users)
}
func makeTestUserRepoDB(dsn string) func([]user.UserWithRemoteIdentities) user.UserRepo {
return func(users []user.UserWithRemoteIdentities) user.UserRepo {
c := initDB(dsn)
repo, err := db.NewUserRepoFromUsers(c, users)
if err != nil {
panic(fmt.Sprintf("Unable to add users: %v", err))
}
return repo
} }
return repo
}
func makeTestUserRepo() user.UserRepo {
return makeTestUserRepoFromUsers(testUsers)
} }
func TestNewUser(t *testing.T) { func TestNewUser(t *testing.T) {
@ -137,7 +123,7 @@ func TestNewUser(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestUserRepo() repo := newUserRepo(t, testUsers)
err := repo.Create(nil, tt.user) err := repo.Create(nil, tt.user)
if tt.err != nil { if tt.err != nil {
if err != tt.err { if err != tt.err {
@ -209,7 +195,7 @@ func TestUpdateUser(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestUserRepo() repo := newUserRepo(t, testUsers)
err := repo.Update(nil, tt.user) err := repo.Update(nil, tt.user)
if tt.err != nil { if tt.err != nil {
if err != tt.err { if err != tt.err {
@ -269,7 +255,7 @@ func TestDisableUser(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestUserRepo() repo := newUserRepo(t, testUsers)
err := repo.Disable(nil, tt.id, tt.disable) err := repo.Disable(nil, tt.id, tt.disable)
switch { switch {
case err != tt.err: case err != tt.err:
@ -320,7 +306,7 @@ func TestAttachRemoteIdentity(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestUserRepo() repo := newUserRepo(t, testUsers)
err := repo.AddRemoteIdentity(nil, tt.id, tt.rid) err := repo.AddRemoteIdentity(nil, tt.id, tt.rid)
if tt.err != nil { if tt.err != nil {
if err != tt.err { if err != tt.err {
@ -390,7 +376,7 @@ func TestRemoveRemoteIdentity(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestUserRepo() repo := newUserRepo(t, testUsers)
err := repo.RemoveRemoteIdentity(nil, tt.id, tt.rid) err := repo.RemoveRemoteIdentity(nil, tt.id, tt.rid)
if tt.err != nil { if tt.err != nil {
if err != tt.err { if err != tt.err {
@ -433,59 +419,6 @@ func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int
return -1 return -1
} }
func TestNewUserRepoFromUsers(t *testing.T) {
tests := []struct {
users []user.UserWithRemoteIdentities
}{
{
users: []user.UserWithRemoteIdentities{
{
User: user.User{
ID: "123",
Email: "email123@example.com",
},
RemoteIdentities: []user.RemoteIdentity{},
},
{
User: user.User{
ID: "456",
Email: "email456@example.com",
},
RemoteIdentities: []user.RemoteIdentity{
{
ID: "remoteID",
ConnectorID: "connID",
},
},
},
},
},
}
for i, tt := range tests {
repo := user.NewUserRepoFromUsers(tt.users)
for _, want := range tt.users {
gotUser, err := repo.Get(nil, want.User.ID)
if err != nil {
t.Errorf("case %d: want nil err: %v", i, err)
}
gotRIDs, err := repo.GetRemoteIdentities(nil, want.User.ID)
if err != nil {
t.Errorf("case %d: want nil err: %v", i, err)
}
if !reflect.DeepEqual(want.User, gotUser) {
t.Errorf("case %d: want=%#v got=%#v", i, want.User, gotUser)
}
if !reflect.DeepEqual(want.RemoteIdentities, gotRIDs) {
t.Errorf("case %d: want=%#v got=%#v", i, want.RemoteIdentities, gotRIDs)
}
}
}
}
func TestGetByEmail(t *testing.T) { func TestGetByEmail(t *testing.T) {
tests := []struct { tests := []struct {
email string email string
@ -502,7 +435,7 @@ func TestGetByEmail(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestUserRepo() repo := newUserRepo(t, testUsers)
gotUser, gotErr := repo.GetByEmail(nil, tt.email) gotUser, gotErr := repo.GetByEmail(nil, tt.email)
if tt.wantErr != nil { if tt.wantErr != nil {
if tt.wantErr != gotErr { if tt.wantErr != gotErr {
@ -566,7 +499,7 @@ func TestGetAdminCount(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestUserRepo() repo := newUserRepo(t, testUsers)
for _, addUser := range tt.addUsers { for _, addUser := range tt.addUsers {
err := repo.Create(nil, addUser) err := repo.Create(nil, addUser)
if err != nil { if err != nil {
@ -621,7 +554,7 @@ func TestList(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := makeTestUserRepoFromUsers(repoUsers) repo := newUserRepo(t, repoUsers)
var tok string var tok string
gotIDs := [][]string{} gotIDs := [][]string{}
done := false done := false
@ -651,7 +584,7 @@ func TestList(t *testing.T) {
} }
func TestListErrorNotFound(t *testing.T) { func TestListErrorNotFound(t *testing.T) {
repo := makeTestUserRepoFromUsers(nil) repo := newUserRepo(t, nil)
_, _, err := repo.List(nil, user.UserFilter{}, 10, "") _, _, err := repo.List(nil, user.UserFilter{}, 10, "")
if err != user.ErrorNotFound { if err != user.ErrorNotFound {
t.Errorf("want=%q, got=%q", user.ErrorNotFound, err) t.Errorf("want=%q, got=%q", user.ErrorNotFound, err)

View file

@ -1,7 +1,9 @@
package integration package integration
import ( import (
"encoding/base64"
"net/http" "net/http"
"net/url"
"reflect" "reflect"
"testing" "testing"
@ -13,7 +15,12 @@ func TestClientCreate(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: "72de74a9",
Secret: "XXX", Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
{Scheme: "https://", Host: "authn.example.com", Path: "/callback"},
},
}, },
} }
cis := []oidc.ClientIdentity{ci} cis := []oidc.ClientIdentity{ci}
@ -54,7 +61,7 @@ func TestClientCreate(t *testing.T) {
call := svc.Clients.Create(newClientInput) call := svc.Clients.Create(newClientInput)
newClient, err := call.Do() newClient, err := call.Do()
if err != nil { if err != nil {
t.Errorf("Call to create client API failed: %v", err) t.Fatalf("Call to create client API failed: %v", err)
} }
if newClient.Id == "" { if newClient.Id == "" {

View file

@ -1,6 +1,7 @@
package integration package integration
import ( import (
"encoding/base64"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -11,7 +12,7 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/db"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" "github.com/coreos/dex/user/manager"
) )
@ -21,7 +22,7 @@ var (
testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"} testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"}
testClientID = "XXX" testClientID = "XXX"
testClientSecret = "yyy" testClientSecret = base64.URLEncoding.EncodeToString([]byte("yyy"))
testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"} testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"} testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
testPrivKey, _ = key.GeneratePrivateKey() testPrivKey, _ = key.GeneratePrivateKey()
@ -45,13 +46,32 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro
} }
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) { func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) {
ur := user.NewUserRepoFromUsers(users) dbMap := db.NewMemDB()
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords) ur := func() user.UserRepo {
repo, err := db.NewUserRepoFromUsers(dbMap, users)
if err != nil {
panic("Failed to create user repo: " + err.Error())
}
return repo
}()
pwr := func() user.PasswordInfoRepo {
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, passwords)
if err != nil {
panic("Failed to create password info repo: " + err.Error())
}
return repo
}()
ccr := connector.NewConnectorConfigRepoFromConfigs( ccr := func() connector.ConnectorConfigRepo {
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}}, repo := db.NewConnectorConfigRepo(dbMap)
) c := []connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}}
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{}) if err := repo.Set(c); err != nil {
panic(err)
}
return repo
}()
um := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
um.Clock = clock um.Clock = clock
return ur, pwr, um return ur, pwr, um
} }

View file

@ -1,6 +1,7 @@
package integration package integration
import ( import (
"encoding/base64"
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
@ -8,12 +9,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
phttp "github.com/coreos/dex/pkg/http" phttp "github.com/coreos/dex/pkg/http"
"github.com/coreos/dex/refresh/refreshtest" "github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/server" "github.com/coreos/dex/server"
"github.com/coreos/dex/session" "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
@ -22,6 +23,7 @@ import (
) )
func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) { func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
dbMap := db.NewMemDB()
k, err := key.GeneratePrivateKey() k, err := key.GeneratePrivateKey()
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to generate private key: %v", err) return nil, fmt.Errorf("Unable to generate private key: %v", err)
@ -32,12 +34,16 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
clientIdentityRepo, err := db.NewClientIdentityRepoFromClients(dbMap, cis)
if err != nil {
return nil, err
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
srv := &server.Server{ srv := &server.Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
ClientIdentityRepo: client.NewClientIdentityRepo(cis), ClientIdentityRepo: clientIdentityRepo,
SessionManager: sm, SessionManager: sm,
} }
@ -113,14 +119,18 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: "72de74a9",
Secret: "XXX", Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
}, },
} }
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) dbMap := db.NewMemDB()
cir, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: " + err.Error())
}
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
k, err := key.GeneratePrivateKey() k, err := key.GeneratePrivateKey()
if err != nil { if err != nil {
@ -138,16 +148,13 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
Email: "testemail@example.com", Email: "testemail@example.com",
DisplayName: "displayname", DisplayName: "displayname",
} }
userRepo := user.NewUserRepo() userRepo := db.NewUserRepo(db.NewMemDB())
if err := userRepo.Create(nil, usr); err != nil { if err := userRepo.Create(nil, usr); err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
passwordInfoRepo := user.NewPasswordInfoRepo() passwordInfoRepo := db.NewPasswordInfoRepo(db.NewMemDB())
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &server.Server{ srv := &server.Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
@ -255,7 +262,7 @@ func TestHTTPClientCredsToken(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: "72de74a9",
Secret: "XXX", Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
}, },
} }
cis := []oidc.ClientIdentity{ci} cis := []oidc.ClientIdentity{ci}

View file

@ -1,6 +1,7 @@
package integration package integration
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -15,6 +16,7 @@ import (
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/server" "github.com/coreos/dex/server"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
@ -97,30 +99,36 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
_, _, um := makeUserObjects(userUsers, userPasswords) _, _, um := makeUserObjects(userUsers, userPasswords)
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ cir := func() client.ClientIdentityRepo {
oidc.ClientIdentity{ repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ oidc.ClientIdentity{
ID: testClientID, Credentials: oidc.ClientCredentials{
Secret: testClientSecret, ID: testClientID,
}, Secret: testClientSecret,
Metadata: oidc.ClientMetadata{ },
RedirectURIs: []url.URL{ Metadata: oidc.ClientMetadata{
testRedirectURL, RedirectURIs: []url.URL{
testRedirectURL,
},
}, },
}, },
}, oidc.ClientIdentity{
oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{
Credentials: oidc.ClientCredentials{ ID: userBadClientID,
ID: userBadClientID, Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
Secret: "secret", },
}, Metadata: oidc.ClientMetadata{
Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{
RedirectURIs: []url.URL{ testRedirectURL,
testRedirectURL, },
}, },
}, },
}, })
}) if err != nil {
panic("Failed to create client identity repo: " + err.Error())
}
return repo
}()
cir.SetDexAdmin(testClientID, true) cir.SetDexAdmin(testClientID, true)

View file

@ -3,16 +3,17 @@ package refreshtest
import ( import (
"fmt" "fmt"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
) )
// NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase. // NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase.
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}. // The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) { func NewTestRefreshTokenRepo() refresh.RefreshTokenRepo {
var tokenIdx int var tokenIdx int
tokenGenerator := func() ([]byte, error) { tokenGenerator := func() ([]byte, error) {
tokenIdx++ tokenIdx++
return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
} }
return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil return db.NewRefreshTokenRepoWithGenerator(db.NewMemDB(), tokenGenerator)
} }

View file

@ -1,13 +1,8 @@
package refresh package refresh
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"encoding/base64"
"errors" "errors"
"fmt"
"strconv"
"strings"
) )
const ( const (
@ -53,121 +48,3 @@ type RefreshTokenRepo interface {
// Revoke deletes the refresh token if the token belongs to the given userID. // Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error Revoke(userID, token string) error
} }
type refreshToken struct {
payload []byte
userID string
clientID string
}
type memRefreshTokenRepo struct {
store map[int]refreshToken
tokenGenerator RefreshTokenGenerator
}
// buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
// parseToken parses a token and returns the token ID and token payload.
func parseToken(token string) (int, []byte, error) {
parts := strings.SplitN(token, TokenDelimer, 2)
if len(parts) != 2 {
return -1, nil, ErrorInvalidToken
}
id, err := strconv.Atoi(parts[0])
if err != nil {
return -1, nil, ErrorInvalidToken
}
tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return -1, nil, ErrorInvalidToken
}
return id, tokenPayload, nil
}
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.
func NewRefreshTokenRepo() RefreshTokenRepo {
return NewRefreshTokenRepoWithTokenGenerator(DefaultRefreshTokenGenerator)
}
func NewRefreshTokenRepoWithTokenGenerator(tokenGenerator RefreshTokenGenerator) RefreshTokenRepo {
repo := &memRefreshTokenRepo{}
repo.store = make(map[int]refreshToken)
repo.tokenGenerator = tokenGenerator
return repo
}
func (r *memRefreshTokenRepo) Create(userID, clientID string) (string, error) {
// Validate userID.
if userID == "" {
return "", ErrorInvalidUserID
}
// Validate clientID.
if clientID == "" {
return "", ErrorInvalidClientID
}
// Generate and store token.
tokenPayload, err := r.tokenGenerator.Generate()
if err != nil {
return "", err
}
tokenID := len(r.store) // Should only be used in single threaded tests.
// No limits on the number of tokens per user/client for this in-memory repo.
r.store[tokenID] = refreshToken{
payload: tokenPayload,
userID: userID,
clientID: clientID,
}
return buildToken(tokenID, tokenPayload), nil
}
func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return "", err
}
record, ok := r.store[tokenID]
if !ok {
return "", ErrorInvalidToken
}
if !bytes.Equal(record.payload, tokenPayload) {
return "", ErrorInvalidToken
}
if record.clientID != clientID {
return "", ErrorInvalidClientID
}
return record.userID, nil
}
func (r *memRefreshTokenRepo) Revoke(userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return err
}
record, ok := r.store[tokenID]
if !ok {
return ErrorInvalidToken
}
if !bytes.Equal(record.payload, tokenPayload) {
return ErrorInvalidToken
}
if record.userID != userID {
return ErrorInvalidUserID
}
delete(r.store, tokenID)
return nil
}

View file

@ -1,7 +1,5 @@
package repo package repo
import "errors"
// Transaction is an abstraction of transactions typically found in database systems. // Transaction is an abstraction of transactions typically found in database systems.
// One of Commit() or Rollback() must be called on each transaction. // One of Commit() or Rollback() must be called on each transaction.
type Transaction interface { type Transaction interface {
@ -13,29 +11,3 @@ type Transaction interface {
} }
type TransactionFactory func() (Transaction, error) type TransactionFactory func() (Transaction, error)
// InMemTransaction satisifies the Transaction interface for in-memory systems.
// However, the only thing it really does is ensure that the same transaction is
// can't be committed/rolled back more than once. As such, this can lead to data
// corruption and should not be used in production systems.
type InMemTransaction bool
func InMemTransactionFactory() (Transaction, error) {
return new(InMemTransaction), nil
}
func (i *InMemTransaction) Commit() error {
return i.commitOrRollback()
}
func (i *InMemTransaction) Rollback() error {
return i.commitOrRollback()
}
func (i *InMemTransaction) commitOrRollback() error {
if *i {
return errors.New("Already committed/rolled-back.")
}
*i = true
return nil
}

View file

@ -1,13 +1,16 @@
package server package server
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"time" "time"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
@ -25,10 +28,19 @@ func TestClientToken(t *testing.T) {
validClientID := "valid-client" validClientID := "valid-client"
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: validClientID, ID: validClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
},
}, },
} }
repo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
privKey, err := key.GeneratePrivateKey() privKey, err := key.GeneratePrivateKey()
if err != nil { if err != nil {
@ -102,7 +114,7 @@ func TestClientToken(t *testing.T) {
// empty repo // empty repo
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: client.NewClientIdentityRepo(nil), repo: db.NewClientIdentityRepo(db.NewMemDB()),
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -9,12 +10,14 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
"sort"
"strings" "strings"
"testing" "testing"
"github.com/coreos/dex/client" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/kylelemons/godebug/pretty"
) )
func makeBody(s string) io.ReadCloser { func makeBody(s string) io.ReadCloser {
@ -24,7 +27,7 @@ func makeBody(s string) io.ReadCloser {
func TestCreateInvalidRequest(t *testing.T) { func TestCreateInvalidRequest(t *testing.T) {
u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"} u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
h := http.Header{"Content-Type": []string{"application/json"}} h := http.Header{"Content-Type": []string{"application/json"}}
repo := client.NewClientIdentityRepo(nil) repo := db.NewClientIdentityRepo(db.NewMemDB())
res := &clientResource{repo: repo} res := &clientResource{repo: repo}
tests := []struct { tests := []struct {
req *http.Request req *http.Request
@ -115,7 +118,7 @@ func TestCreateInvalidRequest(t *testing.T) {
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
repo := client.NewClientIdentityRepo(nil) repo := db.NewClientIdentityRepo(db.NewMemDB())
res := &clientResource{repo: repo} res := &clientResource{repo: repo}
tests := [][]string{ tests := [][]string{
[]string{"http://example.com"}, []string{"http://example.com"},
@ -168,6 +171,11 @@ func TestCreate(t *testing.T) {
} }
func TestList(t *testing.T) { func TestList(t *testing.T) {
b64Encode := func(s string) string {
return base64.URLEncoding.EncodeToString([]byte(s))
}
tests := []struct { tests := []struct {
cs []oidc.ClientIdentity cs []oidc.ClientIdentity
want []*schema.Client want []*schema.Client
@ -181,7 +189,7 @@ func TestList(t *testing.T) {
{ {
cs: []oidc.ClientIdentity{ cs: []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: "bar"}, Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"}, url.URL{Scheme: "http", Host: "example.com"},
@ -200,7 +208,7 @@ func TestList(t *testing.T) {
{ {
cs: []oidc.ClientIdentity{ cs: []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: "bar"}, Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"}, url.URL{Scheme: "http", Host: "example.com"},
@ -208,7 +216,7 @@ func TestList(t *testing.T) {
}, },
}, },
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ID: "biz", Secret: "bang"}, Credentials: oidc.ClientCredentials{ID: "biz", Secret: b64Encode("bang")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"}, url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"},
@ -230,7 +238,11 @@ func TestList(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo := client.NewClientIdentityRepo(tt.cs) repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), tt.cs)
if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
continue
}
res := &clientResource{repo: repo} res := &clientResource{repo: repo}
r, err := http.NewRequest("GET", "http://example.com/clients", nil) r, err := http.NewRequest("GET", "http://example.com/clients", nil)
@ -248,9 +260,17 @@ func TestList(t *testing.T) {
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Errorf("case %d: unexpected error=%v", i, err) t.Errorf("case %d: unexpected error=%v", i, err)
} }
sort.Sort(byClientId(tt.want))
sort.Sort(byClientId(resp.Clients))
if !reflect.DeepEqual(tt.want, resp.Clients) { if diff := pretty.Compare(tt.want, resp.Clients); diff != "" {
t.Errorf("case %d: invalid response body, want=%#v, got=%#v", i, tt.want, resp.Clients) t.Errorf("case %d: invalid response body: %s", i, diff)
} }
} }
} }
type byClientId []*schema.Client
func (b byClientId) Len() int { return len(b) }
func (b byClientId) Less(i, j int) bool { return b[i].Id < b[j].Id }
func (b byClientId) Swap(i, j int) { b[i], b[j] = b[j], b[i] }

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
@ -11,18 +12,17 @@ import (
"time" "time"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/pkg/health" "github.com/coreos/pkg/health"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
"github.com/coreos/dex/refresh" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/session"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
type ServerConfig struct { type ServerConfig struct {
@ -101,20 +101,21 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
return err return err
} }
dbMap := db.NewMemDB()
ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour)) ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour))
kRepo := key.NewPrivateKeySetRepo() kRepo := key.NewPrivateKeySetRepo()
if err = kRepo.Set(ks); err != nil { if err = kRepo.Set(ks); err != nil {
return err return err
} }
cf, err := os.Open(cfg.ClientsFile) clients, err := loadClients(cfg.ClientsFile)
if err != nil { if err != nil {
return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err) return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
} }
defer cf.Close() ciRepo, err := db.NewClientIdentityRepoFromClients(dbMap, clients)
ciRepo, err := client.NewClientIdentityRepoFromReader(cf)
if err != nil { if err != nil {
return fmt.Errorf("unable to read client identities from file %s: %v", cfg.ClientsFile, err) return fmt.Errorf("failed to create client identity repo: %v", err)
} }
f, err := os.Open(cfg.ConnectorsFile) f, err := os.Open(cfg.ConnectorsFile)
@ -126,23 +127,30 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
if err != nil { if err != nil {
return fmt.Errorf("decoding connector configs: %v", err) return fmt.Errorf("decoding connector configs: %v", err)
} }
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs) cfgRepo := db.NewConnectorConfigRepo(dbMap)
if err := cfgRepo.Set(cfgs); err != nil {
return fmt.Errorf("failed to set connectors: %v", err)
}
sRepo := session.NewSessionRepo() sRepo := db.NewSessionRepo(dbMap)
skRepo := session.NewSessionKeyRepo() skRepo := db.NewSessionKeyRepo(dbMap)
sm := session.NewSessionManager(sRepo, skRepo) sm := sessionmanager.NewSessionManager(sRepo, skRepo)
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile) users, err := loadUsers(cfg.UsersFile)
if err != nil { if err != nil {
return fmt.Errorf("unable to read users from file: %v", err) return fmt.Errorf("unable to read users from file: %v", err)
} }
userRepo, err := db.NewUserRepoFromUsers(dbMap, users)
if err != nil {
return err
}
pwiRepo := user.NewPasswordInfoRepo() pwiRepo := db.NewPasswordInfoRepo(dbMap)
refTokRepo := refresh.NewRefreshTokenRepo() refTokRepo := db.NewRefreshTokenRepo(dbMap)
txnFactory := repo.InMemTransactionFactory txnFactory := db.TransactionFactory(dbMap)
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
srv.ClientIdentityRepo = ciRepo srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo srv.ConnectorConfigRepo = cfgRepo
@ -152,7 +160,54 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
srv.SessionManager = sm srv.SessionManager = sm
srv.RefreshTokenRepo = refTokRepo srv.RefreshTokenRepo = refTokRepo
return nil return nil
}
func loadUsers(filepath string) (users []user.UserWithRemoteIdentities, err error) {
f, err := os.Open(filepath)
if err != nil {
return
}
defer f.Close()
err = json.NewDecoder(f).Decode(&users)
return
}
func loadClients(filepath string) ([]oidc.ClientIdentity, error) {
f, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer f.Close()
var c []struct {
ID string `json:"id"`
Secret string `json:"secret"`
RedirectURLs []string `json:"redirectURLs"`
}
if err := json.NewDecoder(f).Decode(&c); err != nil {
return nil, err
}
clients := make([]oidc.ClientIdentity, len(c))
for i, client := range c {
redirectURIs := make([]url.URL, len(client.RedirectURLs))
for j, u := range client.RedirectURLs {
uri, err := url.Parse(u)
if err != nil {
return nil, err
}
redirectURIs[j] = *uri
}
clients[i] = oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
ID: client.ID,
Secret: client.Secret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: redirectURIs,
},
}
}
return clients, nil
} }
func (cfg *MultiServerConfig) Configure(srv *Server) error { func (cfg *MultiServerConfig) Configure(srv *Server) error {
@ -168,6 +223,9 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
if err != nil { if err != nil {
return fmt.Errorf("unable to initialize database connection: %v", err) return fmt.Errorf("unable to initialize database connection: %v", err)
} }
if _, ok := dbc.Dialect.(gorp.PostgresDialect); !ok {
return errors.New("only postgres backend supported for multi server configurations")
}
kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.UseOldFormat, cfg.KeySecrets...) kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.UseOldFormat, cfg.KeySecrets...)
if err != nil { if err != nil {
@ -180,10 +238,10 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
cfgRepo := db.NewConnectorConfigRepo(dbc) cfgRepo := db.NewConnectorConfigRepo(dbc)
userRepo := db.NewUserRepo(dbc) userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc) refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := session.NewSessionManager(sRepo, skRepo) sm := sessionmanager.NewSessionManager(sRepo, skRepo)
srv.ClientIdentityRepo = ciRepo srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -17,7 +18,8 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/session" "github.com/coreos/dex/db"
"github.com/coreos/dex/session/manager"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
@ -75,20 +77,26 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
} }
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ ClientIdentityRepo: func() client.ClientIdentityRepo {
oidc.ClientIdentity{ repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ oidc.ClientIdentity{
ID: "XXX", Credentials: oidc.ClientCredentials{
Secret: "secrete", ID: "XXX",
}, Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
Metadata: oidc.ClientMetadata{ },
RedirectURIs: []url.URL{ Metadata: oidc.ClientMetadata{
url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}, RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"},
},
}, },
}, },
}, })
}), if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}(),
} }
tests := []struct { tests := []struct {
@ -198,21 +206,27 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
} }
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ ClientIdentityRepo: func() client.ClientIdentityRepo {
oidc.ClientIdentity{ repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ oidc.ClientIdentity{
ID: "XXX", Credentials: oidc.ClientCredentials{
Secret: "secrete", ID: "XXX",
}, Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
Metadata: oidc.ClientMetadata{ },
RedirectURIs: []url.URL{ Metadata: oidc.ClientMetadata{
url.URL{Scheme: "http", Host: "foo.example.com", Path: "/callback"}, RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "bar.example.com", Path: "/callback"}, url.URL{Scheme: "http", Host: "foo.example.com", Path: "/callback"},
url.URL{Scheme: "http", Host: "bar.example.com", Path: "/callback"},
},
}, },
}, },
}, })
}), if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}(),
} }
tests := []struct { tests := []struct {

View file

@ -9,10 +9,10 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
type sendResetPasswordEmailData struct { type sendResetPasswordEmailData struct {
@ -28,7 +28,7 @@ type sendResetPasswordEmailData struct {
type SendResetPasswordEmailHandler struct { type SendResetPasswordEmailHandler struct {
tpl *template.Template tpl *template.Template
emailer *useremail.UserEmailer emailer *useremail.UserEmailer
sm *session.SessionManager sm *sessionmanager.SessionManager
cr client.ClientIdentityRepo cr client.ClientIdentityRepo
} }
@ -182,7 +182,7 @@ type resetPasswordTemplateData struct {
type ResetPasswordHandler struct { type ResetPasswordHandler struct {
tpl *template.Template tpl *template.Template
issuerURL url.URL issuerURL url.URL
um *manager.UserManager um *usermanager.UserManager
keysFunc func() ([]key.PublicKey, error) keysFunc func() ([]key.PublicKey, error)
} }
@ -238,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext) cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
if err != nil { if err != nil {
switch err { switch err {
case manager.ErrorPasswordAlreadyChanged: case usermanager.ErrorPasswordAlreadyChanged:
r.data.Error = "Link Expired" r.data.Error = "Link Expired"
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email." r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
r.data.DontShowForm = true r.data.DontShowForm = true

View file

@ -10,8 +10,9 @@ import (
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
@ -274,7 +275,7 @@ func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url.
return &ru return &ru
} }
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) { func registerFromLocalConnector(userManager *usermanager.UserManager, sessionManager *sessionmanager.SessionManager, ses *session.Session, email, password string) (string, error) {
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID) userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
if err != nil { if err != nil {
return "", err return "", err
@ -289,7 +290,7 @@ func registerFromLocalConnector(userManager *manager.UserManager, sessionManager
return userID, nil return userID, nil
} }
func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) { func registerFromRemoteConnector(userManager *usermanager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
if ses.Identity.ID == "" { if ses.Identity.ID == "" {
return "", errors.New("No Identity found in session.") return "", errors.New("No Identity found in session.")
} }

View file

@ -22,10 +22,11 @@ import (
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
usersapi "github.com/coreos/dex/user/api" usersapi "github.com/coreos/dex/user/api"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
const ( const (
@ -57,7 +58,7 @@ type Server struct {
IssuerURL url.URL IssuerURL url.URL
KeyManager key.PrivateKeyManager KeyManager key.PrivateKeyManager
KeySetRepo key.PrivateKeySetRepo KeySetRepo key.PrivateKeySetRepo
SessionManager *session.SessionManager SessionManager *sessionmanager.SessionManager
ClientIdentityRepo client.ClientIdentityRepo ClientIdentityRepo client.ClientIdentityRepo
ConnectorConfigRepo connector.ConnectorConfigRepo ConnectorConfigRepo connector.ConnectorConfigRepo
Templates *template.Template Templates *template.Template
@ -69,7 +70,7 @@ type Server struct {
HealthChecks []health.Checkable HealthChecks []health.Checkable
Connectors []connector.Connector Connectors []connector.Connector
UserRepo user.UserRepo UserRepo user.UserRepo
UserManager *manager.UserManager UserManager *usermanager.UserManager
PasswordInfoRepo user.PasswordInfoRepo PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer UserEmailer *useremail.UserEmailer

View file

@ -10,8 +10,9 @@ import (
"time" "time"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh/refreshtest" "github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/session" "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
@ -20,6 +21,8 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
) )
var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secrete"))
type StaticKeyManager struct { type StaticKeyManager struct {
key.PrivateKeyManager key.PrivateKeyManager
expiresAt time.Time expiresAt time.Time
@ -68,14 +71,14 @@ func (ss *StaticSigner) JWK() jose.JWK {
return jose.JWK{} return jose.JWK{}
} }
func staticGenerateCodeFunc(code string) session.GenerateCodeFunc { func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
return func() (string, error) { return func() (string, error) {
return code, nil return code, nil
} }
} }
func makeNewUserRepo() (user.UserRepo, error) { func makeNewUserRepo() (user.UserRepo, error) {
userRepo := user.NewUserRepo() userRepo := db.NewUserRepo(db.NewMemDB())
id := "testid-1" id := "testid-1"
err := userRepo.Create(nil, user.User{ err := userRepo.Create(nil, user.User{
@ -120,7 +123,7 @@ func TestServerProviderConfig(t *testing.T) {
} }
func TestServerNewSession(t *testing.T) { func TestServerNewSession(t *testing.T) {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{ srv := &Server{
SessionManager: sm, SessionManager: sm,
} }
@ -179,7 +182,7 @@ func TestServerLogin(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -191,13 +194,19 @@ func TestServerLogin(t *testing.T) {
}, },
}, },
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) ciRepo := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = staticGenerateCodeFunc("fakecode") sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"}) sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil { if err != nil {
@ -235,17 +244,24 @@ func TestServerLogin(t *testing.T) {
} }
func TestServerLoginUnrecognizedSessionKey(t *testing.T) { func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ ciRepo := func() client.ClientIdentityRepo {
oidc.ClientIdentity{ repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ oidc.ClientIdentity{
ID: "XXX", Secret: "secrete", Credentials: oidc.ClientCredentials{
ID: "XXX", Secret: clientTestSecret,
},
}, },
}, })
}) if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: nil, err: errors.New("fail")}, signer: &StaticSigner{sig: nil, err: errors.New("fail")},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
@ -268,7 +284,7 @@ func TestServerLoginDisabledUser(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -280,13 +296,19 @@ func TestServerLoginDisabledUser(t *testing.T) {
}, },
}, },
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) ciRepo := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = staticGenerateCodeFunc("fakecode") sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"}) sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil { if err != nil {
@ -336,24 +358,27 @@ func TestServerCodeToken(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
}, },
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) ciRepo := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
userRepo, err := makeNewUserRepo() userRepo, err := makeNewUserRepo()
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
@ -375,8 +400,10 @@ func TestServerCodeToken(t *testing.T) {
}, },
// Have 'offline_access' in scope, should get non-empty refresh token. // Have 'offline_access' in scope, should get non-empty refresh token.
{ {
// NOTE(ericchiang): This test assumes that the database ID of the first
// refresh token will be "1".
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
}, },
} }
@ -417,14 +444,20 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
}, },
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) ciRepo := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
}
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
@ -460,7 +493,7 @@ func TestServerTokenFail(t *testing.T) {
keyFixture := "goodkey" keyFixture := "goodkey"
ccFixture := oidc.ClientCredentials{ ccFixture := oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: clientTestSecret,
} }
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
@ -474,11 +507,13 @@ func TestServerTokenFail(t *testing.T) {
}{ }{
// control test case to make sure fixtures check out // control test case to make sure fixtures check out
{ {
// NOTE(ericchiang): This test assumes that the database ID of the first
// refresh token will be "1".
signer: signerFixture, signer: signerFixture,
argCC: ccFixture, argCC: ccFixture,
argKey: keyFixture, argKey: keyFixture,
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
}, },
// no 'offline_access' in 'scope', should get empty refresh token // no 'offline_access' in 'scope', should get empty refresh token
@ -518,7 +553,7 @@ func TestServerTokenFail(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = func() (string, error) { return keyFixture, nil } sm.GenerateCode = func() (string, error) { return keyFixture, nil }
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope) sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
@ -534,9 +569,13 @@ func TestServerTokenFail(t *testing.T) {
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: tt.signer, signer: tt.signer,
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{Credentials: ccFixture}, oidc.ClientIdentity{Credentials: ccFixture},
}) })
if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
continue
}
_, err = sm.AttachUser(sessionID, "testid-1") _, err = sm.AttachUser(sessionID, "testid-1")
if err != nil { if err != nil {
@ -548,10 +587,7 @@ func TestServerTokenFail(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &Server{ srv := &Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
@ -590,15 +626,17 @@ func TestServerRefreshToken(t *testing.T) {
credXXX := oidc.ClientCredentials{ credXXX := oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secret", Secret: clientTestSecret,
} }
credYYY := oidc.ClientCredentials{ credYYY := oidc.ClientCredentials{
ID: "YYY", ID: "YYY",
Secret: "secret", Secret: clientTestSecret,
} }
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
// NOTE(ericchiang): These tests assume that the database ID of the first
// refresh token will be "1".
tests := []struct { tests := []struct {
token string token string
clientID string // The client that associates with the token. clientID string // The client that associates with the token.
@ -608,7 +646,7 @@ func TestServerRefreshToken(t *testing.T) {
}{ }{
// Everything is good. // Everything is good.
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
@ -624,7 +662,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid refresh token(invalid payload content). // Invalid refresh token(invalid payload content).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
@ -632,7 +670,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid refresh token(invalid ID content). // Invalid refresh token(invalid ID content).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
@ -640,7 +678,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(client is not associated with the token). // Invalid client(client is not associated with the token).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credYYY, credYYY,
signerFixture, signerFixture,
@ -648,7 +686,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no client ID). // Invalid client(no client ID).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "", Secret: "aaa"}, oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture, signerFixture,
@ -656,7 +694,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no such client). // Invalid client(no such client).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture, signerFixture,
@ -664,7 +702,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no secrets). // Invalid client(no secrets).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX"}, oidc.ClientCredentials{ID: "XXX"},
signerFixture, signerFixture,
@ -672,7 +710,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(invalid secret). // Invalid client(invalid secret).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"}, oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
signerFixture, signerFixture,
@ -680,7 +718,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Signing operation fails. // Signing operation fails.
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credXXX, credXXX,
&StaticSigner{sig: nil, err: errors.New("fail")}, &StaticSigner{sig: nil, err: errors.New("fail")},
@ -693,20 +731,21 @@ func TestServerRefreshToken(t *testing.T) {
signer: tt.signer, signer: tt.signer,
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{Credentials: credXXX}, oidc.ClientIdentity{Credentials: credXXX},
oidc.ClientIdentity{Credentials: credYYY}, oidc.ClientIdentity{Credentials: credYYY},
}) })
if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
continue
}
userRepo, err := makeNewUserRepo() userRepo, err := makeNewUserRepo()
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &Server{ srv := &Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
@ -745,10 +784,13 @@ func TestServerRefreshToken(t *testing.T) {
signer: signerFixture, signer: signerFixture,
} }
ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ ciRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{Credentials: credXXX}, oidc.ClientIdentity{Credentials: credXXX},
oidc.ClientIdentity{Credentials: credYYY}, oidc.ClientIdentity{Credentials: credYYY},
}) })
if err != nil {
t.Fatalf("failed to create client identity repo: %v", err)
}
userRepo, err := makeNewUserRepo() userRepo, err := makeNewUserRepo()
if err != nil { if err != nil {
@ -763,10 +805,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &Server{ srv := &Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
@ -787,7 +826,7 @@ func TestServerRefreshToken(t *testing.T) {
} }
srv.UserRepo = userRepo srv.UserRepo = userRepo
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1")))) _, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) { if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err) t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
} }

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"time" "time"
@ -10,12 +11,12 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
"github.com/coreos/dex/repo" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/session"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
const ( const (
@ -24,9 +25,8 @@ const (
) )
var ( var (
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"} testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
testClientID = "XXX" testClientID = "XXX"
testClientSecret = "secrete"
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"} testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
@ -75,13 +75,13 @@ var (
type testFixtures struct { type testFixtures struct {
srv *Server srv *Server
userRepo user.UserRepo userRepo user.UserRepo
sessionManager *session.SessionManager sessionManager *sessionmanager.SessionManager
emailer *email.TemplatizedEmailer emailer *email.TemplatizedEmailer
redirectURL url.URL redirectURL url.URL
clientIdentityRepo client.ClientIdentityRepo clientIdentityRepo client.ClientIdentityRepo
} }
func sequentialGenerateCodeFunc() session.GenerateCodeFunc { func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
x := 0 x := 0
return func() (string, error) { return func() (string, error) {
x += 1 x += 1
@ -90,8 +90,15 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
} }
func makeTestFixtures() (*testFixtures, error) { func makeTestFixtures() (*testFixtures, error) {
userRepo := user.NewUserRepoFromUsers(testUsers) dbMap := db.NewMemDB()
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos) userRepo, err := db.NewUserRepoFromUsers(dbMap, testUsers)
if err != nil {
return nil, err
}
pwRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, testPasswordInfos)
if err != nil {
return nil, err
}
connConfigs := []connector.ConnectorConfig{ connConfigs := []connector.ConnectorConfig{
&connector.OIDCConnectorConfig{ &connector.OIDCConnectorConfig{
@ -111,11 +118,14 @@ func makeTestFixtures() (*testFixtures, error) {
ID: "local", ID: "local",
}, },
} }
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs) connCfgRepo := db.NewConnectorConfigRepo(dbMap)
if err := connCfgRepo.Set(connConfigs); err != nil {
return nil, err
}
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{}) manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sessionManager.GenerateCode = sequentialGenerateCodeFunc() sessionManager.GenerateCode = sequentialGenerateCodeFunc()
emailer, err := email.NewTemplatizedEmailerFromGlobs( emailer, err := email.NewTemplatizedEmailerFromGlobs(
@ -126,11 +136,11 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err return nil, err
} }
clientIdentityRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ clientIdentityRepo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: testClientSecret, Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -139,6 +149,9 @@ func makeTestFixtures() (*testFixtures, error) {
}, },
}, },
}) })
if err != nil {
return nil, err
}
km := key.NewPrivateKeyManager() km := key.NewPrivateKeyManager()
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute))) err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)))

View file

@ -1,4 +1,4 @@
package session package manager
import ( import (
"crypto/rand" "crypto/rand"
@ -10,6 +10,7 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
@ -27,11 +28,11 @@ func DefaultGenerateCode() (string, error) {
return base64.URLEncoding.EncodeToString(b), nil return base64.URLEncoding.EncodeToString(b), nil
} }
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager { func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
return &SessionManager{ return &SessionManager{
GenerateCode: DefaultGenerateCode, GenerateCode: DefaultGenerateCode,
Clock: clockwork.NewRealClock(), Clock: clockwork.NewRealClock(),
ValidityWindow: DefaultSessionValidityWindow, ValidityWindow: session.DefaultSessionValidityWindow,
sessions: sRepo, sessions: sRepo,
keys: skRepo, keys: skRepo,
} }
@ -41,8 +42,8 @@ type SessionManager struct {
GenerateCode GenerateCodeFunc GenerateCode GenerateCodeFunc
Clock clockwork.Clock Clock clockwork.Clock
ValidityWindow time.Duration ValidityWindow time.Duration
sessions SessionRepo sessions session.SessionRepo
keys SessionKeyRepo keys session.SessionKeyRepo
} }
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) { func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
@ -52,10 +53,10 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r
} }
now := m.Clock.Now() now := m.Clock.Now()
s := Session{ s := session.Session{
ConnectorID: connectorID, ConnectorID: connectorID,
ID: sID, ID: sID,
State: SessionStateNew, State: session.SessionStateNew,
CreatedAt: now, CreatedAt: now,
ExpiresAt: now.Add(m.ValidityWindow), ExpiresAt: now.Add(m.ValidityWindow),
ClientID: clientID, ClientID: clientID,
@ -80,11 +81,12 @@ func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
return "", err return "", err
} }
k := SessionKey{ k := session.SessionKey{
Key: key, Key: key,
SessionID: sessionID, SessionID: sessionID,
} }
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
err = m.keys.Push(k, sessionKeyValidityWindow) err = m.keys.Push(k, sessionKeyValidityWindow)
if err != nil { if err != nil {
return "", err return "", err
@ -97,7 +99,7 @@ func (m *SessionManager) ExchangeKey(key string) (string, error) {
return m.keys.Pop(key) return m.keys.Pop(key)
} }
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) { func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
s, err := m.sessions.Get(sessionID) s, err := m.sessions.Get(sessionID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -110,14 +112,14 @@ func (m *SessionManager) getSessionInState(sessionID string, state SessionState)
return s, nil return s, nil
} }
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) { func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateNew) s, err := m.getSessionInState(sessionID, session.SessionStateNew)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Identity = ident s.Identity = ident
s.State = SessionStateRemoteAttached s.State = session.SessionStateRemoteAttached
if err = m.sessions.Update(*s); err != nil { if err = m.sessions.Update(*s); err != nil {
return nil, err return nil, err
@ -126,14 +128,14 @@ func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Ident
return s, nil return s, nil
} }
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) { func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached) s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.UserID = userID s.UserID = userID
s.State = SessionStateIdentified s.State = session.SessionStateIdentified
if err = m.sessions.Update(*s); err != nil { if err = m.sessions.Update(*s); err != nil {
return nil, err return nil, err
@ -142,13 +144,13 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session,
return s, nil return s, nil
} }
func (m *SessionManager) Kill(sessionID string) (*Session, error) { func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID) s, err := m.sessions.Get(sessionID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.State = SessionStateDead s.State = session.SessionStateDead
if err = m.sessions.Update(*s); err != nil { if err = m.sessions.Update(*s); err != nil {
return nil, err return nil, err
@ -157,6 +159,6 @@ func (m *SessionManager) Kill(sessionID string) (*Session, error) {
return s, nil return s, nil
} }
func (m *SessionManager) Get(sessionID string) (*Session, error) { func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
return m.sessions.Get(sessionID) return m.sessions.Get(sessionID)
} }

View file

@ -1,9 +1,11 @@
package session package manager
import ( import (
"net/url" "net/url"
"testing" "testing"
"github.com/coreos/dex/db"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
@ -13,8 +15,13 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc {
} }
} }
func newManager() *SessionManager {
dbMap := db.NewMemDB()
return NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
}
func TestSessionManagerNewSession(t *testing.T) { func TestSessionManagerNewSession(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager()
sm.GenerateCode = staticGenerateCodeFunc("boo") sm.GenerateCode = staticGenerateCodeFunc("boo")
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
@ -26,7 +33,7 @@ func TestSessionManagerNewSession(t *testing.T) {
} }
func TestSessionAttachRemoteIdentityTwice(t *testing.T) { func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager()
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
@ -43,7 +50,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
} }
func TestSessionManagerExchangeKey(t *testing.T) { func TestSessionManagerExchangeKey(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager()
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
@ -68,8 +75,8 @@ func TestSessionManagerExchangeKey(t *testing.T) {
} }
func TestSessionManagerGetSessionInStateNoExist(t *testing.T) { func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager()
ses, err := sm.getSessionInState("123", SessionStateNew) ses, err := sm.getSessionInState("123", session.SessionStateNew)
if err == nil { if err == nil {
t.Errorf("Expected non-nil error") t.Errorf("Expected non-nil error")
} }
@ -79,12 +86,12 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
} }
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) { func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager()
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
ses, err := sm.getSessionInState(sessionID, SessionStateDead) ses, err := sm.getSessionInState(sessionID, session.SessionStateDead)
if err == nil { if err == nil {
t.Errorf("Expected non-nil error") t.Errorf("Expected non-nil error")
} }
@ -94,7 +101,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
} }
func TestSessionManagerKill(t *testing.T) { func TestSessionManagerKill(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager()
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)

View file

@ -1,11 +1,6 @@
package session package session
import ( import "time"
"errors"
"time"
"github.com/jonboulle/clockwork"
)
type SessionRepo interface { type SessionRepo interface {
Get(string) (*Session, error) Get(string) (*Session, error)
@ -17,87 +12,3 @@ type SessionKeyRepo interface {
Push(SessionKey, time.Duration) error Push(SessionKey, time.Duration) error
Pop(string) (string, error) Pop(string) (string, error)
} }
func NewSessionRepo() SessionRepo {
return NewSessionRepoWithClock(clockwork.NewRealClock())
}
func NewSessionRepoWithClock(clock clockwork.Clock) SessionRepo {
return &memSessionRepo{
store: make(map[string]Session),
clock: clock,
}
}
type memSessionRepo struct {
store map[string]Session
clock clockwork.Clock
}
func (m *memSessionRepo) Get(sessionID string) (*Session, error) {
s, ok := m.store[sessionID]
if !ok || s.ExpiresAt.Before(m.clock.Now()) {
return nil, errors.New("unrecognized ID")
}
return &s, nil
}
func (m *memSessionRepo) Create(s Session) error {
if _, ok := m.store[s.ID]; ok {
return errors.New("ID exists")
}
m.store[s.ID] = s
return nil
}
func (m *memSessionRepo) Update(s Session) error {
if _, ok := m.store[s.ID]; !ok {
return errors.New("unrecognized ID")
}
m.store[s.ID] = s
return nil
}
type expiringSessionKey struct {
SessionKey
expiresAt time.Time
}
func NewSessionKeyRepo() SessionKeyRepo {
return NewSessionKeyRepoWithClock(clockwork.NewRealClock())
}
func NewSessionKeyRepoWithClock(clock clockwork.Clock) SessionKeyRepo {
return &memSessionKeyRepo{
store: make(map[string]expiringSessionKey),
clock: clock,
}
}
type memSessionKeyRepo struct {
store map[string]expiringSessionKey
clock clockwork.Clock
}
func (m *memSessionKeyRepo) Pop(key string) (string, error) {
esk, ok := m.store[key]
if !ok {
return "", errors.New("unrecognized key")
}
defer delete(m.store, key)
if esk.expiresAt.Before(m.clock.Now()) {
return "", errors.New("expired key")
}
return esk.SessionKey.SessionID, nil
}
func (m *memSessionKeyRepo) Push(sk SessionKey, ttl time.Duration) error {
m.store[sk.Key] = expiringSessionKey{
SessionKey: sk,
expiresAt: m.clock.Now().Add(ttl),
}
return nil
}

View file

@ -1,22 +1,22 @@
[ [
{ {
"id": "XXX", "id": "XXX",
"secret": "secrete", "secret": "c2VjcmV0ZQ==",
"redirectURLs": ["http://127.0.0.1:5555/callback"] "redirectURLs": ["http://127.0.0.1:5555/callback"]
}, },
{ {
"id": "example-app", "id": "example-app",
"secret": "example-app-secret", "secret": "ZXhhbXBsZS1hcHAtc2VjcmV0",
"redirectURLs": ["http://127.0.0.1:5555/callback"] "redirectURLs": ["http://127.0.0.1:5555/callback"]
}, },
{ {
"id": "example-cli", "id": "example-cli",
"secret": "example-cli-secret", "secret": "ZXhhbXBsZS1jbGktc2VjcmV0",
"redirectURLs": ["http://127.0.0.1:8000/admin/v1/oauth/login"] "redirectURLs": ["http://127.0.0.1:8000/admin/v1/oauth/login"]
}, },
{ {
"id": "oauth2_proxy", "id": "oauth2_proxy",
"secret": "proxy", "secret": "cHJveHk=",
"redirectURLs": ["http://127.0.0.1:4180/oauth2/callback"] "redirectURLs": ["http://127.0.0.1:4180/oauth2/callback"]
} }
] ]

8
test
View file

@ -12,9 +12,13 @@
# Invoke ./cover for HTML output # Invoke ./cover for HTML output
COVER=${COVER:-"-cover"} COVER=${COVER:-"-cover"}
source ./build source ./env
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email admin" if [ ! -d $GOPATH/pkg ]; then
echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests."
fi
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log" FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override # user has not provided PKG override

View file

@ -1,4 +1,6 @@
#!/bin/bash -e #!/bin/bash -e
source ./build
source ./env
go test $@ github.com/coreos/dex/functional go test $@ github.com/coreos/dex/functional
go test $@ github.com/coreos/dex/functional/repo go test $@ github.com/coreos/dex/functional/repo

View file

@ -1,6 +1,7 @@
package api package api
import ( import (
"encoding/base64"
"net/url" "net/url"
"testing" "testing"
"time" "time"
@ -11,7 +12,7 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" "github.com/coreos/dex/user/manager"
@ -86,54 +87,77 @@ var (
) )
func makeTestFixtures() (*UsersAPI, *testEmailer) { func makeTestFixtures() (*UsersAPI, *testEmailer) {
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ dbMap := db.NewMemDB()
{ ur := func() user.UserRepo {
User: user.User{ repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
ID: "ID-1", {
Email: "id1@example.com", User: user.User{
Admin: true, ID: "ID-1",
CreatedAt: clock.Now(), Email: "id1@example.com",
Admin: true,
CreatedAt: clock.Now(),
},
}, {
User: user.User{
ID: "ID-2",
Email: "id2@example.com",
CreatedAt: clock.Now(),
},
}, {
User: user.User{
ID: "ID-3",
Email: "id3@example.com",
CreatedAt: clock.Now(),
},
}, {
User: user.User{
ID: "ID-4",
Email: "id4@example.com",
CreatedAt: clock.Now(),
Disabled: true,
},
}, },
}, { })
User: user.User{ if err != nil {
ID: "ID-2", panic("Failed to create user repo: " + err.Error())
Email: "id2@example.com", }
CreatedAt: clock.Now(), return repo
}()
pwr := func() user.PasswordInfoRepo {
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{
{
UserID: "ID-1",
Password: []byte("password-1"),
}, },
}, { {
User: user.User{ UserID: "ID-2",
ID: "ID-3", Password: []byte("password-2"),
Email: "id3@example.com",
CreatedAt: clock.Now(),
}, },
}, { })
User: user.User{ if err != nil {
ID: "ID-4", panic("Failed to create user repo: " + err.Error())
Email: "id4@example.com", }
CreatedAt: clock.Now(), return repo
Disabled: true, }()
},
}, ccr := func() connector.ConnectorConfigRepo {
}) repo := db.NewConnectorConfigRepo(dbMap)
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ c := []connector.ConnectorConfig{
{ &connector.LocalConnectorConfig{ID: "local"},
UserID: "ID-1", }
Password: []byte("password-1"), if err := repo.Set(c); err != nil {
}, panic(err)
{ }
UserID: "ID-2", return repo
Password: []byte("password-2"), }()
},
}) mgr := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
mgr.Clock = clock mgr.Clock = clock
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "XXX",
Secret: "secrete", Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -141,7 +165,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}, },
}, },
} }
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) cir := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
if err != nil {
panic("Failed to create client identity repo: " + err.Error())
}
return repo
}()
emailer := &testEmailer{} emailer := &testEmailer{}
api := NewUsersAPI(mgr, cir, emailer, "local") api := NewUsersAPI(mgr, cir, emailer, "local")

View file

@ -12,6 +12,7 @@ import (
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
) )
@ -45,35 +46,49 @@ func (t *testEmailer) SendMail(from, subject, text, html string, to ...string) e
} }
func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) { func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ dbMap := db.NewMemDB()
{ ur := func() user.UserRepo {
User: user.User{ repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
ID: "ID-1", {
Email: "id1@example.com", User: user.User{
Admin: true, ID: "ID-1",
Email: "id1@example.com",
Admin: true,
},
}, {
User: user.User{
ID: "ID-2",
Email: "id2@example.com",
},
}, {
User: user.User{
ID: "ID-3",
Email: "id3@example.com",
},
}, },
}, { })
User: user.User{ if err != nil {
ID: "ID-2", panic("Failed to create user repo: " + err.Error())
Email: "id2@example.com", }
return repo
}()
pwr := func() user.PasswordInfoRepo {
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{
{
UserID: "ID-1",
Password: []byte("password-1"),
}, },
}, { {
User: user.User{ UserID: "ID-2",
ID: "ID-3", Password: []byte("password-2"),
Email: "id3@example.com",
}, },
}, })
}) if err != nil {
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ panic("Failed to create user repo: " + err.Error())
{ }
UserID: "ID-1", return repo
Password: []byte("password-1"), }()
},
{
UserID: "ID-2",
Password: []byte("password-2"),
},
})
privKey, err := key.GeneratePrivateKey() privKey, err := key.GeneratePrivateKey()
if err != nil { if err != nil {

View file

@ -10,7 +10,7 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/db"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
) )
@ -26,46 +26,69 @@ func makeTestFixtures() *testFixtures {
f := &testFixtures{} f := &testFixtures{}
f.clock = clockwork.NewFakeClock() f.clock = clockwork.NewFakeClock()
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ dbMap := db.NewMemDB()
{ f.ur = func() user.UserRepo {
User: user.User{ repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
ID: "ID-1", {
Email: "Email-1@example.com", User: user.User{
}, ID: "ID-1",
RemoteIdentities: []user.RemoteIdentity{ Email: "Email-1@example.com",
{ },
ConnectorID: "local", RemoteIdentities: []user.RemoteIdentity{
ID: "1", {
ConnectorID: "local",
ID: "1",
},
},
}, {
User: user.User{
ID: "ID-2",
Email: "Email-2@example.com",
EmailVerified: true,
},
RemoteIdentities: []user.RemoteIdentity{
{
ConnectorID: "local",
ID: "2",
},
}, },
}, },
}, { })
User: user.User{ if err != nil {
ID: "ID-2", panic("Failed to create user repo: " + err.Error())
Email: "Email-2@example.com", }
EmailVerified: true, return repo
}()
f.pwr = func() user.PasswordInfoRepo {
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{
{
UserID: "ID-1",
Password: []byte("password-1"),
}, },
RemoteIdentities: []user.RemoteIdentity{ {
{ UserID: "ID-2",
ConnectorID: "local", Password: []byte("password-2"),
ID: "2",
},
}, },
}, })
}) if err != nil {
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ panic("Failed to create user repo: " + err.Error())
{ }
UserID: "ID-1", return repo
Password: []byte("password-1"), }()
},
{ f.ccr = func() connector.ConnectorConfigRepo {
UserID: "ID-2", repo := db.NewConnectorConfigRepo(dbMap)
Password: []byte("password-2"), c := []connector.ConnectorConfig{
}, &connector.LocalConnectorConfig{ID: "local"},
}) }
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{ if err := repo.Set(c); err != nil {
&connector.LocalConnectorConfig{ID: "local"}, panic(err)
}) }
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{}) return repo
}()
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, db.TransactionFactory(dbMap), ManagerOptions{})
f.mgr.Clock = f.clock f.mgr.Clock = f.clock
return f return f
} }
@ -207,18 +230,22 @@ func TestRegisterWithPassword(t *testing.T) {
} }
if diff := pretty.Compare(usr, ridUSR); diff != "" { if diff := pretty.Compare(usr, ridUSR); diff != "" {
t.Errorf("case %d: Compare(want, got) = %v", i, diff) t.Errorf("case %d: Compare(want, got) = %v", i, diff)
continue
} }
pwi, err := f.pwr.Get(nil, userID) pwi, err := f.pwr.Get(nil, userID)
if err != nil { if err != nil {
t.Errorf("case %d: err != nil: %q", i, err) t.Errorf("case %d: err != nil: %q", i, err)
continue
} }
ident, err := pwi.Authenticate(tt.plaintext) ident, err := pwi.Authenticate(tt.plaintext)
if err != nil { if err != nil {
t.Errorf("case %d: err != nil: %q", i, err) t.Errorf("case %d: err != nil: %q", i, err)
continue
} }
if ident.ID != userID { if ident.ID != userID {
t.Errorf("case %d: ident.ID: want=%q, got=%q", i, userID, ident.ID) t.Errorf("case %d: ident.ID: want=%q, got=%q", i, userID, ident.ID)
continue
} }
_, err = pwi.Authenticate(tt.plaintext + "WRONG") _, err = pwi.Authenticate(tt.plaintext + "WRONG")
@ -274,7 +301,7 @@ func TestVerifyEmail(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
f := makeTestFixtures() f := makeTestFixtures()
cb, err := f.mgr.VerifyEmail(user.EmailVerification{tt.evClaims}) cb, err := f.mgr.VerifyEmail(user.EmailVerification{Claims: tt.evClaims})
if tt.wantErr { if tt.wantErr {
if err == nil { if err == nil {
t.Errorf("case %d: want non-nil err", i) t.Errorf("case %d: want non-nil err", i)
@ -344,7 +371,7 @@ func TestChangePassword(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
f := makeTestFixtures() f := makeTestFixtures()
cb, err := f.mgr.ChangePassword(user.PasswordReset{tt.pwrClaims}, tt.newPassword) cb, err := f.mgr.ChangePassword(user.PasswordReset{Claims: tt.pwrClaims}, tt.newPassword)
if tt.wantErr { if tt.wantErr {
if err == nil { if err == nil {
t.Errorf("case %d: want non-nil err", i) t.Errorf("case %d: want non-nil err", i)

View file

@ -4,9 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/url" "net/url"
"os"
"time" "time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -85,60 +83,6 @@ type PasswordInfoRepo interface {
Create(repo.Transaction, PasswordInfo) error Create(repo.Transaction, PasswordInfo) error
} }
func NewPasswordInfoRepo() PasswordInfoRepo {
return &memPasswordInfoRepo{
pws: make(map[string]PasswordInfo),
}
}
type memPasswordInfoRepo struct {
pws map[string]PasswordInfo
}
func (m *memPasswordInfoRepo) Get(_ repo.Transaction, id string) (PasswordInfo, error) {
pw, ok := m.pws[id]
if !ok {
return PasswordInfo{}, ErrorNotFound
}
return pw, nil
}
func (m *memPasswordInfoRepo) Create(_ repo.Transaction, pw PasswordInfo) error {
_, ok := m.pws[pw.UserID]
if ok {
return ErrorDuplicateID
}
if pw.UserID == "" {
return ErrorInvalidID
}
if len(pw.Password) == 0 {
return ErrorInvalidPassword
}
m.pws[pw.UserID] = pw
return nil
}
func (m *memPasswordInfoRepo) Update(_ repo.Transaction, pw PasswordInfo) error {
if pw.UserID == "" {
return ErrorInvalidID
}
_, ok := m.pws[pw.UserID]
if !ok {
return ErrorNotFound
}
if len(pw.Password) == 0 {
return ErrorInvalidPassword
}
m.pws[pw.UserID] = pw
return nil
}
func (u *PasswordInfo) UnmarshalJSON(data []byte) error { func (u *PasswordInfo) UnmarshalJSON(data []byte) error {
var dec struct { var dec struct {
UserID string `json:"userId"` UserID string `json:"userId"`
@ -172,21 +116,6 @@ func (u *PasswordInfo) UnmarshalJSON(data []byte) error {
return nil return nil
} }
func newPasswordInfosFromReader(r io.Reader) ([]PasswordInfo, error) {
var pws []PasswordInfo
err := json.NewDecoder(r).Decode(&pws)
return pws, err
}
func readPasswordInfosFromFile(loc string) ([]PasswordInfo, error) {
pwf, err := os.Open(loc)
if err != nil {
return nil, fmt.Errorf("unable to read password info from file %q: %v", loc, err)
}
return newPasswordInfosFromReader(pwf)
}
func LoadPasswordInfos(repo PasswordInfoRepo, pws []PasswordInfo) error { func LoadPasswordInfos(repo PasswordInfoRepo, pws []PasswordInfo) error {
for i, pw := range pws { for i, pw := range pws {
err := repo.Create(nil, pw) err := repo.Create(nil, pw)
@ -197,23 +126,6 @@ func LoadPasswordInfos(repo PasswordInfoRepo, pws []PasswordInfo) error {
return nil return nil
} }
func NewPasswordInfoRepoFromPasswordInfos(pws []PasswordInfo) PasswordInfoRepo {
memRepo := NewPasswordInfoRepo().(*memPasswordInfoRepo)
for _, pw := range pws {
memRepo.pws[pw.UserID] = pw
}
return memRepo
}
func NewPasswordInfoRepoFromFile(loc string) (PasswordInfoRepo, error) {
pws, err := readPasswordInfosFromFile(loc)
if err != nil {
return nil, err
}
return NewPasswordInfoRepoFromPasswordInfos(pws), nil
}
func NewPasswordReset(userID string, password Password, issuer url.URL, clientID string, callback url.URL, expires time.Duration) PasswordReset { func NewPasswordReset(userID string, password Password, issuer url.URL, clientID string, callback url.URL, expires time.Duration) PasswordReset {
claims := oidc.NewClaims(issuer.String(), userID, clientID, clock.Now(), clock.Now().Add(expires)) claims := oidc.NewClaims(issuer.String(), userID, clientID, clock.Now(), clock.Now().Add(expires))
claims.Add(ClaimPasswordResetPassword, string(password)) claims.Add(ClaimPasswordResetPassword, string(password))

View file

@ -2,7 +2,6 @@ package user
import ( import (
"net/url" "net/url"
"strings"
"testing" "testing"
"time" "time"
@ -14,48 +13,6 @@ import (
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
) )
func TestNewPasswordInfosFromReader(t *testing.T) {
PasswordHasher = func(plaintext string) ([]byte, error) {
return []byte(strings.ToUpper(plaintext)), nil
}
defer func() {
PasswordHasher = DefaultPasswordHasher
}()
tests := []struct {
json string
want []PasswordInfo
}{
{
json: `[{"userId":"12345","passwordPlaintext":"password"},{"userId":"78901","passwordHash":"WORDPASS", "passwordExpires":"2006-01-01T15:04:05Z"}]`,
want: []PasswordInfo{
{
UserID: "12345",
Password: []byte("PASSWORD"),
},
{
UserID: "78901",
Password: []byte("WORDPASS"),
PasswordExpires: time.Date(2006,
1, 1, 15, 4, 5, 0, time.UTC),
},
},
},
}
for i, tt := range tests {
r := strings.NewReader(tt.json)
us, err := newPasswordInfosFromReader(r)
if err != nil {
t.Errorf("case %d: want nil err: %v", i, err)
continue
}
if diff := pretty.Compare(tt.want, us); diff != "" {
t.Errorf("case %d: Compare(want, got): %v", i, diff)
}
}
}
func TestNewPasswordFromHash(t *testing.T) { func TestNewPasswordFromHash(t *testing.T) {
tests := []string{ tests := []string{
"test", "test",

View file

@ -4,13 +4,10 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"time" "time"
"net/mail" "net/mail"
"net/url" "net/url"
"os"
"sort"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/pborman/uuid" "github.com/pborman/uuid"
@ -172,262 +169,11 @@ func ValidPassword(plaintext string) bool {
return len(plaintext) > 5 return len(plaintext) > 5
} }
// NewUserRepo returns an in-memory UserRepo useful for development.
func NewUserRepo() UserRepo {
return &memUserRepo{
usersByID: make(map[string]User),
userIDsByEmail: make(map[string]string),
userIDsByRemoteID: make(map[RemoteIdentity]string),
remoteIDsByUserID: make(map[string]map[RemoteIdentity]struct{}),
}
}
type memUserRepo struct {
usersByID map[string]User
userIDsByEmail map[string]string
userIDsByRemoteID map[RemoteIdentity]string
remoteIDsByUserID map[string]map[RemoteIdentity]struct{}
}
func (r *memUserRepo) Get(_ repo.Transaction, id string) (User, error) {
user, ok := r.usersByID[id]
if !ok {
return User{}, ErrorNotFound
}
return user, nil
}
type usersByEmail []User
func (s usersByEmail) Len() int { return len(s) }
func (s usersByEmail) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s usersByEmail) Less(i, j int) bool { return s[i].Email < s[j].Email }
func (r *memUserRepo) List(tx repo.Transaction, filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) {
var offset int
var err error
if nextPageToken != "" {
filter, maxResults, offset, err = DecodeNextPageToken(nextPageToken)
}
if err != nil {
return nil, "", err
}
users := []User{}
for _, usr := range r.usersByID {
users = append(users, usr)
}
sort.Sort(usersByEmail(users))
high := offset + maxResults
var tok string
if high >= len(users) {
high = len(users)
} else {
tok, err = EncodeNextPageToken(filter, maxResults, high)
}
if err != nil {
return nil, "", err
}
if len(users[offset:high]) == 0 {
return nil, "", ErrorNotFound
}
return users[offset:high], tok, nil
}
func (r *memUserRepo) GetByEmail(tx repo.Transaction, email string) (User, error) {
userID, ok := r.userIDsByEmail[email]
if !ok {
return User{}, ErrorNotFound
}
return r.Get(tx, userID)
}
func (r *memUserRepo) Create(_ repo.Transaction, user User) error {
if user.ID == "" {
return ErrorInvalidID
}
if !ValidEmail(user.Email) {
return ErrorInvalidEmail
}
// make sure no one has the same ID; if using UUID the chances of this
// happening are astronomically small.
_, ok := r.usersByID[user.ID]
if ok {
return ErrorDuplicateID
}
// make sure there's no other user with the same Email
_, ok = r.userIDsByEmail[user.Email]
if ok {
return ErrorDuplicateEmail
}
r.set(user)
return nil
}
func (r *memUserRepo) Update(_ repo.Transaction, user User) error {
if user.ID == "" {
return ErrorInvalidID
}
if !ValidEmail(user.Email) {
return ErrorInvalidEmail
}
// make sure this user exists already
_, ok := r.usersByID[user.ID]
if !ok {
return ErrorNotFound
}
// make sure there's no other user with the same Email
otherID, ok := r.userIDsByEmail[user.Email]
if ok && otherID != user.ID {
return ErrorDuplicateEmail
}
r.set(user)
return nil
}
func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error {
if id == "" {
return ErrorInvalidID
}
user, ok := r.usersByID[id]
if !ok {
return ErrorNotFound
}
user.Disabled = disable
r.set(user)
return nil
}
func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
_, ok := r.usersByID[userID]
if !ok {
return ErrorNotFound
}
_, ok = r.userIDsByRemoteID[ri]
if ok {
return ErrorDuplicateRemoteIdentity
}
r.userIDsByRemoteID[ri] = userID
rIDs, ok := r.remoteIDsByUserID[userID]
if !ok {
rIDs = make(map[RemoteIdentity]struct{})
r.remoteIDsByUserID[userID] = rIDs
}
rIDs[ri] = struct{}{}
return nil
}
func (r *memUserRepo) RemoveRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
otherID, ok := r.userIDsByRemoteID[ri]
if !ok {
return ErrorNotFound
}
if otherID != userID {
return ErrorNotFound
}
delete(r.userIDsByRemoteID, ri)
delete(r.remoteIDsByUserID[userID], ri)
return nil
}
func (r *memUserRepo) GetByRemoteIdentity(_ repo.Transaction, ri RemoteIdentity) (User, error) {
userID, ok := r.userIDsByRemoteID[ri]
if !ok {
return User{}, ErrorNotFound
}
user, ok := r.usersByID[userID]
if !ok {
return User{}, ErrorNotFound
}
return user, nil
}
func (r *memUserRepo) GetRemoteIdentities(_ repo.Transaction, userID string) ([]RemoteIdentity, error) {
ids := []RemoteIdentity{}
for id := range r.remoteIDsByUserID[userID] {
ids = append(ids, id)
}
return ids, nil
}
func (r *memUserRepo) GetAdminCount(_ repo.Transaction) (int, error) {
var i int
for _, usr := range r.usersByID {
if usr.Admin {
i++
}
}
return i, nil
}
func (r *memUserRepo) set(user User) error {
r.usersByID[user.ID] = user
r.userIDsByEmail[user.Email] = user.ID
return nil
}
type UserWithRemoteIdentities struct { type UserWithRemoteIdentities struct {
User User `json:"user"` User User `json:"user"`
RemoteIdentities []RemoteIdentity `json:"remoteIdentities"` RemoteIdentities []RemoteIdentity `json:"remoteIdentities"`
} }
// NewUserRepoFromFile returns an in-memory UserRepo useful for development given a JSON serialized file of Users.
func NewUserRepoFromFile(loc string) (UserRepo, error) {
us, err := readUsersFromFile(loc)
if err != nil {
return nil, err
}
return NewUserRepoFromUsers(us), nil
}
func NewUserRepoFromUsers(us []UserWithRemoteIdentities) UserRepo {
memUserRepo := NewUserRepo().(*memUserRepo)
for _, u := range us {
memUserRepo.set(u.User)
for _, ri := range u.RemoteIdentities {
memUserRepo.AddRemoteIdentity(nil, u.User.ID, ri)
}
}
return memUserRepo
}
func newUsersFromReader(r io.Reader) ([]UserWithRemoteIdentities, error) {
var us []UserWithRemoteIdentities
err := json.NewDecoder(r).Decode(&us)
return us, err
}
func readUsersFromFile(loc string) ([]UserWithRemoteIdentities, error) {
uf, err := os.Open(loc)
if err != nil {
return nil, fmt.Errorf("unable to read users from file %q: %v", loc, err)
}
defer uf.Close()
us, err := newUsersFromReader(uf)
if err != nil {
return nil, err
}
return us, err
}
func (u *User) UnmarshalJSON(data []byte) error { func (u *User) UnmarshalJSON(data []byte) error {
var dec struct { var dec struct {
ID string `json:"id"` ID string `json:"id"`

View file

@ -2,7 +2,6 @@ package user
import ( import (
"reflect" "reflect"
"strings"
"testing" "testing"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
@ -10,44 +9,6 @@ import (
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
) )
func TestNewUsersFromReader(t *testing.T) {
tests := []struct {
json string
want []UserWithRemoteIdentities
}{
{
json: `[{"user":{"id":"12345", "displayName": "Elroy Canis", "email":"elroy23@example.com"}, "remoteIdentities":[{"connectorID":"google", "id":"elroy@example.com"}] }]`,
want: []UserWithRemoteIdentities{
{
User: User{
ID: "12345",
DisplayName: "Elroy Canis",
Email: "elroy23@example.com",
},
RemoteIdentities: []RemoteIdentity{
{
ConnectorID: "google",
ID: "elroy@example.com",
},
},
},
},
},
}
for i, tt := range tests {
r := strings.NewReader(tt.json)
us, err := newUsersFromReader(r)
if err != nil {
t.Errorf("case %d: want nil err: %v", i, err)
continue
}
if diff := pretty.Compare(tt.want, us); diff != "" {
t.Errorf("case %d: Compare(want, got): %v", i, diff)
}
}
}
func TestAddToClaims(t *testing.T) { func TestAddToClaims(t *testing.T) {
tests := []struct { tests := []struct {
user User user User