From 11859166d0d50782563f29b2453560b550cab79c Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Thu, 31 Dec 2020 02:07:32 +0400 Subject: [PATCH] feat: Add ent-based sqlite3 storage Signed-off-by: m.nabokikh --- Makefile | 7 +- cmd/dex/config.go | 22 +++- go.mod | 3 +- go.sum | 2 + storage/ent/client/authcode.go | 52 +++++++++ storage/ent/client/authrequest.go | 107 +++++++++++++++++ storage/ent/client/client.go | 92 +++++++++++++++ storage/ent/client/connector.go | 88 ++++++++++++++ storage/ent/client/devicerequest.go | 36 ++++++ storage/ent/client/devicetoken.go | 76 ++++++++++++ storage/ent/client/keys.go | 81 +++++++++++++ storage/ent/client/main.go | 95 +++++++++++++++ storage/ent/client/offlinesession.go | 93 +++++++++++++++ storage/ent/client/password.go | 100 ++++++++++++++++ storage/ent/client/refreshtoken.go | 109 +++++++++++++++++ storage/ent/client/types.go | 167 +++++++++++++++++++++++++++ storage/ent/client/utils.go | 44 +++++++ storage/ent/generate.go | 3 + storage/ent/schema/authcode.go | 89 ++++++++++++++ storage/ent/schema/authrequest.go | 94 +++++++++++++++ storage/ent/schema/client.go | 53 +++++++++ storage/ent/schema/connector.go | 46 ++++++++ storage/ent/schema/devicerequest.go | 50 ++++++++ storage/ent/schema/devicetoken.go | 45 ++++++++ storage/ent/schema/keys.go | 44 +++++++ storage/ent/schema/offlinesession.go | 46 ++++++++ storage/ent/schema/password.go | 44 +++++++ storage/ent/schema/refreshtoken.go | 89 ++++++++++++++ storage/ent/schema/types.go | 9 ++ storage/ent/sqlite.go | 65 +++++++++++ storage/ent/sqlite_test.go | 31 +++++ 31 files changed, 1878 insertions(+), 4 deletions(-) create mode 100644 storage/ent/client/authcode.go create mode 100644 storage/ent/client/authrequest.go create mode 100644 storage/ent/client/client.go create mode 100644 storage/ent/client/connector.go create mode 100644 storage/ent/client/devicerequest.go create mode 100644 storage/ent/client/devicetoken.go create mode 100644 storage/ent/client/keys.go create mode 100644 storage/ent/client/main.go create mode 100644 storage/ent/client/offlinesession.go create mode 100644 storage/ent/client/password.go create mode 100644 storage/ent/client/refreshtoken.go create mode 100644 storage/ent/client/types.go create mode 100644 storage/ent/client/utils.go create mode 100644 storage/ent/generate.go create mode 100644 storage/ent/schema/authcode.go create mode 100644 storage/ent/schema/authrequest.go create mode 100644 storage/ent/schema/client.go create mode 100644 storage/ent/schema/connector.go create mode 100644 storage/ent/schema/devicerequest.go create mode 100644 storage/ent/schema/devicetoken.go create mode 100644 storage/ent/schema/keys.go create mode 100644 storage/ent/schema/offlinesession.go create mode 100644 storage/ent/schema/password.go create mode 100644 storage/ent/schema/refreshtoken.go create mode 100644 storage/ent/schema/types.go create mode 100644 storage/ent/sqlite.go create mode 100644 storage/ent/sqlite_test.go diff --git a/Makefile b/Makefile index 25de27c2..391311de 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,10 @@ PROTOC_VERSION = 3.15.6 PROTOC_GEN_GO_VERSION = 1.26.0 PROTOC_GEN_GO_GRPC_VERSION = 1.1.0 -build: bin/dex +generate: + @go generate $(REPO_PATH)/storage/ent/ + +build: generate bin/dex bin/dex: @mkdir -p bin/ @@ -42,7 +45,7 @@ bin/example-app: @mkdir -p bin/ @cd examples/ && go install -v -ldflags $(LD_FLAGS) $(REPO_PATH)/examples/example-app -.PHONY: release-binary +.PHONY: generate release-binary release-binary: @go build -o /go/bin/dex -v -ldflags $(LD_FLAGS) $(REPO_PATH)/cmd/dex diff --git a/cmd/dex/config.go b/cmd/dex/config.go index f218879d..bec6f620 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -13,6 +13,7 @@ import ( "github.com/dexidp/dex/pkg/log" "github.com/dexidp/dex/server" "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent" "github.com/dexidp/dex/storage/etcd" "github.com/dexidp/dex/storage/kubernetes" "github.com/dexidp/dex/storage/memory" @@ -173,13 +174,32 @@ type StorageConfig interface { Open(logger log.Logger) (storage.Storage, error) } +var ( + _ StorageConfig = (*etcd.Etcd)(nil) + _ StorageConfig = (*kubernetes.Config)(nil) + _ StorageConfig = (*memory.Config)(nil) + _ StorageConfig = (*sql.SQLite3)(nil) + _ StorageConfig = (*sql.Postgres)(nil) + _ StorageConfig = (*sql.MySQL)(nil) + _ StorageConfig = (*ent.SQLite3)(nil) +) + +func getORMBasedSQLiteStorage() StorageConfig { + switch os.Getenv("DEX_ENT_ENABLED") { + case "true", "yes": + return new(ent.SQLite3) + default: + return new(sql.SQLite3) + } +} + var storages = map[string]func() StorageConfig{ "etcd": func() StorageConfig { return new(etcd.Etcd) }, "kubernetes": func() StorageConfig { return new(kubernetes.Config) }, "memory": func() StorageConfig { return new(memory.Config) }, - "sqlite3": func() StorageConfig { return new(sql.SQLite3) }, "postgres": func() StorageConfig { return new(sql.Postgres) }, "mysql": func() StorageConfig { return new(sql.MySQL) }, + "sqlite3": getORMBasedSQLiteStorage, } // isExpandEnvEnabled returns if os.ExpandEnv should be used for each storage and connector config. diff --git a/go.mod b/go.mod index f850be78..c16152a1 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,8 @@ require ( github.com/beevik/etree v1.1.0 github.com/coreos/go-oidc/v3 v3.0.0 github.com/dexidp/dex/api/v2 v2.0.0 - github.com/felixge/httpsnoop v1.0.2 + github.com/facebook/ent v0.5.3 + github.com/felixge/httpsnoop v1.0.1 github.com/ghodss/yaml v1.0.0 github.com/go-ldap/ldap/v3 v3.3.0 github.com/go-sql-driver/mysql v1.6.0 diff --git a/go.sum b/go.sum index 0d885ff5..baf9c97d 100644 --- a/go.sum +++ b/go.sum @@ -128,6 +128,8 @@ github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5y github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/facebook/ent v0.5.3 h1:YT3Sl28n7gGGOkQeYgeJsZmizJ1Iiy7psgkOtEk0aq4= +github.com/facebook/ent v0.5.3/go.mod h1:tlWP+qCd3x2EeO7B/EqlJQ4dWu/2IeYFhP/szzDKAi8= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/felixge/httpsnoop v1.0.2 h1:+nS9g82KMXccJ/wp0zyRW9ZBHFETmMGtkk+2CTTrW4o= diff --git a/storage/ent/client/authcode.go b/storage/ent/client/authcode.go new file mode 100644 index 00000000..b6b263bf --- /dev/null +++ b/storage/ent/client/authcode.go @@ -0,0 +1,52 @@ +package client + +import ( + "context" + + "github.com/dexidp/dex/storage" +) + +// CreateAuthCode saves provided auth code into the database. +func (d *Database) CreateAuthCode(code storage.AuthCode) error { + _, err := d.client.AuthCode.Create(). + SetID(code.ID). + SetClientID(code.ClientID). + SetScopes(code.Scopes). + SetRedirectURI(code.RedirectURI). + SetNonce(code.Nonce). + SetClaimsUserID(code.Claims.UserID). + SetClaimsEmail(code.Claims.Email). + SetClaimsEmailVerified(code.Claims.EmailVerified). + SetClaimsUsername(code.Claims.Username). + SetClaimsPreferredUsername(code.Claims.PreferredUsername). + SetClaimsGroups(code.Claims.Groups). + SetCodeChallenge(code.PKCE.CodeChallenge). + SetCodeChallengeMethod(code.PKCE.CodeChallengeMethod). + // Save utc time into database because ent doesn't support comparing dates with different timezones + SetExpiry(code.Expiry.UTC()). + SetConnectorID(code.ConnectorID). + SetConnectorData(code.ConnectorData). + Save(context.TODO()) + if err != nil { + return convertDBError("create auth code: %w", err) + } + return nil +} + +// GetAuthCode extracts an auth code from the database by id. +func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) { + authCode, err := d.client.AuthCode.Get(context.TODO(), id) + if err != nil { + return storage.AuthCode{}, convertDBError("get auth code: %w", err) + } + return toStorageAuthCode(authCode), nil +} + +// DeleteAuthCode deletes an auth code from the database by id. +func (d *Database) DeleteAuthCode(id string) error { + err := d.client.AuthCode.DeleteOneID(id).Exec(context.TODO()) + if err != nil { + return convertDBError("delete auth code: %w", err) + } + return nil +} diff --git a/storage/ent/client/authrequest.go b/storage/ent/client/authrequest.go new file mode 100644 index 00000000..4cbb8b4e --- /dev/null +++ b/storage/ent/client/authrequest.go @@ -0,0 +1,107 @@ +package client + +import ( + "context" + "fmt" + + "github.com/dexidp/dex/storage" +) + +// CreateAuthRequest saves provided auth request into the database. +func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error { + _, err := d.client.AuthRequest.Create(). + SetID(authRequest.ID). + SetClientID(authRequest.ClientID). + SetScopes(authRequest.Scopes). + SetResponseTypes(authRequest.ResponseTypes). + SetRedirectURI(authRequest.RedirectURI). + SetState(authRequest.State). + SetNonce(authRequest.Nonce). + SetForceApprovalPrompt(authRequest.ForceApprovalPrompt). + SetLoggedIn(authRequest.LoggedIn). + SetClaimsUserID(authRequest.Claims.UserID). + SetClaimsEmail(authRequest.Claims.Email). + SetClaimsEmailVerified(authRequest.Claims.EmailVerified). + SetClaimsUsername(authRequest.Claims.Username). + SetClaimsPreferredUsername(authRequest.Claims.PreferredUsername). + SetClaimsGroups(authRequest.Claims.Groups). + SetCodeChallenge(authRequest.PKCE.CodeChallenge). + SetCodeChallengeMethod(authRequest.PKCE.CodeChallengeMethod). + // Save utc time into database because ent doesn't support comparing dates with different timezones + SetExpiry(authRequest.Expiry.UTC()). + SetConnectorID(authRequest.ConnectorID). + SetConnectorData(authRequest.ConnectorData). + Save(context.TODO()) + if err != nil { + return convertDBError("create auth request: %w", err) + } + return nil +} + +// GetAuthRequest extracts an auth request from the database by id. +func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) { + authRequest, err := d.client.AuthRequest.Get(context.TODO(), id) + if err != nil { + return storage.AuthRequest{}, convertDBError("get auth request: %w", err) + } + return toStorageAuthRequest(authRequest), nil +} + +// DeleteAuthRequest deletes an auth request from the database by id. +func (d *Database) DeleteAuthRequest(id string) error { + err := d.client.AuthRequest.DeleteOneID(id).Exec(context.TODO()) + if err != nil { + return convertDBError("delete auth request: %w", err) + } + return nil +} + +// UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database. +func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error { + tx, err := d.client.Tx(context.TODO()) + if err != nil { + return fmt.Errorf("update auth request tx: %w", err) + } + + authRequest, err := tx.AuthRequest.Get(context.TODO(), id) + if err != nil { + return rollback(tx, "update auth request database: %w", err) + } + + newAuthRequest, err := updater(toStorageAuthRequest(authRequest)) + if err != nil { + return rollback(tx, "update auth request updating: %w", err) + } + + _, err = tx.AuthRequest.UpdateOneID(newAuthRequest.ID). + SetClientID(newAuthRequest.ClientID). + SetScopes(newAuthRequest.Scopes). + SetResponseTypes(newAuthRequest.ResponseTypes). + SetRedirectURI(newAuthRequest.RedirectURI). + SetState(newAuthRequest.State). + SetNonce(newAuthRequest.Nonce). + SetForceApprovalPrompt(newAuthRequest.ForceApprovalPrompt). + SetLoggedIn(newAuthRequest.LoggedIn). + SetClaimsUserID(newAuthRequest.Claims.UserID). + SetClaimsEmail(newAuthRequest.Claims.Email). + SetClaimsEmailVerified(newAuthRequest.Claims.EmailVerified). + SetClaimsUsername(newAuthRequest.Claims.Username). + SetClaimsPreferredUsername(newAuthRequest.Claims.PreferredUsername). + SetClaimsGroups(newAuthRequest.Claims.Groups). + SetCodeChallenge(newAuthRequest.PKCE.CodeChallenge). + SetCodeChallengeMethod(newAuthRequest.PKCE.CodeChallengeMethod). + // Save utc time into database because ent doesn't support comparing dates with different timezones + SetExpiry(newAuthRequest.Expiry.UTC()). + SetConnectorID(newAuthRequest.ConnectorID). + SetConnectorData(newAuthRequest.ConnectorData). + Save(context.TODO()) + if err != nil { + return rollback(tx, "update auth request uploading: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update auth request commit: %w", err) + } + + return nil +} diff --git a/storage/ent/client/client.go b/storage/ent/client/client.go new file mode 100644 index 00000000..577508d6 --- /dev/null +++ b/storage/ent/client/client.go @@ -0,0 +1,92 @@ +package client + +import ( + "context" + + "github.com/dexidp/dex/storage" +) + +// CreateClient saves provided oauth2 client settings into the database. +func (d *Database) CreateClient(client storage.Client) error { + _, err := d.client.OAuth2Client.Create(). + SetID(client.ID). + SetName(client.Name). + SetSecret(client.Secret). + SetPublic(client.Public). + SetLogoURL(client.LogoURL). + SetRedirectUris(client.RedirectURIs). + SetTrustedPeers(client.TrustedPeers). + Save(context.TODO()) + if err != nil { + return convertDBError("create oauth2 client: %w", err) + } + return nil +} + +// ListClients extracts an array of oauth2 clients from the database. +func (d *Database) ListClients() ([]storage.Client, error) { + clients, err := d.client.OAuth2Client.Query().All(context.TODO()) + if err != nil { + return nil, convertDBError("list clients: %w", err) + } + + storageClients := make([]storage.Client, 0, len(clients)) + for _, c := range clients { + storageClients = append(storageClients, toStorageClient(c)) + } + return storageClients, nil +} + +// GetClient extracts an oauth2 client from the database by id. +func (d *Database) GetClient(id string) (storage.Client, error) { + client, err := d.client.OAuth2Client.Get(context.TODO(), id) + if err != nil { + return storage.Client{}, convertDBError("get client: %w", err) + } + return toStorageClient(client), nil +} + +// DeleteClient deletes an oauth2 client from the database by id. +func (d *Database) DeleteClient(id string) error { + err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO()) + if err != nil { + return convertDBError("delete client: %w", err) + } + return nil +} + +// UpdateClient changes an oauth2 client by id using an updater function and saves it to the database. +func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { + tx, err := d.client.Tx(context.TODO()) + if err != nil { + return convertDBError("update client tx: %w", err) + } + + client, err := tx.OAuth2Client.Get(context.TODO(), id) + if err != nil { + return rollback(tx, "update client database: %w", err) + } + + newClient, err := updater(toStorageClient(client)) + if err != nil { + return rollback(tx, "update client updating: %w", err) + } + + _, err = tx.OAuth2Client.UpdateOneID(newClient.ID). + SetName(newClient.Name). + SetSecret(newClient.Secret). + SetPublic(newClient.Public). + SetLogoURL(newClient.LogoURL). + SetRedirectUris(newClient.RedirectURIs). + SetTrustedPeers(newClient.TrustedPeers). + Save(context.TODO()) + if err != nil { + return rollback(tx, "update client uploading: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update auth request commit: %w", err) + } + + return nil +} diff --git a/storage/ent/client/connector.go b/storage/ent/client/connector.go new file mode 100644 index 00000000..ebba3f58 --- /dev/null +++ b/storage/ent/client/connector.go @@ -0,0 +1,88 @@ +package client + +import ( + "context" + + "github.com/dexidp/dex/storage" +) + +// CreateConnector saves a connector into the database. +func (d *Database) CreateConnector(connector storage.Connector) error { + _, err := d.client.Connector.Create(). + SetID(connector.ID). + SetName(connector.Name). + SetType(connector.Type). + SetResourceVersion(connector.ResourceVersion). + SetConfig(connector.Config). + Save(context.TODO()) + if err != nil { + return convertDBError("create connector: %w", err) + } + return nil +} + +// ListConnectors extracts an array of connectors from the database. +func (d *Database) ListConnectors() ([]storage.Connector, error) { + connectors, err := d.client.Connector.Query().All(context.TODO()) + if err != nil { + return nil, convertDBError("list connectors: %w", err) + } + + storageConnectors := make([]storage.Connector, 0, len(connectors)) + for _, c := range connectors { + storageConnectors = append(storageConnectors, toStorageConnector(c)) + } + return storageConnectors, nil +} + +// GetConnector extracts a connector from the database by id. +func (d *Database) GetConnector(id string) (storage.Connector, error) { + connector, err := d.client.Connector.Get(context.TODO(), id) + if err != nil { + return storage.Connector{}, convertDBError("get connector: %w", err) + } + return toStorageConnector(connector), nil +} + +// DeleteConnector deletes a connector from the database by id. +func (d *Database) DeleteConnector(id string) error { + err := d.client.Connector.DeleteOneID(id).Exec(context.TODO()) + if err != nil { + return convertDBError("delete connector: %w", err) + } + return nil +} + +// UpdateConnector changes a connector by id using an updater function and saves it to the database. +func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error { + tx, err := d.client.Tx(context.TODO()) + if err != nil { + return convertDBError("update connector tx: %w", err) + } + + connector, err := tx.Connector.Get(context.TODO(), id) + if err != nil { + return rollback(tx, "update connector database: %w", err) + } + + newConnector, err := updater(toStorageConnector(connector)) + if err != nil { + return rollback(tx, "update connector updating: %w", err) + } + + _, err = tx.Connector.UpdateOneID(newConnector.ID). + SetName(newConnector.Name). + SetType(newConnector.Type). + SetResourceVersion(newConnector.ResourceVersion). + SetConfig(newConnector.Config). + Save(context.TODO()) + if err != nil { + return rollback(tx, "update connector uploading: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update connector commit: %w", err) + } + + return nil +} diff --git a/storage/ent/client/devicerequest.go b/storage/ent/client/devicerequest.go new file mode 100644 index 00000000..6e9c2500 --- /dev/null +++ b/storage/ent/client/devicerequest.go @@ -0,0 +1,36 @@ +package client + +import ( + "context" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/db/devicerequest" +) + +// CreateDeviceRequest saves provided device request into the database. +func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error { + _, err := d.client.DeviceRequest.Create(). + SetClientID(request.ClientID). + SetClientSecret(request.ClientSecret). + SetScopes(request.Scopes). + SetUserCode(request.UserCode). + SetDeviceCode(request.DeviceCode). + // Save utc time into database because ent doesn't support comparing dates with different timezones + SetExpiry(request.Expiry.UTC()). + Save(context.TODO()) + if err != nil { + return convertDBError("create device request: %w", err) + } + return nil +} + +// GetDeviceRequest extracts a device request from the database by user code. +func (d *Database) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) { + deviceRequest, err := d.client.DeviceRequest.Query(). + Where(devicerequest.UserCode(userCode)). + Only(context.TODO()) + if err != nil { + return storage.DeviceRequest{}, convertDBError("get device request: %w", err) + } + return toStorageDeviceRequest(deviceRequest), nil +} diff --git a/storage/ent/client/devicetoken.go b/storage/ent/client/devicetoken.go new file mode 100644 index 00000000..89de1cb3 --- /dev/null +++ b/storage/ent/client/devicetoken.go @@ -0,0 +1,76 @@ +package client + +import ( + "context" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/db/devicetoken" +) + +// CreateDeviceToken saves provided token into the database. +func (d *Database) CreateDeviceToken(token storage.DeviceToken) error { + _, err := d.client.DeviceToken.Create(). + SetDeviceCode(token.DeviceCode). + SetToken([]byte(token.Token)). + SetPollInterval(token.PollIntervalSeconds). + // Save utc time into database because ent doesn't support comparing dates with different timezones + SetExpiry(token.Expiry.UTC()). + SetLastRequest(token.LastRequestTime.UTC()). + SetStatus(token.Status). + Save(context.TODO()) + if err != nil { + return convertDBError("create device token: %w", err) + } + return nil +} + +// GetDeviceToken extracts a token from the database by device code. +func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { + deviceToken, err := d.client.DeviceToken.Query(). + Where(devicetoken.DeviceCode(deviceCode)). + Only(context.TODO()) + if err != nil { + return storage.DeviceToken{}, convertDBError("get device token: %w", err) + } + return toStorageDeviceToken(deviceToken), nil +} + +// UpdateDeviceToken changes a token by device code using an updater function and saves it to the database. +func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + tx, err := d.client.Tx(context.TODO()) + if err != nil { + return convertDBError("update device token tx: %w", err) + } + + token, err := tx.DeviceToken.Query(). + Where(devicetoken.DeviceCode(deviceCode)). + Only(context.TODO()) + if err != nil { + return rollback(tx, "update device token database: %w", err) + } + + newToken, err := updater(toStorageDeviceToken(token)) + if err != nil { + return rollback(tx, "update device token updating: %w", err) + } + + _, err = tx.DeviceToken.Update(). + Where(devicetoken.DeviceCode(newToken.DeviceCode)). + SetDeviceCode(newToken.DeviceCode). + SetToken([]byte(newToken.Token)). + SetPollInterval(newToken.PollIntervalSeconds). + // Save utc time into database because ent doesn't support comparing dates with different timezones + SetExpiry(newToken.Expiry.UTC()). + SetLastRequest(newToken.LastRequestTime.UTC()). + SetStatus(newToken.Status). + Save(context.TODO()) + if err != nil { + return rollback(tx, "update device token uploading: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update device token commit: %w", err) + } + + return nil +} diff --git a/storage/ent/client/keys.go b/storage/ent/client/keys.go new file mode 100644 index 00000000..d9f32048 --- /dev/null +++ b/storage/ent/client/keys.go @@ -0,0 +1,81 @@ +package client + +import ( + "context" + "errors" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/db" +) + +func getKeys(client *db.KeysClient) (storage.Keys, error) { + rawKeys, err := client.Get(context.TODO(), keysRowID) + if err != nil { + return storage.Keys{}, convertDBError("get keys: %w", err) + } + + return toStorageKeys(rawKeys), nil +} + +// GetKeys returns signing keys, public keys and verification keys from the database. +func (d *Database) GetKeys() (storage.Keys, error) { + return getKeys(d.client.Keys) +} + +// UpdateKeys rotates keys using updater function. +func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { + firstUpdate := false + + tx, err := d.client.Tx(context.TODO()) + if err != nil { + return convertDBError("update keys tx: %w", err) + } + + storageKeys, err := getKeys(tx.Keys) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + return rollback(tx, "update keys get: %w", err) + } + firstUpdate = true + } + + newKeys, err := updater(storageKeys) + if err != nil { + return rollback(tx, "update keys updating: %w", err) + } + + // ent doesn't have an upsert support yet + // https://github.com/facebook/ent/issues/139 + if firstUpdate { + _, err = tx.Keys.Create(). + SetID(keysRowID). + SetNextRotation(newKeys.NextRotation). + SetSigningKey(*newKeys.SigningKey). + SetSigningKeyPub(*newKeys.SigningKeyPub). + SetVerificationKeys(newKeys.VerificationKeys). + Save(context.TODO()) + if err != nil { + return rollback(tx, "create keys: %w", err) + } + if err = tx.Commit(); err != nil { + return rollback(tx, "update keys commit: %w", err) + } + return nil + } + + err = tx.Keys.UpdateOneID(keysRowID). + SetNextRotation(newKeys.NextRotation.UTC()). + SetSigningKey(*newKeys.SigningKey). + SetSigningKeyPub(*newKeys.SigningKeyPub). + SetVerificationKeys(newKeys.VerificationKeys). + Exec(context.TODO()) + if err != nil { + return rollback(tx, "update keys uploading: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update keys commit: %w", err) + } + + return nil +} diff --git a/storage/ent/client/main.go b/storage/ent/client/main.go new file mode 100644 index 00000000..84dc7d97 --- /dev/null +++ b/storage/ent/client/main.go @@ -0,0 +1,95 @@ +package client + +import ( + "context" + "hash" + "time" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/db" + "github.com/dexidp/dex/storage/ent/db/authcode" + "github.com/dexidp/dex/storage/ent/db/authrequest" + "github.com/dexidp/dex/storage/ent/db/devicerequest" + "github.com/dexidp/dex/storage/ent/db/devicetoken" + "github.com/dexidp/dex/storage/ent/db/migrate" +) + +var _ storage.Storage = (*Database)(nil) + +type Database struct { + client *db.Client + hasher func() hash.Hash +} + +// NewDatabase returns new database client with set options. +func NewDatabase(opts ...func(*Database)) *Database { + database := &Database{} + for _, f := range opts { + f(database) + } + return database +} + +// WithClient sets client option of a Database object. +func WithClient(c *db.Client) func(*Database) { + return func(s *Database) { + s.client = c + } +} + +// WithHasher sets client option of a Database object. +func WithHasher(h func() hash.Hash) func(*Database) { + return func(s *Database) { + s.hasher = h + } +} + +// Schema exposes migration schema to perform migrations. +func (d *Database) Schema() *migrate.Schema { + return d.client.Schema +} + +// Close calls the corresponding method of the ent database client. +func (d *Database) Close() error { + return d.client.Close() +} + +// GarbageCollect removes expired entities from the database. +func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) { + result := storage.GCResult{} + utcNow := now.UTC() + + q, err := d.client.AuthRequest.Delete(). + Where(authrequest.ExpiryLT(utcNow)). + Exec(context.TODO()) + if err != nil { + return result, convertDBError("gc auth request: %w", err) + } + result.AuthRequests = int64(q) + + q, err = d.client.AuthCode.Delete(). + Where(authcode.ExpiryLT(utcNow)). + Exec(context.TODO()) + if err != nil { + return result, convertDBError("gc auth code: %w", err) + } + result.AuthCodes = int64(q) + + q, err = d.client.DeviceRequest.Delete(). + Where(devicerequest.ExpiryLT(utcNow)). + Exec(context.TODO()) + if err != nil { + return result, convertDBError("gc device request: %w", err) + } + result.DeviceRequests = int64(q) + + q, err = d.client.DeviceToken.Delete(). + Where(devicetoken.ExpiryLT(utcNow)). + Exec(context.TODO()) + if err != nil { + return result, convertDBError("gc device token: %w", err) + } + result.DeviceTokens = int64(q) + + return result, err +} diff --git a/storage/ent/client/offlinesession.go b/storage/ent/client/offlinesession.go new file mode 100644 index 00000000..cee415b6 --- /dev/null +++ b/storage/ent/client/offlinesession.go @@ -0,0 +1,93 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/dexidp/dex/storage" +) + +// CreateOfflineSessions saves provided offline session into the database. +func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error { + encodedRefresh, err := json.Marshal(session.Refresh) + if err != nil { + return fmt.Errorf("encode refresh offline session: %w", err) + } + + id := offlineSessionID(session.UserID, session.ConnID, d.hasher) + _, err = d.client.OfflineSession.Create(). + SetID(id). + SetUserID(session.UserID). + SetConnID(session.ConnID). + SetConnectorData(session.ConnectorData). + SetRefresh(encodedRefresh). + Save(context.TODO()) + if err != nil { + return convertDBError("create offline session: %w", err) + } + return nil +} + +// GetOfflineSessions extracts an offline session from the database by user id and connector id. +func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSessions, error) { + id := offlineSessionID(userID, connID, d.hasher) + + offlineSession, err := d.client.OfflineSession.Get(context.TODO(), id) + if err != nil { + return storage.OfflineSessions{}, convertDBError("get offline session: %w", err) + } + return toStorageOfflineSession(offlineSession), nil +} + +// DeleteOfflineSessions deletes an offline session from the database by user id and connector id. +func (d *Database) DeleteOfflineSessions(userID, connID string) error { + id := offlineSessionID(userID, connID, d.hasher) + + err := d.client.OfflineSession.DeleteOneID(id).Exec(context.TODO()) + if err != nil { + return convertDBError("delete offline session: %w", err) + } + return nil +} + +// UpdatePassword changes an offline session by user id and connector id using an updater function. +func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { + id := offlineSessionID(userID, connID, d.hasher) + + tx, err := d.client.Tx(context.TODO()) + if err != nil { + return convertDBError("update offline session tx: %w", err) + } + + offlineSession, err := tx.OfflineSession.Get(context.TODO(), id) + if err != nil { + return rollback(tx, "update offline session database: %w", err) + } + + newOfflineSession, err := updater(toStorageOfflineSession(offlineSession)) + if err != nil { + return rollback(tx, "update offline session updating: %w", err) + } + + encodedRefresh, err := json.Marshal(newOfflineSession.Refresh) + if err != nil { + return rollback(tx, "encode refresh offline session: %w", err) + } + + _, err = tx.OfflineSession.UpdateOneID(id). + SetUserID(newOfflineSession.UserID). + SetConnID(newOfflineSession.ConnID). + SetConnectorData(newOfflineSession.ConnectorData). + SetRefresh(encodedRefresh). + Save(context.TODO()) + if err != nil { + return rollback(tx, "update offline session uploading: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update password commit: %w", err) + } + + return nil +} diff --git a/storage/ent/client/password.go b/storage/ent/client/password.go new file mode 100644 index 00000000..003cbd1a --- /dev/null +++ b/storage/ent/client/password.go @@ -0,0 +1,100 @@ +package client + +import ( + "context" + "strings" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/db/password" +) + +// CreatePassword saves provided password into the database. +func (d *Database) CreatePassword(password storage.Password) error { + _, err := d.client.Password.Create(). + SetEmail(password.Email). + SetHash(password.Hash). + SetUsername(password.Username). + SetUserID(password.UserID). + Save(context.TODO()) + if err != nil { + return convertDBError("create password: %w", err) + } + return nil +} + +// ListPasswords extracts an array of passwords from the database. +func (d *Database) ListPasswords() ([]storage.Password, error) { + passwords, err := d.client.Password.Query().All(context.TODO()) + if err != nil { + return nil, convertDBError("list passwords: %w", err) + } + + storagePasswords := make([]storage.Password, 0, len(passwords)) + for _, p := range passwords { + storagePasswords = append(storagePasswords, toStoragePassword(p)) + } + return storagePasswords, nil +} + +// GetPassword extracts a password from the database by email. +func (d *Database) GetPassword(email string) (storage.Password, error) { + email = strings.ToLower(email) + passwordFromStorage, err := d.client.Password.Query(). + Where(password.Email(email)). + Only(context.TODO()) + if err != nil { + return storage.Password{}, convertDBError("get password: %w", err) + } + return toStoragePassword(passwordFromStorage), nil +} + +// DeletePassword deletes a password from the database by email. +func (d *Database) DeletePassword(email string) error { + email = strings.ToLower(email) + _, err := d.client.Password.Delete(). + Where(password.Email(email)). + Exec(context.TODO()) + if err != nil { + return convertDBError("delete password: %w", err) + } + return nil +} + +// UpdatePassword changes a password by email using an updater function and saves it to the database. +func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error { + email = strings.ToLower(email) + + tx, err := d.client.Tx(context.TODO()) + if err != nil { + return convertDBError("update connector tx: %w", err) + } + + passwordToUpdate, err := tx.Password.Query(). + Where(password.Email(email)). + Only(context.TODO()) + if err != nil { + return rollback(tx, "update password database: %w", err) + } + + newPassword, err := updater(toStoragePassword(passwordToUpdate)) + if err != nil { + return rollback(tx, "update password updating: %w", err) + } + + _, err = tx.Password.Update(). + Where(password.Email(newPassword.Email)). + SetEmail(newPassword.Email). + SetHash(newPassword.Hash). + SetUsername(newPassword.Username). + SetUserID(newPassword.UserID). + Save(context.TODO()) + if err != nil { + return rollback(tx, "update password uploading: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update password commit: %w", err) + } + + return nil +} diff --git a/storage/ent/client/refreshtoken.go b/storage/ent/client/refreshtoken.go new file mode 100644 index 00000000..90f3c6ae --- /dev/null +++ b/storage/ent/client/refreshtoken.go @@ -0,0 +1,109 @@ +package client + +import ( + "context" + + "github.com/dexidp/dex/storage" +) + +// CreateRefresh saves provided refresh token into the database. +func (d *Database) CreateRefresh(refresh storage.RefreshToken) error { + _, err := d.client.RefreshToken.Create(). + SetID(refresh.ID). + SetClientID(refresh.ClientID). + SetScopes(refresh.Scopes). + SetNonce(refresh.Nonce). + SetClaimsUserID(refresh.Claims.UserID). + SetClaimsEmail(refresh.Claims.Email). + SetClaimsEmailVerified(refresh.Claims.EmailVerified). + SetClaimsUsername(refresh.Claims.Username). + SetClaimsPreferredUsername(refresh.Claims.PreferredUsername). + SetClaimsGroups(refresh.Claims.Groups). + SetConnectorID(refresh.ConnectorID). + SetConnectorData(refresh.ConnectorData). + SetToken(refresh.Token). + // Save utc time into database because ent doesn't support comparing dates with different timezones + SetLastUsed(refresh.LastUsed.UTC()). + SetCreatedAt(refresh.CreatedAt.UTC()). + Save(context.TODO()) + if err != nil { + return convertDBError("create refresh token: %w", err) + } + return nil +} + +// ListRefreshTokens extracts an array of refresh tokens from the database. +func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) { + refreshTokens, err := d.client.RefreshToken.Query().All(context.TODO()) + if err != nil { + return nil, convertDBError("list refresh tokens: %w", err) + } + + storageRefreshTokens := make([]storage.RefreshToken, 0, len(refreshTokens)) + for _, r := range refreshTokens { + storageRefreshTokens = append(storageRefreshTokens, toStorageRefreshToken(r)) + } + return storageRefreshTokens, nil +} + +// GetRefresh extracts a refresh token from the database by id. +func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) { + refreshToken, err := d.client.RefreshToken.Get(context.TODO(), id) + if err != nil { + return storage.RefreshToken{}, convertDBError("get refresh token: %w", err) + } + return toStorageRefreshToken(refreshToken), nil +} + +// DeleteRefresh deletes a refresh token from the database by id. +func (d *Database) DeleteRefresh(id string) error { + err := d.client.RefreshToken.DeleteOneID(id).Exec(context.TODO()) + if err != nil { + return convertDBError("delete refresh token: %w", err) + } + return nil +} + +// UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database. +func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { + tx, err := d.client.Tx(context.TODO()) + if err != nil { + return convertDBError("update refresh token tx: %w", err) + } + + token, err := tx.RefreshToken.Get(context.TODO(), id) + if err != nil { + return rollback(tx, "update refresh token database: %w", err) + } + + newtToken, err := updater(toStorageRefreshToken(token)) + if err != nil { + return rollback(tx, "update refresh token updating: %w", err) + } + + _, err = tx.RefreshToken.UpdateOneID(newtToken.ID). + SetClientID(newtToken.ClientID). + SetScopes(newtToken.Scopes). + SetNonce(newtToken.Nonce). + SetClaimsUserID(newtToken.Claims.UserID). + SetClaimsEmail(newtToken.Claims.Email). + SetClaimsEmailVerified(newtToken.Claims.EmailVerified). + SetClaimsUsername(newtToken.Claims.Username). + SetClaimsPreferredUsername(newtToken.Claims.PreferredUsername). + SetClaimsGroups(newtToken.Claims.Groups). + SetConnectorID(newtToken.ConnectorID). + SetConnectorData(newtToken.ConnectorData). + SetToken(newtToken.Token). + // Save utc time into database because ent doesn't support comparing dates with different timezones + SetLastUsed(newtToken.LastUsed.UTC()). + SetCreatedAt(newtToken.CreatedAt.UTC()). + Save(context.TODO()) + if err != nil { + return rollback(tx, "update refresh token uploading: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update refresh token commit: %w", err) + } + return nil +} diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go new file mode 100644 index 00000000..388ef3e5 --- /dev/null +++ b/storage/ent/client/types.go @@ -0,0 +1,167 @@ +package client + +import ( + "encoding/json" + "strings" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/db" +) + +const keysRowID = "keys" + +func toStorageKeys(keys *db.Keys) storage.Keys { + return storage.Keys{ + SigningKey: &keys.SigningKey, + SigningKeyPub: &keys.SigningKeyPub, + VerificationKeys: keys.VerificationKeys, + NextRotation: keys.NextRotation, + } +} + +func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest { + return storage.AuthRequest{ + ID: a.ID, + ClientID: a.ClientID, + ResponseTypes: a.ResponseTypes, + Scopes: a.Scopes, + RedirectURI: a.RedirectURI, + Nonce: a.Nonce, + State: a.State, + ForceApprovalPrompt: a.ForceApprovalPrompt, + LoggedIn: a.LoggedIn, + ConnectorID: a.ConnectorID, + ConnectorData: *a.ConnectorData, + Expiry: a.Expiry, + Claims: storage.Claims{ + UserID: a.ClaimsUserID, + Username: a.ClaimsUsername, + PreferredUsername: a.ClaimsPreferredUsername, + Email: a.ClaimsEmail, + EmailVerified: a.ClaimsEmailVerified, + Groups: a.ClaimsGroups, + }, + PKCE: storage.PKCE{ + CodeChallenge: a.CodeChallenge, + CodeChallengeMethod: a.CodeChallengeMethod, + }, + } +} + +func toStorageAuthCode(a *db.AuthCode) storage.AuthCode { + return storage.AuthCode{ + ID: a.ID, + ClientID: a.ClientID, + Scopes: a.Scopes, + RedirectURI: a.RedirectURI, + Nonce: a.Nonce, + ConnectorID: a.ConnectorID, + ConnectorData: *a.ConnectorData, + Expiry: a.Expiry, + Claims: storage.Claims{ + UserID: a.ClaimsUserID, + Username: a.ClaimsUsername, + PreferredUsername: a.ClaimsPreferredUsername, + Email: a.ClaimsEmail, + EmailVerified: a.ClaimsEmailVerified, + Groups: a.ClaimsGroups, + }, + PKCE: storage.PKCE{ + CodeChallenge: a.CodeChallenge, + CodeChallengeMethod: a.CodeChallengeMethod, + }, + } +} + +func toStorageClient(c *db.OAuth2Client) storage.Client { + return storage.Client{ + ID: c.ID, + Secret: c.Secret, + RedirectURIs: c.RedirectUris, + TrustedPeers: c.TrustedPeers, + Public: c.Public, + Name: c.Name, + LogoURL: c.LogoURL, + } +} + +func toStorageConnector(c *db.Connector) storage.Connector { + return storage.Connector{ + ID: c.ID, + Type: c.Type, + Name: c.Name, + Config: c.Config, + } +} + +func toStorageOfflineSession(o *db.OfflineSession) storage.OfflineSessions { + s := storage.OfflineSessions{ + UserID: o.UserID, + ConnID: o.ConnID, + ConnectorData: *o.ConnectorData, + } + + if o.Refresh != nil { + if err := json.Unmarshal(o.Refresh, &s.Refresh); err != nil { + // Correctness of json structure if guaranteed on uploading + panic(err) + } + } else { + // Server code assumes this will be non-nil. + s.Refresh = make(map[string]*storage.RefreshTokenRef) + } + return s +} + +func toStorageRefreshToken(r *db.RefreshToken) storage.RefreshToken { + return storage.RefreshToken{ + ID: r.ID, + Token: r.Token, + CreatedAt: r.CreatedAt, + LastUsed: r.LastUsed, + ClientID: r.ClientID, + ConnectorID: r.ConnectorID, + ConnectorData: *r.ConnectorData, + Scopes: r.Scopes, + Nonce: r.Nonce, + Claims: storage.Claims{ + UserID: r.ClaimsUserID, + Username: r.ClaimsUsername, + PreferredUsername: r.ClaimsPreferredUsername, + Email: r.ClaimsEmail, + EmailVerified: r.ClaimsEmailVerified, + Groups: r.ClaimsGroups, + }, + } +} + +func toStoragePassword(p *db.Password) storage.Password { + return storage.Password{ + Email: p.Email, + Hash: p.Hash, + Username: p.Username, + UserID: p.UserID, + } +} + +func toStorageDeviceRequest(r *db.DeviceRequest) storage.DeviceRequest { + return storage.DeviceRequest{ + UserCode: strings.ToUpper(r.UserCode), + DeviceCode: r.DeviceCode, + ClientID: r.ClientID, + ClientSecret: r.ClientSecret, + Scopes: r.Scopes, + Expiry: r.Expiry, + } +} + +func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken { + return storage.DeviceToken{ + DeviceCode: t.DeviceCode, + Status: t.Status, + Token: string(*t.Token), + Expiry: t.Expiry, + LastRequestTime: t.LastRequest, + PollIntervalSeconds: t.PollInterval, + } +} diff --git a/storage/ent/client/utils.go b/storage/ent/client/utils.go new file mode 100644 index 00000000..65c037ac --- /dev/null +++ b/storage/ent/client/utils.go @@ -0,0 +1,44 @@ +package client + +import ( + "fmt" + "hash" + + "github.com/pkg/errors" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/db" +) + +func rollback(tx *db.Tx, t string, err error) error { + rerr := tx.Rollback() + err = convertDBError(t, err) + + if rerr == nil { + return err + } + return errors.Wrapf(err, "rolling back transaction: %v", rerr) +} + +func convertDBError(t string, err error) error { + if db.IsNotFound(err) { + return storage.ErrNotFound + } + + if db.IsConstraintError(err) { + return storage.ErrAlreadyExists + } + + return fmt.Errorf(t, err) +} + +// compose hashed id from user and connection id to use it as primary key +// ent doesn't support multi-key primary yet +// https://github.com/facebook/ent/issues/400 +func offlineSessionID(userID string, connID string, hasher func() hash.Hash) string { + h := hasher() + + h.Write([]byte(userID)) + h.Write([]byte(connID)) + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/storage/ent/generate.go b/storage/ent/generate.go new file mode 100644 index 00000000..4813a6da --- /dev/null +++ b/storage/ent/generate.go @@ -0,0 +1,3 @@ +package ent + +//go:generate go run github.com/facebook/ent/cmd/entc generate ./schema --target ./db diff --git a/storage/ent/schema/authcode.go b/storage/ent/schema/authcode.go new file mode 100644 index 00000000..fea075c9 --- /dev/null +++ b/storage/ent/schema/authcode.go @@ -0,0 +1,89 @@ +package schema + +import ( + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" +) + +/* Original SQL table: +create table auth_code +( + id text not null primary key, + client_id text not null, + scopes blob not null, + nonce text not null, + redirect_uri text not null, + claims_user_id text not null, + claims_username text not null, + claims_email text not null, + claims_email_verified integer not null, + claims_groups blob not null, + connector_id text not null, + connector_data blob, + expiry timestamp not null, + claims_preferred_username text default '' not null, + code_challenge text default '' not null, + code_challenge_method text default '' not null +); +*/ + +// AuthCode holds the schema definition for the AuthCode entity. +type AuthCode struct { + ent.Schema +} + +// Fields of the AuthCode. +func (AuthCode) Fields() []ent.Field { + return []ent.Field{ + field.Text("id"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Text("client_id"). + SchemaType(textSchema). + NotEmpty(), + field.JSON("scopes", []string{}). + Optional(), + field.Text("nonce"). + SchemaType(textSchema). + NotEmpty(), + field.Text("redirect_uri"). + SchemaType(textSchema). + NotEmpty(), + + field.Text("claims_user_id"). + SchemaType(textSchema). + NotEmpty(), + field.Text("claims_username"). + SchemaType(textSchema). + NotEmpty(), + field.Text("claims_email"). + SchemaType(textSchema). + NotEmpty(), + field.Bool("claims_email_verified"), + field.JSON("claims_groups", []string{}). + Optional(), + field.Text("claims_preferred_username"). + SchemaType(textSchema). + Default(""), + + field.Text("connector_id"). + SchemaType(textSchema). + NotEmpty(), + field.Bytes("connector_data"). + Nillable(). + Optional(), + field.Time("expiry"), + field.Text("code_challenge"). + SchemaType(textSchema). + Default(""), + field.Text("code_challenge_method"). + SchemaType(textSchema). + Default(""), + } +} + +// Edges of the AuthCode. +func (AuthCode) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/authrequest.go b/storage/ent/schema/authrequest.go new file mode 100644 index 00000000..f027c1a5 --- /dev/null +++ b/storage/ent/schema/authrequest.go @@ -0,0 +1,94 @@ +package schema + +import ( + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" +) + +/* Original SQL table: +create table auth_request +( + id text not null primary key, + client_id text not null, + response_types blob not null, + scopes blob not null, + redirect_uri text not null, + nonce text not null, + state text not null, + force_approval_prompt integer not null, + logged_in integer not null, + claims_user_id text not null, + claims_username text not null, + claims_email text not null, + claims_email_verified integer not null, + claims_groups blob not null, + connector_id text not null, + connector_data blob, + expiry timestamp not null, + claims_preferred_username text default '' not null, + code_challenge text default '' not null, + code_challenge_method text default '' not null +); +*/ + +// AuthRequest holds the schema definition for the AuthRequest entity. +type AuthRequest struct { + ent.Schema +} + +// Fields of the AuthRequest. +func (AuthRequest) Fields() []ent.Field { + return []ent.Field{ + field.Text("id"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Text("client_id"). + SchemaType(textSchema), + field.JSON("scopes", []string{}). + Optional(), + field.JSON("response_types", []string{}). + Optional(), + field.Text("redirect_uri"). + SchemaType(textSchema), + field.Text("nonce"). + SchemaType(textSchema), + field.Text("state"). + SchemaType(textSchema), + + field.Bool("force_approval_prompt"), + field.Bool("logged_in"), + + field.Text("claims_user_id"). + SchemaType(textSchema), + field.Text("claims_username"). + SchemaType(textSchema), + field.Text("claims_email"). + SchemaType(textSchema), + field.Bool("claims_email_verified"), + field.JSON("claims_groups", []string{}). + Optional(), + field.Text("claims_preferred_username"). + SchemaType(textSchema). + Default(""), + + field.Text("connector_id"). + SchemaType(textSchema), + field.Bytes("connector_data"). + Nillable(). + Optional(), + field.Time("expiry"), + + field.Text("code_challenge"). + SchemaType(textSchema). + Default(""), + field.Text("code_challenge_method"). + SchemaType(textSchema). + Default(""), + } +} + +// Edges of the AuthRequest. +func (AuthRequest) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/client.go b/storage/ent/schema/client.go new file mode 100644 index 00000000..85ea57b6 --- /dev/null +++ b/storage/ent/schema/client.go @@ -0,0 +1,53 @@ +package schema + +import ( + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" +) + +/* Original SQL table: +create table client +( + id text not null primary key, + secret text not null, + redirect_uris blob not null, + trusted_peers blob not null, + public integer not null, + name text not null, + logo_url text not null +); +*/ + +// OAuth2Client holds the schema definition for the Client entity. +type OAuth2Client struct { + ent.Schema +} + +// Fields of the OAuth2Client. +func (OAuth2Client) Fields() []ent.Field { + return []ent.Field{ + field.Text("id"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Text("secret"). + SchemaType(textSchema). + NotEmpty(), + field.JSON("redirect_uris", []string{}). + Optional(), + field.JSON("trusted_peers", []string{}). + Optional(), + field.Bool("public"), + field.Text("name"). + SchemaType(textSchema). + NotEmpty(), + field.Text("logo_url"). + SchemaType(textSchema). + NotEmpty(), + } +} + +// Edges of the OAuth2Client. +func (OAuth2Client) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/connector.go b/storage/ent/schema/connector.go new file mode 100644 index 00000000..4b5a6fb3 --- /dev/null +++ b/storage/ent/schema/connector.go @@ -0,0 +1,46 @@ +package schema + +import ( + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" +) + +/* Original SQL table: +create table connector +( + id text not null primary key, + type text not null, + name text not null, + resource_version text not null, + config blob +); +*/ + +// Connector holds the schema definition for the Client entity. +type Connector struct { + ent.Schema +} + +// Fields of the Connector. +func (Connector) Fields() []ent.Field { + return []ent.Field{ + field.Text("id"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Text("type"). + SchemaType(textSchema). + NotEmpty(), + field.Text("name"). + SchemaType(textSchema). + NotEmpty(), + field.Text("resource_version"). + SchemaType(textSchema), + field.Bytes("config"), + } +} + +// Edges of the Connector. +func (Connector) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/devicerequest.go b/storage/ent/schema/devicerequest.go new file mode 100644 index 00000000..71701e7f --- /dev/null +++ b/storage/ent/schema/devicerequest.go @@ -0,0 +1,50 @@ +package schema + +import ( + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" +) + +/* Original SQL table: +create table device_request +( + user_code text not null primary key, + device_code text not null, + client_id text not null, + client_secret text, + scopes blob not null, + expiry timestamp not null +); +*/ + +// DeviceRequest holds the schema definition for the DeviceRequest entity. +type DeviceRequest struct { + ent.Schema +} + +// Fields of the DeviceRequest. +func (DeviceRequest) Fields() []ent.Field { + return []ent.Field{ + field.Text("user_code"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Text("device_code"). + SchemaType(textSchema). + NotEmpty(), + field.Text("client_id"). + SchemaType(textSchema). + NotEmpty(), + field.Text("client_secret"). + SchemaType(textSchema). + NotEmpty(), + field.JSON("scopes", []string{}). + Optional(), + field.Time("expiry"), + } +} + +// Edges of the DeviceRequest. +func (DeviceRequest) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/devicetoken.go b/storage/ent/schema/devicetoken.go new file mode 100644 index 00000000..1b6eadaf --- /dev/null +++ b/storage/ent/schema/devicetoken.go @@ -0,0 +1,45 @@ +package schema + +import ( + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" +) + +/* Original SQL table: +create table device_token +( + device_code text not null primary key, + status text not null, + token blob, + expiry timestamp not null, + last_request timestamp not null, + poll_interval integer not null +); +*/ + +// DeviceToken holds the schema definition for the DeviceToken entity. +type DeviceToken struct { + ent.Schema +} + +// Fields of the DeviceToken. +func (DeviceToken) Fields() []ent.Field { + return []ent.Field{ + field.Text("device_code"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Text("status"). + SchemaType(textSchema). + NotEmpty(), + field.Bytes("token").Nillable().Optional(), + field.Time("expiry"), + field.Time("last_request"), + field.Int("poll_interval"), + } +} + +// Edges of the DeviceToken. +func (DeviceToken) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/keys.go b/storage/ent/schema/keys.go new file mode 100644 index 00000000..3d9e7ff6 --- /dev/null +++ b/storage/ent/schema/keys.go @@ -0,0 +1,44 @@ +package schema + +import ( + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" + "gopkg.in/square/go-jose.v2" + + "github.com/dexidp/dex/storage" +) + +/* Original SQL table: +create table keys +( + id text not null primary key, + verification_keys blob not null, + signing_key blob not null, + signing_key_pub blob not null, + next_rotation timestamp not null +); +*/ + +// Keys holds the schema definition for the Keys entity. +type Keys struct { + ent.Schema +} + +// Fields of the Keys. +func (Keys) Fields() []ent.Field { + return []ent.Field{ + field.Text("id"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.JSON("verification_keys", []storage.VerificationKey{}), + field.JSON("signing_key", jose.JSONWebKey{}), + field.JSON("signing_key_pub", jose.JSONWebKey{}), + field.Time("next_rotation"), + } +} + +// Edges of the Keys. +func (Keys) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/offlinesession.go b/storage/ent/schema/offlinesession.go new file mode 100644 index 00000000..16b764d5 --- /dev/null +++ b/storage/ent/schema/offlinesession.go @@ -0,0 +1,46 @@ +package schema + +import ( + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" +) + +/* Original SQL table: +create table offline_session +( + user_id text not null, + conn_id text not null, + refresh blob not null, + connector_data blob, + primary key (user_id, conn_id) +); +*/ + +// OfflineSession holds the schema definition for the OfflineSession entity. +type OfflineSession struct { + ent.Schema +} + +// Fields of the OfflineSession. +func (OfflineSession) Fields() []ent.Field { + return []ent.Field{ + // Using id field here because it's impossible to create multi-key primary yet + field.Text("id"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Text("user_id"). + SchemaType(textSchema). + NotEmpty(), + field.Text("conn_id"). + SchemaType(textSchema). + NotEmpty(), + field.Bytes("refresh"), + field.Bytes("connector_data").Nillable().Optional(), + } +} + +// Edges of the OfflineSession. +func (OfflineSession) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/password.go b/storage/ent/schema/password.go new file mode 100644 index 00000000..378d88d3 --- /dev/null +++ b/storage/ent/schema/password.go @@ -0,0 +1,44 @@ +package schema + +import ( + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" +) + +/* Original SQL table: +create table password +( + email text not null primary key, + hash blob not null, + username text not null, + user_id text not null +); +*/ + +// Password holds the schema definition for the Password entity. +type Password struct { + ent.Schema +} + +// Fields of the Password. +func (Password) Fields() []ent.Field { + return []ent.Field{ + field.Text("email"). + SchemaType(textSchema). + StorageKey("email"). // use email as ID field to make querying easier + NotEmpty(). + Unique(), + field.Bytes("hash"), + field.Text("username"). + SchemaType(textSchema). + NotEmpty(), + field.Text("user_id"). + SchemaType(textSchema). + NotEmpty(), + } +} + +// Edges of the Password. +func (Password) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/refreshtoken.go b/storage/ent/schema/refreshtoken.go new file mode 100644 index 00000000..36804ac9 --- /dev/null +++ b/storage/ent/schema/refreshtoken.go @@ -0,0 +1,89 @@ +package schema + +import ( + "time" + + "github.com/facebook/ent" + "github.com/facebook/ent/schema/field" +) + +/* Original SQL table: +create table refresh_token +( + id text not null primary key, + client_id text not null, + scopes blob not null, + nonce text not null, + claims_user_id text not null, + claims_username text not null, + claims_email text not null, + claims_email_verified integer not null, + claims_groups blob not null, + connector_id text not null, + connector_data blob, + token text default '' not null, + created_at timestamp default '0001-01-01 00:00:00 UTC' not null, + last_used timestamp default '0001-01-01 00:00:00 UTC' not null, + claims_preferred_username text default '' not null +); +*/ + +// RefreshToken holds the schema definition for the RefreshToken entity. +type RefreshToken struct { + ent.Schema +} + +// Fields of the RefreshToken. +func (RefreshToken) Fields() []ent.Field { + return []ent.Field{ + field.Text("id"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Text("client_id"). + SchemaType(textSchema). + NotEmpty(), + field.JSON("scopes", []string{}). + Optional(), + field.Text("nonce"). + SchemaType(textSchema). + NotEmpty(), + + field.Text("claims_user_id"). + SchemaType(textSchema). + NotEmpty(), + field.Text("claims_username"). + SchemaType(textSchema). + NotEmpty(), + field.Text("claims_email"). + SchemaType(textSchema). + NotEmpty(), + field.Bool("claims_email_verified"), + field.JSON("claims_groups", []string{}). + Optional(), + field.Text("claims_preferred_username"). + SchemaType(textSchema). + Default(""), + + field.Text("connector_id"). + SchemaType(textSchema). + NotEmpty(), + field.Bytes("connector_data"). + Nillable(). + Optional(), + + field.Text("token"). + SchemaType(textSchema). + Default(""), + + field.Time("created_at"). + Default(time.Now), + field.Time("last_used"). + Default(time.Now), + } +} + +// Edges of the RefreshToken. +func (RefreshToken) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/ent/schema/types.go b/storage/ent/schema/types.go new file mode 100644 index 00000000..2b0378d8 --- /dev/null +++ b/storage/ent/schema/types.go @@ -0,0 +1,9 @@ +package schema + +import ( + "github.com/facebook/ent/dialect" +) + +var textSchema = map[string]string{ + dialect.SQLite: "text", +} diff --git a/storage/ent/sqlite.go b/storage/ent/sqlite.go new file mode 100644 index 00000000..68601a1a --- /dev/null +++ b/storage/ent/sqlite.go @@ -0,0 +1,65 @@ +package ent + +import ( + "context" + "crypto/sha256" + "strings" + + "github.com/facebook/ent/dialect/sql" + + // Register sqlite driver. + _ "github.com/mattn/go-sqlite3" + + "github.com/dexidp/dex/pkg/log" + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/client" + "github.com/dexidp/dex/storage/ent/db" +) + +// SQLite3 options for creating an SQL db. +type SQLite3 struct { + File string `json:"file"` +} + +// Open always returns a new in sqlite3 storage. +func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) { + logger.Debug("experimental ent-based storage driver is enabled") + + // Implicitly set foreign_keys pragma to "on" because it is required by ent + s.File = addFK(s.File) + + drv, err := sql.Open("sqlite3", s.File) + if err != nil { + return nil, err + } + + pool := drv.DB() + if s.File == ":memory:" { + // sqlite3 uses file locks to coordinate concurrent access. In memory + // doesn't support this, so limit the number of connections to 1. + pool.SetMaxOpenConns(1) + } + + databaseClient := client.NewDatabase( + client.WithClient(db.NewClient(db.Driver(drv))), + client.WithHasher(sha256.New), + ) + + if err := databaseClient.Schema().Create(context.TODO()); err != nil { + return nil, err + } + + return databaseClient, nil +} + +func addFK(dsn string) string { + if strings.Contains(dsn, "_fk") { + return dsn + } + + delim := "?" + if strings.Contains(dsn, "?") { + delim = "&" + } + return dsn + delim + "_fk=1" +} diff --git a/storage/ent/sqlite_test.go b/storage/ent/sqlite_test.go new file mode 100644 index 00000000..053c827d --- /dev/null +++ b/storage/ent/sqlite_test.go @@ -0,0 +1,31 @@ +package ent + +import ( + "os" + "testing" + + "github.com/sirupsen/logrus" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/conformance" +) + +func newStorage() storage.Storage { + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + cfg := SQLite3{File: ":memory:"} + s, err := cfg.Open(logger) + if err != nil { + panic(err) + } + return s +} + +func TestSQLite3(t *testing.T) { + conformance.RunTests(t, newStorage) + conformance.RunTransactionTests(t, newStorage) +}