storage/sql: add a SQL storage implementation

This change adds support for SQLite3, and Postgres.
This commit is contained in:
Eric Chiang 2016-09-14 18:11:57 -07:00 committed by Eric Chiang
parent 82a55cf785
commit 87a7d093b2
10 changed files with 1116 additions and 0 deletions

113
storage/sql/config.go Normal file
View File

@ -0,0 +1,113 @@
package sql
import (
"database/sql"
"fmt"
"net/url"
"strconv"
"github.com/coreos/dex/storage"
)
// SQLite3 options for creating an SQL db.
type SQLite3 struct {
// File to
File string `yaml:"file"`
}
// Open creates a new storage implementation backed by SQLite3
func (s *SQLite3) Open() (storage.Storage, error) {
return s.open()
}
func (s *SQLite3) open() (*conn, error) {
db, err := sql.Open("sqlite3", s.File)
if err != nil {
return nil, err
}
if s.File == ":memory:" {
// sqlite3 uses file locks to coordinate concurrent access. In memory
// doesn't support this, so limit the number of connections to 1.
db.SetMaxOpenConns(1)
}
c := &conn{db, flavorSQLite3}
if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err)
}
return c, nil
}
const (
sslDisable = "disable"
sslRequire = "require"
sslVerifyCA = "verify-ca"
sslVerifyFull = "verify-full"
)
// PostgresSSL represents SSL options for Postgres databases.
type PostgresSSL struct {
Mode string
CAFile string
// Files for client auth.
KeyFile string
CertFile string
}
// Postgres options for creating an SQL db.
type Postgres struct {
Database string
User string
Password string
Host string
SSL PostgresSSL `json:"ssl" yaml:"ssl"`
ConnectionTimeout int // Seconds
}
// Open creates a new storage implementation backed by Postgres.
func (p *Postgres) Open() (storage.Storage, error) {
return p.open()
}
func (p *Postgres) open() (*conn, error) {
v := url.Values{}
set := func(key, val string) {
if val != "" {
v.Set(key, val)
}
}
set("connect_timeout", strconv.Itoa(p.ConnectionTimeout))
set("sslkey", p.SSL.KeyFile)
set("sslcert", p.SSL.CertFile)
set("sslrootcert", p.SSL.CAFile)
if p.SSL.Mode == "" {
// Assume the strictest mode if unspecified.
p.SSL.Mode = sslVerifyFull
}
set("sslmode", p.SSL.Mode)
u := url.URL{
Scheme: "postgres",
Host: p.Host,
Path: "/" + p.Database,
RawQuery: v.Encode(),
}
if p.User != "" {
if p.Password != "" {
u.User = url.UserPassword(p.User, p.Password)
} else {
u.User = url.User(p.User)
}
}
db, err := sql.Open("postgres", u.String())
if err != nil {
return nil, err
}
c := &conn{db, flavorPostgres}
if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err)
}
return c, nil
}

View File

@ -0,0 +1 @@
package sql

487
storage/sql/crud.go Normal file
View File

