forked from mystiq/dex
feat: Add ent-based sqlite3 storage
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
parent
674631c9ab
commit
11859166d0
31 changed files with 1878 additions and 4 deletions
7
Makefile
7
Makefile
|
@ -26,7 +26,10 @@ PROTOC_VERSION = 3.15.6
|
||||||
PROTOC_GEN_GO_VERSION = 1.26.0
|
PROTOC_GEN_GO_VERSION = 1.26.0
|
||||||
PROTOC_GEN_GO_GRPC_VERSION = 1.1.0
|
PROTOC_GEN_GO_GRPC_VERSION = 1.1.0
|
||||||
|
|
||||||
build: bin/dex
|
generate:
|
||||||
|
@go generate $(REPO_PATH)/storage/ent/
|
||||||
|
|
||||||
|
build: generate bin/dex
|
||||||
|
|
||||||
bin/dex:
|
bin/dex:
|
||||||
@mkdir -p bin/
|
@mkdir -p bin/
|
||||||
|
@ -42,7 +45,7 @@ bin/example-app:
|
||||||
@mkdir -p bin/
|
@mkdir -p bin/
|
||||||
@cd examples/ && go install -v -ldflags $(LD_FLAGS) $(REPO_PATH)/examples/example-app
|
@cd examples/ && go install -v -ldflags $(LD_FLAGS) $(REPO_PATH)/examples/example-app
|
||||||
|
|
||||||
.PHONY: release-binary
|
.PHONY: generate release-binary
|
||||||
release-binary:
|
release-binary:
|
||||||
@go build -o /go/bin/dex -v -ldflags $(LD_FLAGS) $(REPO_PATH)/cmd/dex
|
@go build -o /go/bin/dex -v -ldflags $(LD_FLAGS) $(REPO_PATH)/cmd/dex
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/dexidp/dex/pkg/log"
|
"github.com/dexidp/dex/pkg/log"
|
||||||
"github.com/dexidp/dex/server"
|
"github.com/dexidp/dex/server"
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/ent"
|
||||||
"github.com/dexidp/dex/storage/etcd"
|
"github.com/dexidp/dex/storage/etcd"
|
||||||
"github.com/dexidp/dex/storage/kubernetes"
|
"github.com/dexidp/dex/storage/kubernetes"
|
||||||
"github.com/dexidp/dex/storage/memory"
|
"github.com/dexidp/dex/storage/memory"
|
||||||
|
@ -173,13 +174,32 @@ type StorageConfig interface {
|
||||||
Open(logger log.Logger) (storage.Storage, error)
|
Open(logger log.Logger) (storage.Storage, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ StorageConfig = (*etcd.Etcd)(nil)
|
||||||
|
_ StorageConfig = (*kubernetes.Config)(nil)
|
||||||
|
_ StorageConfig = (*memory.Config)(nil)
|
||||||
|
_ StorageConfig = (*sql.SQLite3)(nil)
|
||||||
|
_ StorageConfig = (*sql.Postgres)(nil)
|
||||||
|
_ StorageConfig = (*sql.MySQL)(nil)
|
||||||
|
_ StorageConfig = (*ent.SQLite3)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
func getORMBasedSQLiteStorage() StorageConfig {
|
||||||
|
switch os.Getenv("DEX_ENT_ENABLED") {
|
||||||
|
case "true", "yes":
|
||||||
|
return new(ent.SQLite3)
|
||||||
|
default:
|
||||||
|
return new(sql.SQLite3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var storages = map[string]func() StorageConfig{
|
var storages = map[string]func() StorageConfig{
|
||||||
"etcd": func() StorageConfig { return new(etcd.Etcd) },
|
"etcd": func() StorageConfig { return new(etcd.Etcd) },
|
||||||
"kubernetes": func() StorageConfig { return new(kubernetes.Config) },
|
"kubernetes": func() StorageConfig { return new(kubernetes.Config) },
|
||||||
"memory": func() StorageConfig { return new(memory.Config) },
|
"memory": func() StorageConfig { return new(memory.Config) },
|
||||||
"sqlite3": func() StorageConfig { return new(sql.SQLite3) },
|
|
||||||
"postgres": func() StorageConfig { return new(sql.Postgres) },
|
"postgres": func() StorageConfig { return new(sql.Postgres) },
|
||||||
"mysql": func() StorageConfig { return new(sql.MySQL) },
|
"mysql": func() StorageConfig { return new(sql.MySQL) },
|
||||||
|
"sqlite3": getORMBasedSQLiteStorage,
|
||||||
}
|
}
|
||||||
|
|
||||||
// isExpandEnvEnabled returns if os.ExpandEnv should be used for each storage and connector config.
|
// isExpandEnvEnabled returns if os.ExpandEnv should be used for each storage and connector config.
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -7,7 +7,8 @@ require (
|
||||||
github.com/beevik/etree v1.1.0
|
github.com/beevik/etree v1.1.0
|
||||||
github.com/coreos/go-oidc/v3 v3.0.0
|
github.com/coreos/go-oidc/v3 v3.0.0
|
||||||
github.com/dexidp/dex/api/v2 v2.0.0
|
github.com/dexidp/dex/api/v2 v2.0.0
|
||||||
github.com/felixge/httpsnoop v1.0.2
|
github.com/facebook/ent v0.5.3
|
||||||
|
github.com/felixge/httpsnoop v1.0.1
|
||||||
github.com/ghodss/yaml v1.0.0
|
github.com/ghodss/yaml v1.0.0
|
||||||
github.com/go-ldap/ldap/v3 v3.3.0
|
github.com/go-ldap/ldap/v3 v3.3.0
|
||||||
github.com/go-sql-driver/mysql v1.6.0
|
github.com/go-sql-driver/mysql v1.6.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -128,6 +128,8 @@ github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5y
|
||||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||||
|
github.com/facebook/ent v0.5.3 h1:YT3Sl28n7gGGOkQeYgeJsZmizJ1Iiy7psgkOtEk0aq4=
|
||||||
|
github.com/facebook/ent v0.5.3/go.mod h1:tlWP+qCd3x2EeO7B/EqlJQ4dWu/2IeYFhP/szzDKAi8=
|
||||||
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
|
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
|
||||||
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/felixge/httpsnoop v1.0.2 h1:+nS9g82KMXccJ/wp0zyRW9ZBHFETmMGtkk+2CTTrW4o=
|
github.com/felixge/httpsnoop v1.0.2 h1:+nS9g82KMXccJ/wp0zyRW9ZBHFETmMGtkk+2CTTrW4o=
|
||||||
|
|
52
storage/ent/client/authcode.go
Normal file
52
storage/ent/client/authcode.go
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateAuthCode saves provided auth code into the database.
|
||||||
|
func (d *Database) CreateAuthCode(code storage.AuthCode) error {
|
||||||
|
_, err := d.client.AuthCode.Create().
|
||||||
|
SetID(code.ID).
|
||||||
|
SetClientID(code.ClientID).
|
||||||
|
SetScopes(code.Scopes).
|
||||||
|
SetRedirectURI(code.RedirectURI).
|
||||||
|
SetNonce(code.Nonce).
|
||||||
|
SetClaimsUserID(code.Claims.UserID).
|
||||||
|
SetClaimsEmail(code.Claims.Email).
|
||||||
|
SetClaimsEmailVerified(code.Claims.EmailVerified).
|
||||||
|
SetClaimsUsername(code.Claims.Username).
|
||||||
|
SetClaimsPreferredUsername(code.Claims.PreferredUsername).
|
||||||
|
SetClaimsGroups(code.Claims.Groups).
|
||||||
|
SetCodeChallenge(code.PKCE.CodeChallenge).
|
||||||
|
SetCodeChallengeMethod(code.PKCE.CodeChallengeMethod).
|
||||||
|
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||||
|
SetExpiry(code.Expiry.UTC()).
|
||||||
|
SetConnectorID(code.ConnectorID).
|
||||||
|
SetConnectorData(code.ConnectorData).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("create auth code: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthCode extracts an auth code from the database by id.
|
||||||
|
func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) {
|
||||||
|
authCode, err := d.client.AuthCode.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return storage.AuthCode{}, convertDBError("get auth code: %w", err)
|
||||||
|
}
|
||||||
|
return toStorageAuthCode(authCode), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAuthCode deletes an auth code from the database by id.
|
||||||
|
func (d *Database) DeleteAuthCode(id string) error {
|
||||||
|
err := d.client.AuthCode.DeleteOneID(id).Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("delete auth code: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
107
storage/ent/client/authrequest.go
Normal file
107
storage/ent/client/authrequest.go
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateAuthRequest saves provided auth request into the database.
|
||||||
|
func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error {
|
||||||
|
_, err := d.client.AuthRequest.Create().
|
||||||
|
SetID(authRequest.ID).
|
||||||
|
SetClientID(authRequest.ClientID).
|
||||||
|
SetScopes(authRequest.Scopes).
|
||||||
|
SetResponseTypes(authRequest.ResponseTypes).
|
||||||
|
SetRedirectURI(authRequest.RedirectURI).
|
||||||
|
SetState(authRequest.State).
|
||||||
|
SetNonce(authRequest.Nonce).
|
||||||
|
SetForceApprovalPrompt(authRequest.ForceApprovalPrompt).
|
||||||
|
SetLoggedIn(authRequest.LoggedIn).
|
||||||
|
SetClaimsUserID(authRequest.Claims.UserID).
|
||||||
|
SetClaimsEmail(authRequest.Claims.Email).
|
||||||
|
SetClaimsEmailVerified(authRequest.Claims.EmailVerified).
|
||||||
|
SetClaimsUsername(authRequest.Claims.Username).
|
||||||
|
SetClaimsPreferredUsername(authRequest.Claims.PreferredUsername).
|
||||||
|
SetClaimsGroups(authRequest.Claims.Groups).
|
||||||
|
SetCodeChallenge(authRequest.PKCE.CodeChallenge).
|
||||||
|
SetCodeChallengeMethod(authRequest.PKCE.CodeChallengeMethod).
|
||||||
|
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||||
|
SetExpiry(authRequest.Expiry.UTC()).
|
||||||
|
SetConnectorID(authRequest.ConnectorID).
|
||||||
|
SetConnectorData(authRequest.ConnectorData).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("create auth request: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthRequest extracts an auth request from the database by id.
|
||||||
|
func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||||
|
authRequest, err := d.client.AuthRequest.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return storage.AuthRequest{}, convertDBError("get auth request: %w", err)
|
||||||
|
}
|
||||||
|
return toStorageAuthRequest(authRequest), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAuthRequest deletes an auth request from the database by id.
|
||||||
|
func (d *Database) DeleteAuthRequest(id string) error {
|
||||||
|
err := d.client.AuthRequest.DeleteOneID(id).Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("delete auth request: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database.
|
||||||
|
func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error {
|
||||||
|
tx, err := d.client.Tx(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update auth request tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authRequest, err := tx.AuthRequest.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update auth request database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newAuthRequest, err := updater(toStorageAuthRequest(authRequest))
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update auth request updating: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.AuthRequest.UpdateOneID(newAuthRequest.ID).
|
||||||
|
SetClientID(newAuthRequest.ClientID).
|
||||||
|
SetScopes(newAuthRequest.Scopes).
|
||||||
|
SetResponseTypes(newAuthRequest.ResponseTypes).
|
||||||
|
SetRedirectURI(newAuthRequest.RedirectURI).
|
||||||
|
SetState(newAuthRequest.State).
|
||||||
|
SetNonce(newAuthRequest.Nonce).
|
||||||
|
SetForceApprovalPrompt(newAuthRequest.ForceApprovalPrompt).
|
||||||
|
SetLoggedIn(newAuthRequest.LoggedIn).
|
||||||
|
SetClaimsUserID(newAuthRequest.Claims.UserID).
|
||||||
|
SetClaimsEmail(newAuthRequest.Claims.Email).
|
||||||
|
SetClaimsEmailVerified(newAuthRequest.Claims.EmailVerified).
|
||||||
|
SetClaimsUsername(newAuthRequest.Claims.Username).
|
||||||
|
SetClaimsPreferredUsername(newAuthRequest.Claims.PreferredUsername).
|
||||||
|
SetClaimsGroups(newAuthRequest.Claims.Groups).
|
||||||
|
SetCodeChallenge(newAuthRequest.PKCE.CodeChallenge).
|
||||||
|
SetCodeChallengeMethod(newAuthRequest.PKCE.CodeChallengeMethod).
|
||||||
|
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||||
|
SetExpiry(newAuthRequest.Expiry.UTC()).
|
||||||
|
SetConnectorID(newAuthRequest.ConnectorID).
|
||||||
|
SetConnectorData(newAuthRequest.ConnectorData).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update auth request uploading: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return rollback(tx, "update auth request commit: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
92
storage/ent/client/client.go
Normal file
92
storage/ent/client/client.go
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateClient saves provided oauth2 client settings into the database.
|
||||||
|
func (d *Database) CreateClient(client storage.Client) error {
|
||||||
|
_, err := d.client.OAuth2Client.Create().
|
||||||
|
SetID(client.ID).
|
||||||
|
SetName(client.Name).
|
||||||
|
SetSecret(client.Secret).
|
||||||
|
SetPublic(client.Public).
|
||||||
|
SetLogoURL(client.LogoURL).
|
||||||
|
SetRedirectUris(client.RedirectURIs).
|
||||||
|
SetTrustedPeers(client.TrustedPeers).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("create oauth2 client: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListClients extracts an array of oauth2 clients from the database.
|
||||||
|
func (d *Database) ListClients() ([]storage.Client, error) {
|
||||||
|
clients, err := d.client.OAuth2Client.Query().All(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return nil, convertDBError("list clients: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
storageClients := make([]storage.Client, 0, len(clients))
|
||||||
|
for _, c := range clients {
|
||||||
|
storageClients = append(storageClients, toStorageClient(c))
|
||||||
|
}
|
||||||
|
return storageClients, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient extracts an oauth2 client from the database by id.
|
||||||
|
func (d *Database) GetClient(id string) (storage.Client, error) {
|
||||||
|
client, err := d.client.OAuth2Client.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return storage.Client{}, convertDBError("get client: %w", err)
|
||||||
|
}
|
||||||
|
return toStorageClient(client), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteClient deletes an oauth2 client from the database by id.
|
||||||
|
func (d *Database) DeleteClient(id string) error {
|
||||||
|
err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("delete client: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateClient changes an oauth2 client by id using an updater function and saves it to the database.
|
||||||
|
func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||||
|
tx, err := d.client.Tx(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("update client tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := tx.OAuth2Client.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update client database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newClient, err := updater(toStorageClient(client))
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update client updating: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.OAuth2Client.UpdateOneID(newClient.ID).
|
||||||
|
SetName(newClient.Name).
|
||||||
|
SetSecret(newClient.Secret).
|
||||||
|
SetPublic(newClient.Public).
|
||||||
|
SetLogoURL(newClient.LogoURL).
|
||||||
|
SetRedirectUris(newClient.RedirectURIs).
|
||||||
|
SetTrustedPeers(newClient.TrustedPeers).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update client uploading: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return rollback(tx, "update auth request commit: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
88
storage/ent/client/connector.go
Normal file
88
storage/ent/client/connector.go
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateConnector saves a connector into the database.
|
||||||
|
func (d *Database) CreateConnector(connector storage.Connector) error {
|
||||||
|
_, err := d.client.Connector.Create().
|
||||||
|
SetID(connector.ID).
|
||||||
|
SetName(connector.Name).
|
||||||
|
SetType(connector.Type).
|
||||||
|
SetResourceVersion(connector.ResourceVersion).
|
||||||
|
SetConfig(connector.Config).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("create connector: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListConnectors extracts an array of connectors from the database.
|
||||||
|
func (d *Database) ListConnectors() ([]storage.Connector, error) {
|
||||||
|
connectors, err := d.client.Connector.Query().All(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return nil, convertDBError("list connectors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
storageConnectors := make([]storage.Connector, 0, len(connectors))
|
||||||
|
for _, c := range connectors {
|
||||||
|
storageConnectors = append(storageConnectors, toStorageConnector(c))
|
||||||
|
}
|
||||||
|
return storageConnectors, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnector extracts a connector from the database by id.
|
||||||
|
func (d *Database) GetConnector(id string) (storage.Connector, error) {
|
||||||
|
connector, err := d.client.Connector.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return storage.Connector{}, convertDBError("get connector: %w", err)
|
||||||
|
}
|
||||||
|
return toStorageConnector(connector), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteConnector deletes a connector from the database by id.
|
||||||
|
func (d *Database) DeleteConnector(id string) error {
|
||||||
|
err := d.client.Connector.DeleteOneID(id).Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("delete connector: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConnector changes a connector by id using an updater function and saves it to the database.
|
||||||
|
func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error {
|
||||||
|
tx, err := d.client.Tx(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("update connector tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, err := tx.Connector.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update connector database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newConnector, err := updater(toStorageConnector(connector))
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update connector updating: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Connector.UpdateOneID(newConnector.ID).
|
||||||
|
SetName(newConnector.Name).
|
||||||
|
SetType(newConnector.Type).
|
||||||
|
SetResourceVersion(newConnector.ResourceVersion).
|
||||||
|
SetConfig(newConnector.Config).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update connector uploading: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return rollback(tx, "update connector commit: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
36
storage/ent/client/devicerequest.go
Normal file
36
storage/ent/client/devicerequest.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db/devicerequest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateDeviceRequest saves provided device request into the database.
|
||||||
|
func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error {
|
||||||
|
_, err := d.client.DeviceRequest.Create().
|
||||||
|
SetClientID(request.ClientID).
|
||||||
|
SetClientSecret(request.ClientSecret).
|
||||||
|
SetScopes(request.Scopes).
|
||||||
|
SetUserCode(request.UserCode).
|
||||||
|
SetDeviceCode(request.DeviceCode).
|
||||||
|
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||||
|
SetExpiry(request.Expiry.UTC()).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("create device request: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceRequest extracts a device request from the database by user code.
|
||||||
|
func (d *Database) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||||
|
deviceRequest, err := d.client.DeviceRequest.Query().
|
||||||
|
Where(devicerequest.UserCode(userCode)).
|
||||||
|
Only(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return storage.DeviceRequest{}, convertDBError("get device request: %w", err)
|
||||||
|
}
|
||||||
|
return toStorageDeviceRequest(deviceRequest), nil
|
||||||
|
}
|
76
storage/ent/client/devicetoken.go
Normal file
76
storage/ent/client/devicetoken.go
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db/devicetoken"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateDeviceToken saves provided token into the database.
|
||||||
|
func (d *Database) CreateDeviceToken(token storage.DeviceToken) error {
|
||||||
|
_, err := d.client.DeviceToken.Create().
|
||||||
|
SetDeviceCode(token.DeviceCode).
|
||||||
|
SetToken([]byte(token.Token)).
|
||||||
|
SetPollInterval(token.PollIntervalSeconds).
|
||||||
|
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||||
|
SetExpiry(token.Expiry.UTC()).
|
||||||
|
SetLastRequest(token.LastRequestTime.UTC()).
|
||||||
|
SetStatus(token.Status).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("create device token: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceToken extracts a token from the database by device code.
|
||||||
|
func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||||
|
deviceToken, err := d.client.DeviceToken.Query().
|
||||||
|
Where(devicetoken.DeviceCode(deviceCode)).
|
||||||
|
Only(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return storage.DeviceToken{}, convertDBError("get device token: %w", err)
|
||||||
|
}
|
||||||
|
return toStorageDeviceToken(deviceToken), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDeviceToken changes a token by device code using an updater function and saves it to the database.
|
||||||
|
func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||||
|
tx, err := d.client.Tx(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("update device token tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := tx.DeviceToken.Query().
|
||||||
|
Where(devicetoken.DeviceCode(deviceCode)).
|
||||||
|
Only(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update device token database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newToken, err := updater(toStorageDeviceToken(token))
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update device token updating: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.DeviceToken.Update().
|
||||||
|
Where(devicetoken.DeviceCode(newToken.DeviceCode)).
|
||||||
|
SetDeviceCode(newToken.DeviceCode).
|
||||||
|
SetToken([]byte(newToken.Token)).
|
||||||
|
SetPollInterval(newToken.PollIntervalSeconds).
|
||||||
|
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||||
|
SetExpiry(newToken.Expiry.UTC()).
|
||||||
|
SetLastRequest(newToken.LastRequestTime.UTC()).
|
||||||
|
SetStatus(newToken.Status).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update device token uploading: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return rollback(tx, "update device token commit: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
81
storage/ent/client/keys.go
Normal file
81
storage/ent/client/keys.go
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getKeys(client *db.KeysClient) (storage.Keys, error) {
|
||||||
|
rawKeys, err := client.Get(context.TODO(), keysRowID)
|
||||||
|
if err != nil {
|
||||||
|
return storage.Keys{}, convertDBError("get keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return toStorageKeys(rawKeys), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKeys returns signing keys, public keys and verification keys from the database.
|
||||||
|
func (d *Database) GetKeys() (storage.Keys, error) {
|
||||||
|
return getKeys(d.client.Keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateKeys rotates keys using updater function.
|
||||||
|
func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
||||||
|
firstUpdate := false
|
||||||
|
|
||||||
|
tx, err := d.client.Tx(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("update keys tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
storageKeys, err := getKeys(tx.Keys)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, storage.ErrNotFound) {
|
||||||
|
return rollback(tx, "update keys get: %w", err)
|
||||||
|
}
|
||||||
|
firstUpdate = true
|
||||||
|
}
|
||||||
|
|
||||||
|
newKeys, err := updater(storageKeys)
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update keys updating: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ent doesn't have an upsert support yet
|
||||||
|
// https://github.com/facebook/ent/issues/139
|
||||||
|
if firstUpdate {
|
||||||
|
_, err = tx.Keys.Create().
|
||||||
|
SetID(keysRowID).
|
||||||
|
SetNextRotation(newKeys.NextRotation).
|
||||||
|
SetSigningKey(*newKeys.SigningKey).
|
||||||
|
SetSigningKeyPub(*newKeys.SigningKeyPub).
|
||||||
|
SetVerificationKeys(newKeys.VerificationKeys).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "create keys: %w", err)
|
||||||
|
}
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return rollback(tx, "update keys commit: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Keys.UpdateOneID(keysRowID).
|
||||||
|
SetNextRotation(newKeys.NextRotation.UTC()).
|
||||||
|
SetSigningKey(*newKeys.SigningKey).
|
||||||
|
SetSigningKeyPub(*newKeys.SigningKeyPub).
|
||||||
|
SetVerificationKeys(newKeys.VerificationKeys).
|
||||||
|
Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update keys uploading: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return rollback(tx, "update keys commit: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
95
storage/ent/client/main.go
Normal file
95
storage/ent/client/main.go
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"hash"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db/authcode"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db/authrequest"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db/devicerequest"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db/devicetoken"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db/migrate"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ storage.Storage = (*Database)(nil)
|
||||||
|
|
||||||
|
type Database struct {
|
||||||
|
client *db.Client
|
||||||
|
hasher func() hash.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabase returns new database client with set options.
|
||||||
|
func NewDatabase(opts ...func(*Database)) *Database {
|
||||||
|
database := &Database{}
|
||||||
|
for _, f := range opts {
|
||||||
|
f(database)
|
||||||
|
}
|
||||||
|
return database
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithClient sets client option of a Database object.
|
||||||
|
func WithClient(c *db.Client) func(*Database) {
|
||||||
|
return func(s *Database) {
|
||||||
|
s.client = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHasher sets client option of a Database object.
|
||||||
|
func WithHasher(h func() hash.Hash) func(*Database) {
|
||||||
|
return func(s *Database) {
|
||||||
|
s.hasher = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schema exposes migration schema to perform migrations.
|
||||||
|
func (d *Database) Schema() *migrate.Schema {
|
||||||
|
return d.client.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close calls the corresponding method of the ent database client.
|
||||||
|
func (d *Database) Close() error {
|
||||||
|
return d.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GarbageCollect removes expired entities from the database.
|
||||||
|
func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) {
|
||||||
|
result := storage.GCResult{}
|
||||||
|
utcNow := now.UTC()
|
||||||
|
|
||||||
|
q, err := d.client.AuthRequest.Delete().
|
||||||
|
Where(authrequest.ExpiryLT(utcNow)).
|
||||||
|
Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return result, convertDBError("gc auth request: %w", err)
|
||||||
|
}
|
||||||
|
result.AuthRequests = int64(q)
|
||||||
|
|
||||||
|
q, err = d.client.AuthCode.Delete().
|
||||||
|
Where(authcode.ExpiryLT(utcNow)).
|
||||||
|
Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return result, convertDBError("gc auth code: %w", err)
|
||||||
|
}
|
||||||
|
result.AuthCodes = int64(q)
|
||||||
|
|
||||||
|
q, err = d.client.DeviceRequest.Delete().
|
||||||
|
Where(devicerequest.ExpiryLT(utcNow)).
|
||||||
|
Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return result, convertDBError("gc device request: %w", err)
|
||||||
|
}
|
||||||
|
result.DeviceRequests = int64(q)
|
||||||
|
|
||||||
|
q, err = d.client.DeviceToken.Delete().
|
||||||
|
Where(devicetoken.ExpiryLT(utcNow)).
|
||||||
|
Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return result, convertDBError("gc device token: %w", err)
|
||||||
|
}
|
||||||
|
result.DeviceTokens = int64(q)
|
||||||
|
|
||||||
|
return result, err
|
||||||
|
}
|
93
storage/ent/client/offlinesession.go
Normal file
93
storage/ent/client/offlinesession.go
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateOfflineSessions saves provided offline session into the database.
|
||||||
|
func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error {
|
||||||
|
encodedRefresh, err := json.Marshal(session.Refresh)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encode refresh offline session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id := offlineSessionID(session.UserID, session.ConnID, d.hasher)
|
||||||
|
_, err = d.client.OfflineSession.Create().
|
||||||
|
SetID(id).
|
||||||
|
SetUserID(session.UserID).
|
||||||
|
SetConnID(session.ConnID).
|
||||||
|
SetConnectorData(session.ConnectorData).
|
||||||
|
SetRefresh(encodedRefresh).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("create offline session: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOfflineSessions extracts an offline session from the database by user id and connector id.
|
||||||
|
func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSessions, error) {
|
||||||
|
id := offlineSessionID(userID, connID, d.hasher)
|
||||||
|
|
||||||
|
offlineSession, err := d.client.OfflineSession.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return storage.OfflineSessions{}, convertDBError("get offline session: %w", err)
|
||||||
|
}
|
||||||
|
return toStorageOfflineSession(offlineSession), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOfflineSessions deletes an offline session from the database by user id and connector id.
|
||||||
|
func (d *Database) DeleteOfflineSessions(userID, connID string) error {
|
||||||
|
id := offlineSessionID(userID, connID, d.hasher)
|
||||||
|
|
||||||
|
err := d.client.OfflineSession.DeleteOneID(id).Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("delete offline session: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePassword changes an offline session by user id and connector id using an updater function.
|
||||||
|
func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||||
|
id := offlineSessionID(userID, connID, d.hasher)
|
||||||
|
|
||||||
|
tx, err := d.client.Tx(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("update offline session tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
offlineSession, err := tx.OfflineSession.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update offline session database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newOfflineSession, err := updater(toStorageOfflineSession(offlineSession))
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update offline session updating: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encodedRefresh, err := json.Marshal(newOfflineSession.Refresh)
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "encode refresh offline session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.OfflineSession.UpdateOneID(id).
|
||||||
|
SetUserID(newOfflineSession.UserID).
|
||||||
|
SetConnID(newOfflineSession.ConnID).
|
||||||
|
SetConnectorData(newOfflineSession.ConnectorData).
|
||||||
|
SetRefresh(encodedRefresh).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update offline session uploading: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return rollback(tx, "update password commit: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
100
storage/ent/client/password.go
Normal file
100
storage/ent/client/password.go
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db/password"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreatePassword saves provided password into the database.
|
||||||
|
func (d *Database) CreatePassword(password storage.Password) error {
|
||||||
|
_, err := d.client.Password.Create().
|
||||||
|
SetEmail(password.Email).
|
||||||
|
SetHash(password.Hash).
|
||||||
|
SetUsername(password.Username).
|
||||||
|
SetUserID(password.UserID).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("create password: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListPasswords extracts an array of passwords from the database.
|
||||||
|
func (d *Database) ListPasswords() ([]storage.Password, error) {
|
||||||
|
passwords, err := d.client.Password.Query().All(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return nil, convertDBError("list passwords: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
storagePasswords := make([]storage.Password, 0, len(passwords))
|
||||||
|
for _, p := range passwords {
|
||||||
|
storagePasswords = append(storagePasswords, toStoragePassword(p))
|
||||||
|
}
|
||||||
|
return storagePasswords, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPassword extracts a password from the database by email.
|
||||||
|
func (d *Database) GetPassword(email string) (storage.Password, error) {
|
||||||
|
email = strings.ToLower(email)
|
||||||
|
passwordFromStorage, err := d.client.Password.Query().
|
||||||
|
Where(password.Email(email)).
|
||||||
|
Only(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return storage.Password{}, convertDBError("get password: %w", err)
|
||||||
|
}
|
||||||
|
return toStoragePassword(passwordFromStorage), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePassword deletes a password from the database by email.
|
||||||
|
func (d *Database) DeletePassword(email string) error {
|
||||||
|
email = strings.ToLower(email)
|
||||||
|
_, err := d.client.Password.Delete().
|
||||||
|
Where(password.Email(email)).
|
||||||
|
Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("delete password: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePassword changes a password by email using an updater function and saves it to the database.
|
||||||
|
func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error {
|
||||||
|
email = strings.ToLower(email)
|
||||||
|
|
||||||
|
tx, err := d.client.Tx(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("update connector tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
passwordToUpdate, err := tx.Password.Query().
|
||||||
|
Where(password.Email(email)).
|
||||||
|
Only(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update password database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newPassword, err := updater(toStoragePassword(passwordToUpdate))
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update password updating: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Password.Update().
|
||||||
|
Where(password.Email(newPassword.Email)).
|
||||||
|
SetEmail(newPassword.Email).
|
||||||
|
SetHash(newPassword.Hash).
|
||||||
|
SetUsername(newPassword.Username).
|
||||||
|
SetUserID(newPassword.UserID).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update password uploading: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return rollback(tx, "update password commit: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
109
storage/ent/client/refreshtoken.go
Normal file
109
storage/ent/client/refreshtoken.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateRefresh saves provided refresh token into the database.
|
||||||
|
func (d *Database) CreateRefresh(refresh storage.RefreshToken) error {
|
||||||
|
_, err := d.client.RefreshToken.Create().
|
||||||
|
SetID(refresh.ID).
|
||||||
|
SetClientID(refresh.ClientID).
|
||||||
|
SetScopes(refresh.Scopes).
|
||||||
|
SetNonce(refresh.Nonce).
|
||||||
|
SetClaimsUserID(refresh.Claims.UserID).
|
||||||
|
SetClaimsEmail(refresh.Claims.Email).
|
||||||
|
SetClaimsEmailVerified(refresh.Claims.EmailVerified).
|
||||||
|
SetClaimsUsername(refresh.Claims.Username).
|
||||||
|
SetClaimsPreferredUsername(refresh.Claims.PreferredUsername).
|
||||||
|
SetClaimsGroups(refresh.Claims.Groups).
|
||||||
|
SetConnectorID(refresh.ConnectorID).
|
||||||
|
SetConnectorData(refresh.ConnectorData).
|
||||||
|
SetToken(refresh.Token).
|
||||||
|
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||||
|
SetLastUsed(refresh.LastUsed.UTC()).
|
||||||
|
SetCreatedAt(refresh.CreatedAt.UTC()).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("create refresh token: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRefreshTokens extracts an array of refresh tokens from the database.
|
||||||
|
func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||||
|
refreshTokens, err := d.client.RefreshToken.Query().All(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return nil, convertDBError("list refresh tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
storageRefreshTokens := make([]storage.RefreshToken, 0, len(refreshTokens))
|
||||||
|
for _, r := range refreshTokens {
|
||||||
|
storageRefreshTokens = append(storageRefreshTokens, toStorageRefreshToken(r))
|
||||||
|
}
|
||||||
|
return storageRefreshTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRefresh extracts a refresh token from the database by id.
|
||||||
|
func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||||
|
refreshToken, err := d.client.RefreshToken.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return storage.RefreshToken{}, convertDBError("get refresh token: %w", err)
|
||||||
|
}
|
||||||
|
return toStorageRefreshToken(refreshToken), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRefresh deletes a refresh token from the database by id.
|
||||||
|
func (d *Database) DeleteRefresh(id string) error {
|
||||||
|
err := d.client.RefreshToken.DeleteOneID(id).Exec(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("delete refresh token: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database.
|
||||||
|
func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||||
|
tx, err := d.client.Tx(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return convertDBError("update refresh token tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := tx.RefreshToken.Get(context.TODO(), id)
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update refresh token database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newtToken, err := updater(toStorageRefreshToken(token))
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update refresh token updating: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.RefreshToken.UpdateOneID(newtToken.ID).
|
||||||
|
SetClientID(newtToken.ClientID).
|
||||||
|
SetScopes(newtToken.Scopes).
|
||||||
|
SetNonce(newtToken.Nonce).
|
||||||
|
SetClaimsUserID(newtToken.Claims.UserID).
|
||||||
|
SetClaimsEmail(newtToken.Claims.Email).
|
||||||
|
SetClaimsEmailVerified(newtToken.Claims.EmailVerified).
|
||||||
|
SetClaimsUsername(newtToken.Claims.Username).
|
||||||
|
SetClaimsPreferredUsername(newtToken.Claims.PreferredUsername).
|
||||||
|
SetClaimsGroups(newtToken.Claims.Groups).
|
||||||
|
SetConnectorID(newtToken.ConnectorID).
|
||||||
|
SetConnectorData(newtToken.ConnectorData).
|
||||||
|
SetToken(newtToken.Token).
|
||||||
|
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||||
|
SetLastUsed(newtToken.LastUsed.UTC()).
|
||||||
|
SetCreatedAt(newtToken.CreatedAt.UTC()).
|
||||||
|
Save(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return rollback(tx, "update refresh token uploading: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return rollback(tx, "update refresh token commit: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
167
storage/ent/client/types.go
Normal file
167
storage/ent/client/types.go
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
const keysRowID = "keys"
|
||||||
|
|
||||||
|
func toStorageKeys(keys *db.Keys) storage.Keys {
|
||||||
|
return storage.Keys{
|
||||||
|
SigningKey: &keys.SigningKey,
|
||||||
|
SigningKeyPub: &keys.SigningKeyPub,
|
||||||
|
VerificationKeys: keys.VerificationKeys,
|
||||||
|
NextRotation: keys.NextRotation,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest {
|
||||||
|
return storage.AuthRequest{
|
||||||
|
ID: a.ID,
|
||||||
|
ClientID: a.ClientID,
|
||||||
|
ResponseTypes: a.ResponseTypes,
|
||||||
|
Scopes: a.Scopes,
|
||||||
|
RedirectURI: a.RedirectURI,
|
||||||
|
Nonce: a.Nonce,
|
||||||
|
State: a.State,
|
||||||
|
ForceApprovalPrompt: a.ForceApprovalPrompt,
|
||||||
|
LoggedIn: a.LoggedIn,
|
||||||
|
ConnectorID: a.ConnectorID,
|
||||||
|
ConnectorData: *a.ConnectorData,
|
||||||
|
Expiry: a.Expiry,
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: a.ClaimsUserID,
|
||||||
|
Username: a.ClaimsUsername,
|
||||||
|
PreferredUsername: a.ClaimsPreferredUsername,
|
||||||
|
Email: a.ClaimsEmail,
|
||||||
|
EmailVerified: a.ClaimsEmailVerified,
|
||||||
|
Groups: a.ClaimsGroups,
|
||||||
|
},
|
||||||
|
PKCE: storage.PKCE{
|
||||||
|
CodeChallenge: a.CodeChallenge,
|
||||||
|
CodeChallengeMethod: a.CodeChallengeMethod,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageAuthCode(a *db.AuthCode) storage.AuthCode {
|
||||||
|
return storage.AuthCode{
|
||||||
|
ID: a.ID,
|
||||||
|
ClientID: a.ClientID,
|
||||||
|
Scopes: a.Scopes,
|
||||||
|
RedirectURI: a.RedirectURI,
|
||||||
|
Nonce: a.Nonce,
|
||||||
|
ConnectorID: a.ConnectorID,
|
||||||
|
ConnectorData: *a.ConnectorData,
|
||||||
|
Expiry: a.Expiry,
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: a.ClaimsUserID,
|
||||||
|
Username: a.ClaimsUsername,
|
||||||
|
PreferredUsername: a.ClaimsPreferredUsername,
|
||||||
|
Email: a.ClaimsEmail,
|
||||||
|
EmailVerified: a.ClaimsEmailVerified,
|
||||||
|
Groups: a.ClaimsGroups,
|
||||||
|
},
|
||||||
|
PKCE: storage.PKCE{
|
||||||
|
CodeChallenge: a.CodeChallenge,
|
||||||
|
CodeChallengeMethod: a.CodeChallengeMethod,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageClient(c *db.OAuth2Client) storage.Client {
|
||||||
|
return storage.Client{
|
||||||
|
ID: c.ID,
|
||||||
|
Secret: c.Secret,
|
||||||
|
RedirectURIs: c.RedirectUris,
|
||||||
|
TrustedPeers: c.TrustedPeers,
|
||||||
|
Public: c.Public,
|
||||||
|
Name: c.Name,
|
||||||
|
LogoURL: c.LogoURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageConnector(c *db.Connector) storage.Connector {
|
||||||
|
return storage.Connector{
|
||||||
|
ID: c.ID,
|
||||||
|
Type: c.Type,
|
||||||
|
Name: c.Name,
|
||||||
|
Config: c.Config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageOfflineSession(o *db.OfflineSession) storage.OfflineSessions {
|
||||||
|
s := storage.OfflineSessions{
|
||||||
|
UserID: o.UserID,
|
||||||
|
ConnID: o.ConnID,
|
||||||
|
ConnectorData: *o.ConnectorData,
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.Refresh != nil {
|
||||||
|
if err := json.Unmarshal(o.Refresh, &s.Refresh); err != nil {
|
||||||
|
// Correctness of json structure if guaranteed on uploading
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Server code assumes this will be non-nil.
|
||||||
|
s.Refresh = make(map[string]*storage.RefreshTokenRef)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageRefreshToken(r *db.RefreshToken) storage.RefreshToken {
|
||||||
|
return storage.RefreshToken{
|
||||||
|
ID: r.ID,
|
||||||
|
Token: r.Token,
|
||||||
|
CreatedAt: r.CreatedAt,
|
||||||
|
LastUsed: r.LastUsed,
|
||||||
|
ClientID: r.ClientID,
|
||||||
|
ConnectorID: r.ConnectorID,
|
||||||
|
ConnectorData: *r.ConnectorData,
|
||||||
|
Scopes: r.Scopes,
|
||||||
|
Nonce: r.Nonce,
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: r.ClaimsUserID,
|
||||||
|
Username: r.ClaimsUsername,
|
||||||
|
PreferredUsername: r.ClaimsPreferredUsername,
|
||||||
|
Email: r.ClaimsEmail,
|
||||||
|
EmailVerified: r.ClaimsEmailVerified,
|
||||||
|
Groups: r.ClaimsGroups,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStoragePassword(p *db.Password) storage.Password {
|
||||||
|
return storage.Password{
|
||||||
|
Email: p.Email,
|
||||||
|
Hash: p.Hash,
|
||||||
|
Username: p.Username,
|
||||||
|
UserID: p.UserID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageDeviceRequest(r *db.DeviceRequest) storage.DeviceRequest {
|
||||||
|
return storage.DeviceRequest{
|
||||||
|
UserCode: strings.ToUpper(r.UserCode),
|
||||||
|
DeviceCode: r.DeviceCode,
|
||||||
|
ClientID: r.ClientID,
|
||||||
|
ClientSecret: r.ClientSecret,
|
||||||
|
Scopes: r.Scopes,
|
||||||
|
Expiry: r.Expiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken {
|
||||||
|
return storage.DeviceToken{
|
||||||
|
DeviceCode: t.DeviceCode,
|
||||||
|
Status: t.Status,
|
||||||
|
Token: string(*t.Token),
|
||||||
|
Expiry: t.Expiry,
|
||||||
|
LastRequestTime: t.LastRequest,
|
||||||
|
PollIntervalSeconds: t.PollInterval,
|
||||||
|
}
|
||||||
|
}
|
44
storage/ent/client/utils.go
Normal file
44
storage/ent/client/utils.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func rollback(tx *db.Tx, t string, err error) error {
|
||||||
|
rerr := tx.Rollback()
|
||||||
|
err = convertDBError(t, err)
|
||||||
|
|
||||||
|
if rerr == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return errors.Wrapf(err, "rolling back transaction: %v", rerr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertDBError(t string, err error) error {
|
||||||
|
if db.IsNotFound(err) {
|
||||||
|
return storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.IsConstraintError(err) {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// compose hashed id from user and connection id to use it as primary key
|
||||||
|
// ent doesn't support multi-key primary yet
|
||||||
|
// https://github.com/facebook/ent/issues/400
|
||||||
|
func offlineSessionID(userID string, connID string, hasher func() hash.Hash) string {
|
||||||
|
h := hasher()
|
||||||
|
|
||||||
|
h.Write([]byte(userID))
|
||||||
|
h.Write([]byte(connID))
|
||||||
|
return fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
}
|
3
storage/ent/generate.go
Normal file
3
storage/ent/generate.go
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
package ent
|
||||||
|
|
||||||
|
//go:generate go run github.com/facebook/ent/cmd/entc generate ./schema --target ./db
|
89
storage/ent/schema/authcode.go
Normal file
89
storage/ent/schema/authcode.go
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table auth_code
|
||||||
|
(
|
||||||
|
id text not null primary key,
|
||||||
|
client_id text not null,
|
||||||
|
scopes blob not null,
|
||||||
|
nonce text not null,
|
||||||
|
redirect_uri text not null,
|
||||||
|
claims_user_id text not null,
|
||||||
|
claims_username text not null,
|
||||||
|
claims_email text not null,
|
||||||
|
claims_email_verified integer not null,
|
||||||
|
claims_groups blob not null,
|
||||||
|
connector_id text not null,
|
||||||
|
connector_data blob,
|
||||||
|
expiry timestamp not null,
|
||||||
|
claims_preferred_username text default '' not null,
|
||||||
|
code_challenge text default '' not null,
|
||||||
|
code_challenge_method text default '' not null
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// AuthCode holds the schema definition for the AuthCode entity.
|
||||||
|
type AuthCode struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the AuthCode.
|
||||||
|
func (AuthCode) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
field.Text("id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.Text("client_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.JSON("scopes", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.Text("nonce").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("redirect_uri").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
|
||||||
|
field.Text("claims_user_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("claims_username").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("claims_email").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Bool("claims_email_verified"),
|
||||||
|
field.JSON("claims_groups", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.Text("claims_preferred_username").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
Default(""),
|
||||||
|
|
||||||
|
field.Text("connector_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Bytes("connector_data").
|
||||||
|
Nillable().
|
||||||
|
Optional(),
|
||||||
|
field.Time("expiry"),
|
||||||
|
field.Text("code_challenge").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
Default(""),
|
||||||
|
field.Text("code_challenge_method").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
Default(""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the AuthCode.
|
||||||
|
func (AuthCode) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
94
storage/ent/schema/authrequest.go
Normal file
94
storage/ent/schema/authrequest.go
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table auth_request
|
||||||
|
(
|
||||||
|
id text not null primary key,
|
||||||
|
client_id text not null,
|
||||||
|
response_types blob not null,
|
||||||
|
scopes blob not null,
|
||||||
|
redirect_uri text not null,
|
||||||
|
nonce text not null,
|
||||||
|
state text not null,
|
||||||
|
force_approval_prompt integer not null,
|
||||||
|
logged_in integer not null,
|
||||||
|
claims_user_id text not null,
|
||||||
|
claims_username text not null,
|
||||||
|
claims_email text not null,
|
||||||
|
claims_email_verified integer not null,
|
||||||
|
claims_groups blob not null,
|
||||||
|
connector_id text not null,
|
||||||
|
connector_data blob,
|
||||||
|
expiry timestamp not null,
|
||||||
|
claims_preferred_username text default '' not null,
|
||||||
|
code_challenge text default '' not null,
|
||||||
|
code_challenge_method text default '' not null
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// AuthRequest holds the schema definition for the AuthRequest entity.
|
||||||
|
type AuthRequest struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the AuthRequest.
|
||||||
|
func (AuthRequest) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
field.Text("id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.Text("client_id").
|
||||||
|
SchemaType(textSchema),
|
||||||
|
field.JSON("scopes", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.JSON("response_types", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.Text("redirect_uri").
|
||||||
|
SchemaType(textSchema),
|
||||||
|
field.Text("nonce").
|
||||||
|
SchemaType(textSchema),
|
||||||
|
field.Text("state").
|
||||||
|
SchemaType(textSchema),
|
||||||
|
|
||||||
|
field.Bool("force_approval_prompt"),
|
||||||
|
field.Bool("logged_in"),
|
||||||
|
|
||||||
|
field.Text("claims_user_id").
|
||||||
|
SchemaType(textSchema),
|
||||||
|
field.Text("claims_username").
|
||||||
|
SchemaType(textSchema),
|
||||||
|
field.Text("claims_email").
|
||||||
|
SchemaType(textSchema),
|
||||||
|
field.Bool("claims_email_verified"),
|
||||||
|
field.JSON("claims_groups", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.Text("claims_preferred_username").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
Default(""),
|
||||||
|
|
||||||
|
field.Text("connector_id").
|
||||||
|
SchemaType(textSchema),
|
||||||
|
field.Bytes("connector_data").
|
||||||
|
Nillable().
|
||||||
|
Optional(),
|
||||||
|
field.Time("expiry"),
|
||||||
|
|
||||||
|
field.Text("code_challenge").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
Default(""),
|
||||||
|
field.Text("code_challenge_method").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
Default(""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the AuthRequest.
|
||||||
|
func (AuthRequest) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
53
storage/ent/schema/client.go
Normal file
53
storage/ent/schema/client.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table client
|
||||||
|
(
|
||||||
|
id text not null primary key,
|
||||||
|
secret text not null,
|
||||||
|
redirect_uris blob not null,
|
||||||
|
trusted_peers blob not null,
|
||||||
|
public integer not null,
|
||||||
|
name text not null,
|
||||||
|
logo_url text not null
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// OAuth2Client holds the schema definition for the Client entity.
|
||||||
|
type OAuth2Client struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the OAuth2Client.
|
||||||
|
func (OAuth2Client) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
field.Text("id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.Text("secret").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.JSON("redirect_uris", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.JSON("trusted_peers", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.Bool("public"),
|
||||||
|
field.Text("name").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("logo_url").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the OAuth2Client.
|
||||||
|
func (OAuth2Client) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
46
storage/ent/schema/connector.go
Normal file
46
storage/ent/schema/connector.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table connector
|
||||||
|
(
|
||||||
|
id text not null primary key,
|
||||||
|
type text not null,
|
||||||
|
name text not null,
|
||||||
|
resource_version text not null,
|
||||||
|
config blob
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Connector holds the schema definition for the Client entity.
|
||||||
|
type Connector struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the Connector.
|
||||||
|
func (Connector) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
field.Text("id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.Text("type").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("name").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("resource_version").
|
||||||
|
SchemaType(textSchema),
|
||||||
|
field.Bytes("config"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the Connector.
|
||||||
|
func (Connector) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
50
storage/ent/schema/devicerequest.go
Normal file
50
storage/ent/schema/devicerequest.go
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table device_request
|
||||||
|
(
|
||||||
|
user_code text not null primary key,
|
||||||
|
device_code text not null,
|
||||||
|
client_id text not null,
|
||||||
|
client_secret text,
|
||||||
|
scopes blob not null,
|
||||||
|
expiry timestamp not null
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// DeviceRequest holds the schema definition for the DeviceRequest entity.
|
||||||
|
type DeviceRequest struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the DeviceRequest.
|
||||||
|
func (DeviceRequest) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
field.Text("user_code").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.Text("device_code").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("client_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("client_secret").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.JSON("scopes", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.Time("expiry"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the DeviceRequest.
|
||||||
|
func (DeviceRequest) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
45
storage/ent/schema/devicetoken.go
Normal file
45
storage/ent/schema/devicetoken.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table device_token
|
||||||
|
(
|
||||||
|
device_code text not null primary key,
|
||||||
|
status text not null,
|
||||||
|
token blob,
|
||||||
|
expiry timestamp not null,
|
||||||
|
last_request timestamp not null,
|
||||||
|
poll_interval integer not null
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// DeviceToken holds the schema definition for the DeviceToken entity.
|
||||||
|
type DeviceToken struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the DeviceToken.
|
||||||
|
func (DeviceToken) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
field.Text("device_code").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.Text("status").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Bytes("token").Nillable().Optional(),
|
||||||
|
field.Time("expiry"),
|
||||||
|
field.Time("last_request"),
|
||||||
|
field.Int("poll_interval"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the DeviceToken.
|
||||||
|
func (DeviceToken) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
44
storage/ent/schema/keys.go
Normal file
44
storage/ent/schema/keys.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table keys
|
||||||
|
(
|
||||||
|
id text not null primary key,
|
||||||
|
verification_keys blob not null,
|
||||||
|
signing_key blob not null,
|
||||||
|
signing_key_pub blob not null,
|
||||||
|
next_rotation timestamp not null
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Keys holds the schema definition for the Keys entity.
|
||||||
|
type Keys struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the Keys.
|
||||||
|
func (Keys) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
field.Text("id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.JSON("verification_keys", []storage.VerificationKey{}),
|
||||||
|
field.JSON("signing_key", jose.JSONWebKey{}),
|
||||||
|
field.JSON("signing_key_pub", jose.JSONWebKey{}),
|
||||||
|
field.Time("next_rotation"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the Keys.
|
||||||
|
func (Keys) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
46
storage/ent/schema/offlinesession.go
Normal file
46
storage/ent/schema/offlinesession.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table offline_session
|
||||||
|
(
|
||||||
|
user_id text not null,
|
||||||
|
conn_id text not null,
|
||||||
|
refresh blob not null,
|
||||||
|
connector_data blob,
|
||||||
|
primary key (user_id, conn_id)
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// OfflineSession holds the schema definition for the OfflineSession entity.
|
||||||
|
type OfflineSession struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the OfflineSession.
|
||||||
|
func (OfflineSession) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
// Using id field here because it's impossible to create multi-key primary yet
|
||||||
|
field.Text("id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.Text("user_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("conn_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Bytes("refresh"),
|
||||||
|
field.Bytes("connector_data").Nillable().Optional(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the OfflineSession.
|
||||||
|
func (OfflineSession) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
44
storage/ent/schema/password.go
Normal file
44
storage/ent/schema/password.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table password
|
||||||
|
(
|
||||||
|
email text not null primary key,
|
||||||
|
hash blob not null,
|
||||||
|
username text not null,
|
||||||
|
user_id text not null
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Password holds the schema definition for the Password entity.
|
||||||
|
type Password struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the Password.
|
||||||
|
func (Password) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
field.Text("email").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
StorageKey("email"). // use email as ID field to make querying easier
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.Bytes("hash"),
|
||||||
|
field.Text("username").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("user_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the Password.
|
||||||
|
func (Password) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
89
storage/ent/schema/refreshtoken.go
Normal file
89
storage/ent/schema/refreshtoken.go
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/facebook/ent"
|
||||||
|
"github.com/facebook/ent/schema/field"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Original SQL table:
|
||||||
|
create table refresh_token
|
||||||
|
(
|
||||||
|
id text not null primary key,
|
||||||
|
client_id text not null,
|
||||||
|
scopes blob not null,
|
||||||
|
nonce text not null,
|
||||||
|
claims_user_id text not null,
|
||||||
|
claims_username text not null,
|
||||||
|
claims_email text not null,
|
||||||
|
claims_email_verified integer not null,
|
||||||
|
claims_groups blob not null,
|
||||||
|
connector_id text not null,
|
||||||
|
connector_data blob,
|
||||||
|
token text default '' not null,
|
||||||
|
created_at timestamp default '0001-01-01 00:00:00 UTC' not null,
|
||||||
|
last_used timestamp default '0001-01-01 00:00:00 UTC' not null,
|
||||||
|
claims_preferred_username text default '' not null
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// RefreshToken holds the schema definition for the RefreshToken entity.
|
||||||
|
type RefreshToken struct {
|
||||||
|
ent.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields of the RefreshToken.
|
||||||
|
func (RefreshToken) Fields() []ent.Field {
|
||||||
|
return []ent.Field{
|
||||||
|
field.Text("id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty().
|
||||||
|
Unique(),
|
||||||
|
field.Text("client_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.JSON("scopes", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.Text("nonce").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
|
||||||
|
field.Text("claims_user_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("claims_username").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Text("claims_email").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Bool("claims_email_verified"),
|
||||||
|
field.JSON("claims_groups", []string{}).
|
||||||
|
Optional(),
|
||||||
|
field.Text("claims_preferred_username").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
Default(""),
|
||||||
|
|
||||||
|
field.Text("connector_id").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
NotEmpty(),
|
||||||
|
field.Bytes("connector_data").
|
||||||
|
Nillable().
|
||||||
|
Optional(),
|
||||||
|
|
||||||
|
field.Text("token").
|
||||||
|
SchemaType(textSchema).
|
||||||
|
Default(""),
|
||||||
|
|
||||||
|
field.Time("created_at").
|
||||||
|
Default(time.Now),
|
||||||
|
field.Time("last_used").
|
||||||
|
Default(time.Now),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edges of the RefreshToken.
|
||||||
|
func (RefreshToken) Edges() []ent.Edge {
|
||||||
|
return []ent.Edge{}
|
||||||
|
}
|
9
storage/ent/schema/types.go
Normal file
9
storage/ent/schema/types.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/facebook/ent/dialect"
|
||||||
|
)
|
||||||
|
|
||||||
|
var textSchema = map[string]string{
|
||||||
|
dialect.SQLite: "text",
|
||||||
|
}
|
65
storage/ent/sqlite.go
Normal file
65
storage/ent/sqlite.go
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/facebook/ent/dialect/sql"
|
||||||
|
|
||||||
|
// Register sqlite driver.
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/pkg/log"
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/ent/client"
|
||||||
|
"github.com/dexidp/dex/storage/ent/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLite3 options for creating an SQL db.
|
||||||
|
type SQLite3 struct {
|
||||||
|
File string `json:"file"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open always returns a new in sqlite3 storage.
|
||||||
|
func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) {
|
||||||
|
logger.Debug("experimental ent-based storage driver is enabled")
|
||||||
|
|
||||||
|
// Implicitly set foreign_keys pragma to "on" because it is required by ent
|
||||||
|
s.File = addFK(s.File)
|
||||||
|
|
||||||
|
drv, err := sql.Open("sqlite3", s.File)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := drv.DB()
|
||||||
|
if s.File == ":memory:" {
|
||||||
|
// sqlite3 uses file locks to coordinate concurrent access. In memory
|
||||||
|
// doesn't support this, so limit the number of connections to 1.
|
||||||
|
pool.SetMaxOpenConns(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
databaseClient := client.NewDatabase(
|
||||||
|
client.WithClient(db.NewClient(db.Driver(drv))),
|
||||||
|
client.WithHasher(sha256.New),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := databaseClient.Schema().Create(context.TODO()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return databaseClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addFK(dsn string) string {
|
||||||
|
if strings.Contains(dsn, "_fk") {
|
||||||
|
return dsn
|
||||||
|
}
|
||||||
|
|
||||||
|
delim := "?"
|
||||||
|
if strings.Contains(dsn, "?") {
|
||||||
|
delim = "&"
|
||||||
|
}
|
||||||
|
return dsn + delim + "_fk=1"
|
||||||
|
}
|
31
storage/ent/sqlite_test.go
Normal file
31
storage/ent/sqlite_test.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/conformance"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newStorage() storage.Storage {
|
||||||
|
logger := &logrus.Logger{
|
||||||
|
Out: os.Stderr,
|
||||||
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||||
|
Level: logrus.DebugLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := SQLite3{File: ":memory:"}
|
||||||
|
s, err := cfg.Open(logger)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLite3(t *testing.T) {
|
||||||
|
conformance.RunTests(t, newStorage)
|
||||||
|
conformance.RunTransactionTests(t, newStorage)
|
||||||
|
}
|
Loading…
Reference in a new issue