@ -0,0 +1,487 @@
package sql
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"github.com/coreos/dex/storage"
)
// TODO(ericchiang): The update, insert, and select methods queries are all
// very repetivite. Consider creating them programatically.
// keysRowID is the ID of the only row we expect to populate the "keys" table.
const keysRowID = "keys"
// encoder wraps the underlying value in a JSON marshaler which is automatically
// called by the database/sql package.
//
// s := []string{"planes", "bears"}
// err := db.Exec(`insert into t1 (id, things) values (1, $1)`, encoder(s))
// if err != nil {
// // handle error
// }
//
// var r []byte
// err = db.QueryRow(`select things from t1 where id = 1;`).Scan(&r)
// if err != nil {
// // handle error
// }
// fmt.Printf("%s\n", r) // ["planes","bears"]
//
func encoder(i interface{}) driver.Valuer {
return jsonEncoder{i}
}
// decoder wraps the underlying value in a JSON unmarshaler which can then be passed
// to a database Scan() method.
func decoder(i interface{}) sql.Scanner {
return jsonDecoder{i}
}
type jsonEncoder struct {
i interface{}
}
func (j jsonEncoder) Value() (driver.Value, error) {
b, err := json.Marshal(j.i)
if err != nil {
return nil, fmt.Errorf("marshal: %v", err)
}
return b, nil
}
type jsonDecoder struct {
i interface{}
}
func (j jsonDecoder) Scan(dest interface{}) error {
if dest == nil {
return errors.New("nil value")
}
b, ok := dest.([]byte)
if !ok {
return fmt.Errorf("expected []byte got %T", dest)
}
if err := json.Unmarshal(b, &j.i); err != nil {
return fmt.Errorf("unmarshal: %v", err)
}
return nil
}
// Abstract conn vs trans.
type querier interface {
QueryRow(query string, args ...interface{}) *sql.Row
}
// Abstract row vs rows.
type scanner interface {
Scan(dest ...interface{}) error
}
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
_, err := c.Exec(`
insert into auth_request (
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data,
expiry
)
values (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
);
`,
a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn,
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData,
a.Expiry,
)
if err != nil {
return fmt.Errorf("insert auth request: %v", err)
}
return nil
}
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getAuthRequest(tx, id)
if err != nil {
return err
}
a, err := updater(r)
if err != nil {
return err
}
_, err = tx.Exec(`
update auth_request
set
client_id = $1, response_types = $2, scopes = $3, redirect_uri = $4,
nonce = $5, state = $6, force_approval_prompt = $7, logged_in = $8,
claims_user_id = $9, claims_username = $10, claims_email = $11,
claims_email_verified = $12,
claims_groups = $13,
connector_id = $14, connector_data = $15,
expiry = $16
where id = $17;
`,
a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn,
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData,
a.Expiry, a.ID,
)
if err != nil {
return fmt.Errorf("update auth request: %v", err)
}
return nil
})
}
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
return getAuthRequest(c, id)
}
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(`
select
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data, expiry
from auth_request where id = $1;
`, id).Scan(
&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State,
&a.ForceApprovalPrompt, &a.LoggedIn,
&a.Claims.UserID, &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified,
decoder(&a.Claims.Groups),
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select auth request: %v", err)
}
return a, nil
}
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
_, err := c.Exec(`
insert into auth_code (
id, client_id, scopes, nonce, redirect_uri,
claims_user_id, claims_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13);
`,
a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID,
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData, a.Expiry,
)
return err
}
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
err = c.QueryRow(`
select
id, client_id, scopes, nonce, redirect_uri,
claims_user_id, claims_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
from auth_code where id = $1;
`, id).Scan(
&a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID,
&a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups),
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select auth code: %v", err)
}
return a, nil
}
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
_, err := c.Exec(`
insert into refresh_token (
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data
)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);
`,
r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
)
if err != nil {
return fmt.Errorf("insert refresh_token: %v", err)
}
return nil
}
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
return scanRefresh(c.QueryRow(`
select
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data
from refresh_token where id = $1;
`, id))
}
func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
rows, err := c.Query(`
select
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data
from refresh_token;
`)
if err != nil {
return nil, fmt.Errorf("query: %v", err)
}
var tokens []storage.RefreshToken
for rows.Next() {
r, err := scanRefresh(rows)
if err != nil {
return nil, err
}
tokens = append(tokens, r)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %v", err)
}
return tokens, nil
}
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
err = s.Scan(
&r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
decoder(&r.Claims.Groups),
&r.ConnectorID, &r.ConnectorData,
)
if err != nil {
if err == sql.ErrNoRows {
return r, storage.ErrNotFound
}
return r, fmt.Errorf("scan refresh_token: %v", err)
}
return r, nil
}
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
return c.ExecTx(func(tx *trans) error {
firstUpdate := false
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
// server. Test this, and consider adding a COUNT() command beforehand.
old, err := getKeys(tx)
if err != nil {
if err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
}
firstUpdate = true
old = storage.Keys{}
}
nk, err := updater(old)
if err != nil {
return err
}
if firstUpdate {
_, err = tx.Exec(`
insert into keys (
id, verification_keys, signing_key, signing_key_pub, next_rotation
)
values ($1, $2, $3, $4, $5);
`,
keysRowID, encoder(nk.VerificationKeys), encoder(nk.SigningKey),
encoder(nk.SigningKeyPub), nk.NextRotation,
)
if err != nil {
return fmt.Errorf("insert: %v", err)
}
} else {
_, err = tx.Exec(`
update keys
set
verification_keys = $1,
signing_key = $2,
singing_key_pub = $3,
next_rotation = $4
where id = $5;
`,
encoder(nk.VerificationKeys), encoder(nk.SigningKey),
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
)
if err != nil {
return fmt.Errorf("update: %v", err)
}
}
return nil
})
}
func (c *conn) GetKeys() (keys storage.Keys, err error) {
return getKeys(c)
}
func getKeys(q querier) (keys storage.Keys, err error) {
err = q.QueryRow(`
select
verification_keys, signing_key, signing_key_pub, next_rotation
from keys
where id=$q
`, keysRowID).Scan(
decoder(&keys.VerificationKeys), decoder(&keys.SigningKey),
decoder(&keys.SigningKeyPub), &keys.NextRotation,
)
if err != nil {
if err == sql.ErrNoRows {
return keys, storage.ErrNotFound
}
return keys, fmt.Errorf("query keys: %v", err)
}
return keys, nil
}
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
return c.ExecTx(func(tx *trans) error {
cli, err := getClient(tx, id)
if err != nil {
return err
}
nc, err := updater(cli)
if err != nil {
return err
}
_, err = tx.Exec(`
update client
set
secret = $1,
redirect_uris = $2,
trusted_peers = $3,
public = $4,
name = $5,
logo_url = $6
where id = $7;
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
)
if err != nil {
return fmt.Errorf("update client: %v", err)
}
return nil
})
}
func (c *conn) CreateClient(cli storage.Client) error {
_, err := c.Exec(`
insert into client (
id, secret, redirect_uris, trusted_peers, public, name, logo_url
)
values ($1, $2, $3, $4, $5, $6, $7);
`,
cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers),
cli.Public, cli.Name, cli.LogoURL,
)
if err != nil {
return fmt.Errorf("insert client: %v", err)
}
return nil
}
func getClient(q querier, id string) (storage.Client, error) {
return scanClient(q.QueryRow(`
select
id, secret, redirect_uris, trusted_peers, public, name, logo_url
from client where id = $1;
`, id))
}
func (c *conn) GetClient(id string) (storage.Client, error) {
return getClient(c, id)
}
func (c *conn) ListClients() ([]storage.Client, error) {
rows, err := c.Query(`
select
id, secret, redirect_uris, trusted_peers, public, name, logo_url
from client;
`)
if err != nil {
return nil, err
}
var clients []storage.Client
for rows.Next() {
cli, err := scanClient(rows)
if err != nil {
return nil, err
}
clients = append(clients, cli)
}
if err := rows.Err(); err != nil {
return nil, err
}
return clients, nil
}
func scanClient(s scanner) (cli storage.Client, err error) {
err = s.Scan(
&cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers),
&cli.Public, &cli.Name, &cli.LogoURL,
)
if err != nil {
if err == sql.ErrNoRows {
return cli, storage.ErrNotFound
}
return cli, fmt.Errorf("get client: %v", err)
}
return cli, nil
}
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", id) }
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", id) }
func (c *conn) DeleteClient(id string) error { return c.delete("client", id) }
func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", id) }
// Do NOT call directly. Does not escape table.
func (c *conn) delete(table, id string) error {
result, err := c.Exec(`delete from `+table+` where id = $1`, id)
if err != nil {
return fmt.Errorf("delete %s: %v", table, id)
}
// For now mandate that the driver implements RowsAffected. If we ever need to support
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %v", err)
}
if n < 1 {
return storage.ErrNotFound
}
return nil
}

55
storage/sql/crud_test.go Normal file
View File

@ -0,0 +1,55 @@
package sql
import (
"database/sql"
"reflect"
"testing"
)
func TestDecoder(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?);`, []byte(`["a", "b"]`)); err != nil {
t.Fatal(err)
}
var got []string
if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(decoder(&got)); err != nil {
t.Fatal(err)
}
want := []string{"a", "b"}
if !reflect.DeepEqual(got, want) {
t.Errorf("wanted %q got %q", want, got)
}
}
func TestEncoder(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil {
t.Fatal(err)
}
put := []string{"a", "b"}
if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?)`, encoder(put)); err != nil {
t.Fatal(err)
}
var got []byte
if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(&got); err != nil {
t.Fatal(err)
}
want := []byte(`["a","b"]`)
if !reflect.DeepEqual(got, want) {
t.Errorf("wanted %q got %q", want, got)
}
}

24
storage/sql/gc.go Normal file
View File

@ -0,0 +1,24 @@
package sql
import (
"fmt"
"time"
)
type gc struct {
now func() time.Time
conn *conn
}
var tablesWithGC = []string{"auth_request", "auth_code"}
func (gc gc) run() error {
for _, table := range tablesWithGC {
_, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now())
if err != nil {
return fmt.Errorf("gc %s: %v", table, err)
}
// TODO(ericchiang): when we have levelled logging print how many rows were gc'd
}
return nil
}

53
storage/sql/gc_test.go Normal file
View File

@ -0,0 +1,53 @@
package sql
import (
"testing"
"time"
"github.com/coreos/dex/storage"
)
func TestGC(t *testing.T) {
// TODO(ericchiang): Add a GarbageCollect method to the storage interface so
// we can write conformance tests instead of directly testing each implementation.
s := &SQLite3{":memory:"}
conn, err := s.open()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
clock := time.Now()
now := func() time.Time { return clock }
runGC := (gc{now, conn}).run
a := storage.AuthRequest{
ID: storage.NewID(),
Expiry: now().Add(time.Second),
}
if err := conn.CreateAuthRequest(a); err != nil {
t.Fatal(err)
}
if err := runGC(); err != nil {
t.Errorf("gc failed: %v", err)
}
if _, err := conn.GetAuthRequest(a.ID); err != nil {
t.Errorf("failed to get auth request after gc: %v", err)
}
clock = clock.Add(time.Minute)
if err := runGC(); err != nil {
t.Errorf("gc failed: %v", err)
}
if _, err := conn.GetAuthRequest(a.ID); err == nil {
t.Errorf("expected error after gc'ing auth request: %v", err)
} else if err != storage.ErrNotFound {
t.Errorf("expected error storage.NotFound got: %v", err)
}
}

151
storage/sql/migrate.go Normal file
View File

@ -0,0 +1,151 @@
package sql
import (
"database/sql"
"fmt"
)
func (c *conn) migrate() (int, error) {
_, err := c.Exec(`
create table if not exists migrations (
num integer not null,
at timestamp not null
);
`)
if err != nil {
return 0, fmt.Errorf("creating migration table: %v", err)
}
i := 0
done := false
for {
err := c.ExecTx(func(tx *trans) error {
// Within a transaction, perform a single migration.
var (
num sql.NullInt64
n int
)
if err := tx.QueryRow(`select max(num) from migrations;`).Scan(&num); err != nil {
return fmt.Errorf("select max migration: %v", err)
}
if num.Valid {
n = int(num.Int64)
}
if n >= len(migrations) {
done = true
return nil
}
migrationNum := n + 1
m := migrations[n]
if _, err := tx.Exec(m.stmt); err != nil {
return fmt.Errorf("migration %d failed: %v", migrationNum, err)
}
q := `insert into migrations (num, at) values ($1, now());`
if _, err := tx.Exec(q, migrationNum); err != nil {
return fmt.Errorf("update migration table: %v", err)
}
return nil
})
if err != nil {
return i, err
}
if done {
break
}
i++
}
return i, nil
}
type migration struct {
stmt string
// TODO(ericchiang): consider adding additional fields like "forDrivers"
}
// All SQL flavors share migration strategies.
var migrations = []migration{
{
stmt: `
create table client (
id text not null primary key,
secret text not null,
redirect_uris bytea not null, -- JSON array of strings
trusted_peers bytea not null, -- JSON array of strings
public boolean not null,
name text not null,
logo_url text not null
);
create table auth_request (
id text not null primary key,
client_id text not null,
response_types bytea not null, -- JSON array of strings
scopes bytea not null, -- JSON array of strings
redirect_uri text not null,
nonce text not null,
state text not null,
force_approval_prompt boolean not null,
logged_in boolean not null,
claims_user_id text not null,
claims_username text not null,
claims_email text not null,
claims_email_verified boolean not null,
claims_groups bytea not null, -- JSON array of strings
connector_id text not null,
connector_data bytea,
expiry timestamp not null
);
create table auth_code (
id text not null primary key,
client_id text not null,
scopes bytea not null, -- JSON array of strings
nonce text not null,
redirect_uri text not null,
claims_user_id text not null,
claims_username text not null,
claims_email text not null,
claims_email_verified boolean not null,
claims_groups bytea not null, -- JSON array of strings
connector_id text not null,
connector_data bytea,
expiry timestamp not null
);
create table refresh_token (
id text not null primary key,
client_id text not null,
scopes bytea not null, -- JSON array of strings
nonce text not null,
claims_user_id text not null,
claims_username text not null,
claims_email text not null,
claims_email_verified boolean not null,
claims_groups bytea not null, -- JSON array of strings
connector_id text not null,
connector_data bytea
);
-- keys is a weird table because we only ever expect there to be a single row
create table keys (
id text not null primary key,
verification_keys bytea not null, -- JSON array
signing_key bytea not null, -- JSON object
signing_key_pub bytea not null, -- JSON object
next_rotation timestamp not null
);
`,
},
}

View File

@ -0,0 +1,25 @@
package sql
import (
"database/sql"
"testing"
)
func TestMigrate(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
c := &conn{db, flavorSQLite3}
for _, want := range []int{len(migrations), 0} {
got, err := c.migrate()
if err != nil {
t.Fatal(err)
}
if got != want {
t.Errorf("expected %d migrations, got %d", want, got)
}
}
}

152
storage/sql/sql.go Normal file
View File

@ -0,0 +1,152 @@
// Package sql provides SQL implementations of the storage interface.
package sql
import (
"database/sql"
"regexp"
"github.com/cockroachdb/cockroach-go/crdb"
// import third party drivers
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
// flavor represents a specific SQL implementation, and is used to translate query strings
// between different drivers. Flavors shouldn't aim to translate all possible SQL statements,
// only the specific queries used by the SQL storages.
type flavor struct {
queryReplacers []replacer
// Optional function to create and finish a transaction. This is mainly for
// cockroachdb support which requires special retry logic provided by their
// client package.
//
// This will be nil for most flavors.
//
// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44
executeTx func(db *sql.DB, fn func(*sql.Tx) error) error
}
// A regexp with a replacement string.
type replacer struct {
re *regexp.Regexp
with string
}
// Match a postgres query binds. E.g. "$1", "$12", etc.
var bindRegexp = regexp.MustCompile(`\$\d+`)
func matchLiteral(s string) *regexp.Regexp {
return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`)
}
var (
// The "github.com/lib/pq" driver is the default flavor. All others are
// translations of this.
flavorPostgres = flavor{}
flavorSQLite3 = flavor{
queryReplacers: []replacer{
{bindRegexp, "?"},
// Translate for booleans to integers.
{matchLiteral("true"), "1"},
{matchLiteral("false"), "0"},
{matchLiteral("boolean"), "integer"},
// Translate other types.
{matchLiteral("bytea"), "blob"},
// {matchLiteral("timestamp"), "integer"},
// SQLite doesn't have a "now()" method, replace with "date('now')"
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
},
}
// Incomplete.
flavorMySQL = flavor{
queryReplacers: []replacer{
{bindRegexp, "?"},
},
}
// Not tested.
flavorCockroach = flavor{
executeTx: crdb.ExecuteTx,
}
)
func (f flavor) translate(query string) string {
// TODO(ericchiang): Heavy cashing.
for _, r := range f.queryReplacers {
query = r.re.ReplaceAllString(query, r.with)
}
return query
}
// conn is the main database connection.
type conn struct {
db *sql.DB
flavor flavor
}
func (c *conn) Close() error {
return c.db.Close()
}
// conn implements the same method signatures as encoding/sql.DB.
func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) {
query = c.flavor.translate(query)
return c.db.Exec(query, args...)
}
func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = c.flavor.translate(query)
return c.db.Query(query, args...)
}
func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
query = c.flavor.translate(query)
return c.db.QueryRow(query, args...)
}
// ExecTx runs a method which operates on a transaction.
func (c *conn) ExecTx(fn func(tx *trans) error) error {
if c.flavor.executeTx != nil {
return c.flavor.executeTx(c.db, func(sqlTx *sql.Tx) error {
return fn(&trans{sqlTx, c})
})
}
sqlTx, err := c.db.Begin()
if err != nil {
return err
}
if err := fn(&trans{sqlTx, c}); err != nil {
sqlTx.Rollback()
return err
}
return sqlTx.Commit()
}
type trans struct {
tx *sql.Tx
c *conn
}
// trans implements the same method signatures as encoding/sql.Tx.
func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) {
query = t.c.flavor.translate(query)
return t.tx.Exec(query, args...)
}
func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = t.c.flavor.translate(query)
return t.tx.Query(query, args...)
}
func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row {
query = t.c.flavor.translate(query)
return t.tx.QueryRow(query, args...)
}

55
storage/sql/sql_test.go Normal file
View File

@ -0,0 +1,55 @@
package sql
import "testing"
func TestTranslate(t *testing.T) {
tests := []struct {
testCase string
flavor flavor
query string
exp string
}{
{
"sqlite3 query bind replacement",
flavorSQLite3,
`select foo from bar where foo.zam = $1;`,
`select foo from bar where foo.zam = ?;`,
},
{
"sqlite3 query bind replacement at newline",
flavorSQLite3,
`select foo from bar where foo.zam = $1`,
`select foo from bar where foo.zam = ?`,
},
{
"sqlite3 query true",
flavorSQLite3,
`select foo from bar where foo.zam = true`,
`select foo from bar where foo.zam = 1`,
},
{
"sqlite3 query false",
flavorSQLite3,
`select foo from bar where foo.zam = false`,
`select foo from bar where foo.zam = 0`,
},
{
"sqlite3 bytea",
flavorSQLite3,
`"connector_data" bytea not null,`,
`"connector_data" blob not null,`,
},
{
"sqlite3 now",
flavorSQLite3,
`now(),`,
`date('now'),`,
},
}
for _, tc := range tests {
if got := tc.flavor.translate(tc.query); got != tc.exp {
t.Errorf("%s: want=%q, got=%q", tc.testCase, tc.exp, got)
}
}
}