From eae3219e4d4b8781f1ffc0188512f88030f65a13 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Mon, 13 Sep 2021 14:25:17 +0400 Subject: [PATCH 1/5] feat: Add MySQL ent-based storage driver Signed-off-by: m.nabokikh --- .github/workflows/ci.yaml | 14 +++ cmd/dex/config.go | 3 +- storage/ent/mysql.go | 162 +++++++++++++++++++++++++++++++ storage/ent/mysql_test.go | 183 +++++++++++++++++++++++++++++++++++ storage/ent/postgres_test.go | 7 -- storage/ent/sqlite.go | 8 +- storage/ent/utils.go | 10 ++ 7 files changed, 374 insertions(+), 13 deletions(-) create mode 100644 storage/ent/mysql.go create mode 100644 storage/ent/mysql_test.go create mode 100644 storage/ent/utils.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fe3a0359..55deda32 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -35,6 +35,15 @@ jobs: - 3306 options: --health-cmd "mysql -proot -e \"show databases;\"" --health-interval 10s --health-timeout 5s --health-retries 5 + mysql-ent: + image: mysql:5.7 + env: + MYSQL_ROOT_PASSWORD: root + MYSQL_DATABASE: dex + ports: + - 3306 + options: --health-cmd "mysql -proot -e \"show databases;\"" --health-interval 10s --health-timeout 5s --health-retries 5 + etcd: image: gcr.io/etcd-development/etcd:v3.5.0 ports: @@ -77,6 +86,11 @@ jobs: DEX_MYSQL_PASSWORD: root DEX_MYSQL_HOST: 127.0.0.1 DEX_MYSQL_PORT: ${{ job.services.mysql.ports[3306] }} + DEX_MYSQL_ENT_DATABASE: dex + DEX_MYSQL_ENT_USER: root + DEX_MYSQL_ENT_PASSWORD: root + DEX_MYSQL_ENT_HOST: 127.0.0.1 + DEX_MYSQL_ENT_PORT: ${{ job.services.mysql-ent.ports[3306] }} DEX_POSTGRES_DATABASE: postgres DEX_POSTGRES_USER: postgres DEX_POSTGRES_PASSWORD: postgres diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 309fc52c..37167bb0 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -183,6 +183,7 @@ var ( _ StorageConfig = (*sql.MySQL)(nil) _ StorageConfig = (*ent.SQLite3)(nil) _ StorageConfig = (*ent.Postgres)(nil) + _ StorageConfig = (*ent.MySQL)(nil) ) func getORMBasedSQLStorage(normal, entBased StorageConfig) func() StorageConfig { @@ -200,9 +201,9 @@ var storages = map[string]func() StorageConfig{ "etcd": func() StorageConfig { return new(etcd.Etcd) }, "kubernetes": func() StorageConfig { return new(kubernetes.Config) }, "memory": func() StorageConfig { return new(memory.Config) }, - "mysql": func() StorageConfig { return new(sql.MySQL) }, "sqlite3": getORMBasedSQLStorage(&sql.SQLite3{}, &ent.SQLite3{}), "postgres": getORMBasedSQLStorage(&sql.Postgres{}, &ent.Postgres{}), + "mysql": getORMBasedSQLStorage(&sql.MySQL{}, &ent.MySQL{}), } // isExpandEnvEnabled returns if os.ExpandEnv should be used for each storage and connector config. diff --git a/storage/ent/mysql.go b/storage/ent/mysql.go new file mode 100644 index 00000000..7caa91ff --- /dev/null +++ b/storage/ent/mysql.go @@ -0,0 +1,162 @@ +package ent + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "database/sql" + "fmt" + "io/ioutil" + "net" + "strconv" + "time" + + entSQL "entgo.io/ent/dialect/sql" + "github.com/go-sql-driver/mysql" + + // Register postgres driver. + _ "github.com/lib/pq" + + "github.com/dexidp/dex/pkg/log" + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/client" + "github.com/dexidp/dex/storage/ent/db" +) + +// nolint +const ( + // MySQL SSL modes + mysqlSSLTrue = "true" + mysqlSSLFalse = "false" + mysqlSSLSkipVerify = "skip-verify" + mysqlSSLCustom = "custom" +) + +// MySQL options for creating an SQL db. +type MySQL struct { + NetworkDB + + SSL SSL `json:"ssl"` + + params map[string]string +} + +// Open always returns a new in sqlite3 storage. +func (m *MySQL) Open(logger log.Logger) (storage.Storage, error) { + logger.Debug("experimental ent-based storage driver is enabled") + drv, err := m.driver() + if err != nil { + return nil, err + } + + databaseClient := client.NewDatabase( + client.WithClient(db.NewClient(db.Driver(drv))), + client.WithHasher(sha256.New), + // Set tx isolation leve for each transaction as dex does for postgres + client.WithTxIsolationLevel(sql.LevelSerializable), + ) + + if err := databaseClient.Schema().Create(context.TODO()); err != nil { + return nil, err + } + + return databaseClient, nil +} + +func (m *MySQL) driver() (*entSQL.Driver, error) { + var tlsConfig string + + switch { + case m.SSL.CAFile != "" || m.SSL.CertFile != "" || m.SSL.KeyFile != "": + if err := m.makeTLSConfig(); err != nil { + return nil, fmt.Errorf("failed to make TLS config: %v", err) + } + tlsConfig = mysqlSSLCustom + case m.SSL.Mode == "": + tlsConfig = mysqlSSLTrue + default: + tlsConfig = m.SSL.Mode + } + + drv, err := entSQL.Open("mysql", m.dsn(tlsConfig)) + if err != nil { + return nil, err + } + + if m.MaxIdleConns == 0 { + /* Override default behaviour to fix https://github.com/dexidp/dex/issues/1608 */ + drv.DB().SetMaxIdleConns(0) + } else { + drv.DB().SetMaxIdleConns(m.MaxIdleConns) + } + + return drv, nil +} + +func (m *MySQL) dsn(tlsConfig string) string { + cfg := mysql.Config{ + User: m.User, + Passwd: m.Password, + DBName: m.Database, + AllowNativePasswords: true, + + Timeout: time.Second * time.Duration(m.ConnectionTimeout), + + TLSConfig: tlsConfig, + + ParseTime: true, + Params: make(map[string]string), + } + + if m.Host != "" { + if m.Host[0] != '/' { + cfg.Net = "tcp" + cfg.Addr = m.Host + + if m.Port != 0 { + cfg.Addr = net.JoinHostPort(m.Host, strconv.Itoa(int(m.Port))) + } + } else { + cfg.Net = "unix" + cfg.Addr = m.Host + } + } + + for k, v := range m.params { + cfg.Params[k] = v + } + + return cfg.FormatDSN() +} + +func (m *MySQL) makeTLSConfig() error { + cfg := &tls.Config{} + + if m.SSL.CAFile != "" { + rootCertPool := x509.NewCertPool() + + pem, err := ioutil.ReadFile(m.SSL.CAFile) + if err != nil { + return err + } + + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + return fmt.Errorf("failed to append PEM") + } + cfg.RootCAs = rootCertPool + } + + if m.SSL.CertFile != "" && m.SSL.KeyFile != "" { + clientCert := make([]tls.Certificate, 0, 1) + certs, err := tls.LoadX509KeyPair(m.SSL.CertFile, m.SSL.KeyFile) + if err != nil { + return err + } + clientCert = append(clientCert, certs) + cfg.Certificates = clientCert + } + + mysql.RegisterTLSConfig(mysqlSSLCustom, cfg) + return nil +} diff --git a/storage/ent/mysql_test.go b/storage/ent/mysql_test.go new file mode 100644 index 00000000..fdb2fda1 --- /dev/null +++ b/storage/ent/mysql_test.go @@ -0,0 +1,183 @@ +package ent + +import ( + "os" + "strconv" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/conformance" +) + +const ( + MySQLEntHostEnv = "DEX_MYSQL_ENT_HOST" + MySQLEntPortEnv = "DEX_MYSQL_ENT_PORT" + MySQLEntDatabaseEnv = "DEX_MYSQL_ENT_DATABASE" + MySQLEntUserEnv = "DEX_MYSQL_ENT_USER" + MySQLEntPasswordEnv = "DEX_MYSQL_ENT_PASSWORD" +) + +func mysqlTestConfig(host string, port uint64) *MySQL { + return &MySQL{ + NetworkDB: NetworkDB{ + Database: getenv(MySQLEntDatabaseEnv, "mysql"), + User: getenv(MySQLEntUserEnv, "mysql"), + Password: getenv(MySQLEntPasswordEnv, "mysql"), + Host: host, + Port: uint16(port), + }, + SSL: SSL{ + Mode: mysqlSSLSkipVerify, + }, + } +} + +func newMySQLStorage(host string, port uint64) storage.Storage { + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + cfg := mysqlTestConfig(host, port) + s, err := cfg.Open(logger) + if err != nil { + panic(err) + } + return s +} + +func TestMySQL(t *testing.T) { + host := os.Getenv(MySQLEntHostEnv) + if host == "" { + t.Skipf("test environment variable %s not set, skipping", MySQLEntHostEnv) + } + + port := uint64(3306) + if rawPort := os.Getenv(MySQLEntPortEnv); rawPort != "" { + var err error + + port, err = strconv.ParseUint(rawPort, 10, 32) + require.NoError(t, err, "invalid mysql port %q: %s", rawPort, err) + } + + newStorage := func() storage.Storage { + return newMySQLStorage(host, port) + } + conformance.RunTests(t, newStorage) + conformance.RunTransactionTests(t, newStorage) +} + +func TestMySQLDSN(t *testing.T) { + tests := []struct { + name string + cfg *MySQL + desiredDSN string + }{ + { + name: "Host port", + cfg: &MySQL{ + NetworkDB: NetworkDB{ + Host: "localhost", + Port: uint16(3306), + }, + }, + desiredDSN: "tcp(localhost:3306)/?checkConnLiveness=false&parseTime=true&tls=false&maxAllowedPacket=0", + }, + { + name: "Host with port", + cfg: &MySQL{ + NetworkDB: NetworkDB{ + Host: "localhost:3306", + }, + }, + desiredDSN: "tcp(localhost:3306)/?checkConnLiveness=false&parseTime=true&tls=false&maxAllowedPacket=0", + }, + { + name: "Host ipv6 with port", + cfg: &MySQL{ + NetworkDB: NetworkDB{ + Host: "[a:b:c:d]:3306", + }, + }, + desiredDSN: "tcp([a:b:c:d]:3306)/?checkConnLiveness=false&parseTime=true&tls=false&maxAllowedPacket=0", + }, + { + name: "Credentials and timeout", + cfg: &MySQL{ + NetworkDB: NetworkDB{ + Database: "test", + User: "test", + Password: "test", + ConnectionTimeout: 5, + }, + }, + desiredDSN: "test:test@/test?checkConnLiveness=false&parseTime=true&timeout=5s&tls=false&maxAllowedPacket=0", + }, + { + name: "SSL", + cfg: &MySQL{ + SSL: SSL{ + CAFile: "/ca.crt", + KeyFile: "/cert.crt", + CertFile: "/cert.key", + }, + }, + desiredDSN: "/?checkConnLiveness=false&parseTime=true&tls=false&maxAllowedPacket=0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.desiredDSN, tt.cfg.dsn(mysqlSSLFalse)) + }) + } +} + +func TestMySQLDriver(t *testing.T) { + host := os.Getenv(MySQLEntHostEnv) + if host == "" { + t.Skipf("test environment variable %s not set, skipping", MySQLEntHostEnv) + } + + port := uint64(3306) + if rawPort := os.Getenv(MySQLEntPortEnv); rawPort != "" { + var err error + + port, err = strconv.ParseUint(rawPort, 10, 32) + require.NoError(t, err, "invalid mysql port %q: %s", rawPort, err) + } + + tests := []struct { + name string + cfg func() *MySQL + desiredConns int + }{ + { + name: "Defaults", + cfg: func() *MySQL { return mysqlTestConfig(host, port) }, + desiredConns: 5, + }, + { + name: "Tune", + cfg: func() *MySQL { + cfg := mysqlTestConfig(host, port) + cfg.MaxOpenConns = 101 + return cfg + }, + desiredConns: 101, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + drv, err := tt.cfg().driver() + require.NoError(t, err) + + require.Equal(t, tt.desiredConns, drv.DB().Stats().MaxOpenConnections) + }) + } +} diff --git a/storage/ent/postgres_test.go b/storage/ent/postgres_test.go index 8021e3a1..c8e3a54d 100644 --- a/storage/ent/postgres_test.go +++ b/storage/ent/postgres_test.go @@ -20,13 +20,6 @@ const ( PostgresEntPasswordEnv = "DEX_POSTGRES_ENT_PASSWORD" ) -func getenv(key, defaultVal string) string { - if val := os.Getenv(key); val != "" { - return val - } - return defaultVal -} - func postgresTestConfig(host string, port uint64) *Postgres { return &Postgres{ NetworkDB: NetworkDB{ diff --git a/storage/ent/sqlite.go b/storage/ent/sqlite.go index e6c43cd9..22866b6f 100644 --- a/storage/ent/sqlite.go +++ b/storage/ent/sqlite.go @@ -33,12 +33,10 @@ func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) { return nil, err } + // always allow only one connection to sqlite3, any other thread/go-routine + // attempting concurrent access will have to wait pool := drv.DB() - if s.File == ":memory:" { - // sqlite3 uses file locks to coordinate concurrent access. In memory - // doesn't support this, so limit the number of connections to 1. - pool.SetMaxOpenConns(1) - } + pool.SetMaxOpenConns(1) databaseClient := client.NewDatabase( client.WithClient(db.NewClient(db.Driver(drv))), diff --git a/storage/ent/utils.go b/storage/ent/utils.go new file mode 100644 index 00000000..6f51e065 --- /dev/null +++ b/storage/ent/utils.go @@ -0,0 +1,10 @@ +package ent + +import "os" + +func getenv(key, defaultVal string) string { + if val := os.Getenv(key); val != "" { + return val + } + return defaultVal +} From fb38e1235d51118c5113cefc78bd7370d906cf26 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Mon, 13 Sep 2021 17:48:02 +0400 Subject: [PATCH 2/5] Add dialects Signed-off-by: m.nabokikh --- .github/workflows/ci.yaml | 7 + storage/ent/db/authcode.go | 2 - storage/ent/db/authcode_create.go | 88 +++++++---- storage/ent/db/authcode_delete.go | 7 +- storage/ent/db/authcode_query.go | 41 +++-- storage/ent/db/authcode_update.go | 18 ++- storage/ent/db/authrequest.go | 3 - storage/ent/db/authrequest_create.go | 80 +++++++--- storage/ent/db/authrequest_delete.go | 7 +- storage/ent/db/authrequest_query.go | 41 +++-- storage/ent/db/authrequest_update.go | 18 ++- storage/ent/db/connector_create.go | 62 ++++++-- storage/ent/db/connector_delete.go | 7 +- storage/ent/db/connector_query.go | 41 +++-- storage/ent/db/connector_update.go | 18 ++- storage/ent/db/devicerequest.go | 1 - storage/ent/db/devicerequest_create.go | 72 ++++++--- storage/ent/db/devicerequest_delete.go | 7 +- storage/ent/db/devicerequest_query.go | 41 +++-- storage/ent/db/devicerequest_update.go | 18 ++- storage/ent/db/devicetoken_create.go | 68 +++++++-- storage/ent/db/devicetoken_delete.go | 7 +- storage/ent/db/devicetoken_query.go | 41 +++-- storage/ent/db/devicetoken_update.go | 18 ++- storage/ent/db/ent.go | 22 +-- storage/ent/db/keys.go | 3 - storage/ent/db/keys_create.go | 58 +++++-- storage/ent/db/keys_delete.go | 7 +- storage/ent/db/keys_query.go | 41 +++-- storage/ent/db/keys_update.go | 18 ++- storage/ent/db/migrate/schema.go | 194 +++++++++++------------- storage/ent/db/mutation.go | 90 ++++++++--- storage/ent/db/oauth2client.go | 2 - storage/ent/db/oauth2client_create.go | 64 ++++++-- storage/ent/db/oauth2client_delete.go | 7 +- storage/ent/db/oauth2client_query.go | 41 +++-- storage/ent/db/oauth2client_update.go | 18 ++- storage/ent/db/offlinesession_create.go | 60 ++++++-- storage/ent/db/offlinesession_delete.go | 7 +- storage/ent/db/offlinesession_query.go | 41 +++-- storage/ent/db/offlinesession_update.go | 18 ++- storage/ent/db/password_create.go | 68 +++++++-- storage/ent/db/password_delete.go | 7 +- storage/ent/db/password_query.go | 41 +++-- storage/ent/db/password_update.go | 18 ++- storage/ent/db/refreshtoken.go | 2 - storage/ent/db/refreshtoken_create.go | 86 +++++++---- storage/ent/db/refreshtoken_delete.go | 7 +- storage/ent/db/refreshtoken_query.go | 41 +++-- storage/ent/db/refreshtoken_update.go | 18 ++- storage/ent/db/runtime.go | 32 +++- storage/ent/db/runtime/runtime.go | 4 +- storage/ent/mysql.go | 2 +- storage/ent/mysql_test.go | 12 ++ storage/ent/schema/authcode.go | 3 +- storage/ent/schema/authrequest.go | 3 +- storage/ent/schema/client.go | 1 + storage/ent/schema/connector.go | 1 + storage/ent/schema/devicerequest.go | 3 +- storage/ent/schema/devicetoken.go | 6 +- storage/ent/schema/dialects.go | 21 +++ storage/ent/schema/keys.go | 3 +- storage/ent/schema/refreshtoken.go | 2 + storage/ent/schema/types.go | 9 -- 64 files changed, 1198 insertions(+), 596 deletions(-) create mode 100644 storage/ent/schema/dialects.go delete mode 100644 storage/ent/schema/types.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 55deda32..4fa7081c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -86,29 +86,36 @@ jobs: DEX_MYSQL_PASSWORD: root DEX_MYSQL_HOST: 127.0.0.1 DEX_MYSQL_PORT: ${{ job.services.mysql.ports[3306] }} + DEX_MYSQL_ENT_DATABASE: dex DEX_MYSQL_ENT_USER: root DEX_MYSQL_ENT_PASSWORD: root DEX_MYSQL_ENT_HOST: 127.0.0.1 DEX_MYSQL_ENT_PORT: ${{ job.services.mysql-ent.ports[3306] }} + DEX_POSTGRES_DATABASE: postgres DEX_POSTGRES_USER: postgres DEX_POSTGRES_PASSWORD: postgres DEX_POSTGRES_HOST: localhost DEX_POSTGRES_PORT: ${{ job.services.postgres.ports[5432] }} + DEX_POSTGRES_ENT_DATABASE: postgres DEX_POSTGRES_ENT_USER: postgres DEX_POSTGRES_ENT_PASSWORD: postgres DEX_POSTGRES_ENT_HOST: localhost DEX_POSTGRES_ENT_PORT: ${{ job.services.postgres-ent.ports[5432] }} + DEX_ETCD_ENDPOINTS: http://localhost:${{ job.services.etcd.ports[2379] }} + DEX_LDAP_HOST: localhost DEX_LDAP_PORT: 389 DEX_LDAP_TLS_PORT: 636 + DEX_KEYSTONE_URL: http://localhost:${{ job.services.keystone.ports[5000] }} DEX_KEYSTONE_ADMIN_URL: http://localhost:${{ job.services.keystone.ports[35357] }} DEX_KEYSTONE_ADMIN_USER: demo DEX_KEYSTONE_ADMIN_PASS: DEMO_PASS + DEX_KUBERNETES_CONFIG_PATH: ~/.kube/config - name: Lint diff --git a/storage/ent/db/authcode.go b/storage/ent/db/authcode.go index 29b5e4f5..9af0ee3b 100644 --- a/storage/ent/db/authcode.go +++ b/storage/ent/db/authcode.go @@ -90,7 +90,6 @@ func (ac *AuthCode) assignValues(columns []string, values []interface{}) error { ac.ClientID = value.String } case authcode.FieldScopes: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field scopes", values[i]) } else if value != nil && len(*value) > 0 { @@ -135,7 +134,6 @@ func (ac *AuthCode) assignValues(columns []string, values []interface{}) error { ac.ClaimsEmailVerified = value.Bool } case authcode.FieldClaimsGroups: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field claims_groups", values[i]) } else if value != nil && len(*value) > 0 { diff --git a/storage/ent/db/authcode_create.go b/storage/ent/db/authcode_create.go index a15e682d..eaf2f8ac 100644 --- a/storage/ent/db/authcode_create.go +++ b/storage/ent/db/authcode_create.go @@ -167,11 +167,17 @@ func (acc *AuthCodeCreate) Save(ctx context.Context) (*AuthCode, error) { return nil, err } acc.mutation = mutation - node, err = acc.sqlSave(ctx) + if node, err = acc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(acc.hooks) - 1; i >= 0; i-- { + if acc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = acc.hooks[i](mut) } if _, err := mut.Mutate(ctx, acc.mutation); err != nil { @@ -190,6 +196,19 @@ func (acc *AuthCodeCreate) SaveX(ctx context.Context) *AuthCode { return v } +// Exec executes the query. +func (acc *AuthCodeCreate) Exec(ctx context.Context) error { + _, err := acc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (acc *AuthCodeCreate) ExecX(ctx context.Context) { + if err := acc.Exec(ctx); err != nil { + panic(err) + } +} + // defaults sets the default values of the builder before save. func (acc *AuthCodeCreate) defaults() { if _, ok := acc.mutation.ClaimsPreferredUsername(); !ok { @@ -209,79 +228,79 @@ func (acc *AuthCodeCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (acc *AuthCodeCreate) check() error { if _, ok := acc.mutation.ClientID(); !ok { - return &ValidationError{Name: "client_id", err: errors.New("db: missing required field \"client_id\"")} + return &ValidationError{Name: "client_id", err: errors.New(`db: missing required field "client_id"`)} } if v, ok := acc.mutation.ClientID(); ok { if err := authcode.ClientIDValidator(v); err != nil { - return &ValidationError{Name: "client_id", err: fmt.Errorf("db: validator failed for field \"client_id\": %w", err)} + return &ValidationError{Name: "client_id", err: fmt.Errorf(`db: validator failed for field "client_id": %w`, err)} } } if _, ok := acc.mutation.Nonce(); !ok { - return &ValidationError{Name: "nonce", err: errors.New("db: missing required field \"nonce\"")} + return &ValidationError{Name: "nonce", err: errors.New(`db: missing required field "nonce"`)} } if v, ok := acc.mutation.Nonce(); ok { if err := authcode.NonceValidator(v); err != nil { - return &ValidationError{Name: "nonce", err: fmt.Errorf("db: validator failed for field \"nonce\": %w", err)} + return &ValidationError{Name: "nonce", err: fmt.Errorf(`db: validator failed for field "nonce": %w`, err)} } } if _, ok := acc.mutation.RedirectURI(); !ok { - return &ValidationError{Name: "redirect_uri", err: errors.New("db: missing required field \"redirect_uri\"")} + return &ValidationError{Name: "redirect_uri", err: errors.New(`db: missing required field "redirect_uri"`)} } if v, ok := acc.mutation.RedirectURI(); ok { if err := authcode.RedirectURIValidator(v); err != nil { - return &ValidationError{Name: "redirect_uri", err: fmt.Errorf("db: validator failed for field \"redirect_uri\": %w", err)} + return &ValidationError{Name: "redirect_uri", err: fmt.Errorf(`db: validator failed for field "redirect_uri": %w`, err)} } } if _, ok := acc.mutation.ClaimsUserID(); !ok { - return &ValidationError{Name: "claims_user_id", err: errors.New("db: missing required field \"claims_user_id\"")} + return &ValidationError{Name: "claims_user_id", err: errors.New(`db: missing required field "claims_user_id"`)} } if v, ok := acc.mutation.ClaimsUserID(); ok { if err := authcode.ClaimsUserIDValidator(v); err != nil { - return &ValidationError{Name: "claims_user_id", err: fmt.Errorf("db: validator failed for field \"claims_user_id\": %w", err)} + return &ValidationError{Name: "claims_user_id", err: fmt.Errorf(`db: validator failed for field "claims_user_id": %w`, err)} } } if _, ok := acc.mutation.ClaimsUsername(); !ok { - return &ValidationError{Name: "claims_username", err: errors.New("db: missing required field \"claims_username\"")} + return &ValidationError{Name: "claims_username", err: errors.New(`db: missing required field "claims_username"`)} } if v, ok := acc.mutation.ClaimsUsername(); ok { if err := authcode.ClaimsUsernameValidator(v); err != nil { - return &ValidationError{Name: "claims_username", err: fmt.Errorf("db: validator failed for field \"claims_username\": %w", err)} + return &ValidationError{Name: "claims_username", err: fmt.Errorf(`db: validator failed for field "claims_username": %w`, err)} } } if _, ok := acc.mutation.ClaimsEmail(); !ok { - return &ValidationError{Name: "claims_email", err: errors.New("db: missing required field \"claims_email\"")} + return &ValidationError{Name: "claims_email", err: errors.New(`db: missing required field "claims_email"`)} } if v, ok := acc.mutation.ClaimsEmail(); ok { if err := authcode.ClaimsEmailValidator(v); err != nil { - return &ValidationError{Name: "claims_email", err: fmt.Errorf("db: validator failed for field \"claims_email\": %w", err)} + return &ValidationError{Name: "claims_email", err: fmt.Errorf(`db: validator failed for field "claims_email": %w`, err)} } } if _, ok := acc.mutation.ClaimsEmailVerified(); !ok { - return &ValidationError{Name: "claims_email_verified", err: errors.New("db: missing required field \"claims_email_verified\"")} + return &ValidationError{Name: "claims_email_verified", err: errors.New(`db: missing required field "claims_email_verified"`)} } if _, ok := acc.mutation.ClaimsPreferredUsername(); !ok { - return &ValidationError{Name: "claims_preferred_username", err: errors.New("db: missing required field \"claims_preferred_username\"")} + return &ValidationError{Name: "claims_preferred_username", err: errors.New(`db: missing required field "claims_preferred_username"`)} } if _, ok := acc.mutation.ConnectorID(); !ok { - return &ValidationError{Name: "connector_id", err: errors.New("db: missing required field \"connector_id\"")} + return &ValidationError{Name: "connector_id", err: errors.New(`db: missing required field "connector_id"`)} } if v, ok := acc.mutation.ConnectorID(); ok { if err := authcode.ConnectorIDValidator(v); err != nil { - return &ValidationError{Name: "connector_id", err: fmt.Errorf("db: validator failed for field \"connector_id\": %w", err)} + return &ValidationError{Name: "connector_id", err: fmt.Errorf(`db: validator failed for field "connector_id": %w`, err)} } } if _, ok := acc.mutation.Expiry(); !ok { - return &ValidationError{Name: "expiry", err: errors.New("db: missing required field \"expiry\"")} + return &ValidationError{Name: "expiry", err: errors.New(`db: missing required field "expiry"`)} } if _, ok := acc.mutation.CodeChallenge(); !ok { - return &ValidationError{Name: "code_challenge", err: errors.New("db: missing required field \"code_challenge\"")} + return &ValidationError{Name: "code_challenge", err: errors.New(`db: missing required field "code_challenge"`)} } if _, ok := acc.mutation.CodeChallengeMethod(); !ok { - return &ValidationError{Name: "code_challenge_method", err: errors.New("db: missing required field \"code_challenge_method\"")} + return &ValidationError{Name: "code_challenge_method", err: errors.New(`db: missing required field "code_challenge_method"`)} } if v, ok := acc.mutation.ID(); ok { if err := authcode.IDValidator(v); err != nil { - return &ValidationError{Name: "id", err: fmt.Errorf("db: validator failed for field \"id\": %w", err)} + return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "id": %w`, err)} } } return nil @@ -290,8 +309,8 @@ func (acc *AuthCodeCreate) check() error { func (acc *AuthCodeCreate) sqlSave(ctx context.Context) (*AuthCode, error) { _node, _spec := acc.createSpec() if err := sqlgraph.CreateNode(ctx, acc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -465,17 +484,19 @@ func (accb *AuthCodeCreateBulk) Save(ctx context.Context) ([]*AuthCode, error) { if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, accb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, accb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, accb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } + mutation.id = &nodes[i].ID + mutation.done = true return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -500,3 +521,16 @@ func (accb *AuthCodeCreateBulk) SaveX(ctx context.Context) []*AuthCode { } return v } + +// Exec executes the query. +func (accb *AuthCodeCreateBulk) Exec(ctx context.Context) error { + _, err := accb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (accb *AuthCodeCreateBulk) ExecX(ctx context.Context) { + if err := accb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/authcode_delete.go b/storage/ent/db/authcode_delete.go index c76007b3..9b657236 100644 --- a/storage/ent/db/authcode_delete.go +++ b/storage/ent/db/authcode_delete.go @@ -20,9 +20,9 @@ type AuthCodeDelete struct { mutation *AuthCodeMutation } -// Where adds a new predicate to the AuthCodeDelete builder. +// Where appends a list predicates to the AuthCodeDelete builder. func (acd *AuthCodeDelete) Where(ps ...predicate.AuthCode) *AuthCodeDelete { - acd.mutation.predicates = append(acd.mutation.predicates, ps...) + acd.mutation.Where(ps...) return acd } @@ -46,6 +46,9 @@ func (acd *AuthCodeDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(acd.hooks) - 1; i >= 0; i-- { + if acd.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = acd.hooks[i](mut) } if _, err := mut.Mutate(ctx, acd.mutation); err != nil { diff --git a/storage/ent/db/authcode_query.go b/storage/ent/db/authcode_query.go index 96b6a485..89116dd9 100644 --- a/storage/ent/db/authcode_query.go +++ b/storage/ent/db/authcode_query.go @@ -287,8 +287,8 @@ func (acq *AuthCodeQuery) GroupBy(field string, fields ...string) *AuthCodeGroup // Select(authcode.FieldClientID). // Scan(ctx, &v) // -func (acq *AuthCodeQuery) Select(field string, fields ...string) *AuthCodeSelect { - acq.fields = append([]string{field}, fields...) +func (acq *AuthCodeQuery) Select(fields ...string) *AuthCodeSelect { + acq.fields = append(acq.fields, fields...) return &AuthCodeSelect{AuthCodeQuery: acq} } @@ -398,10 +398,14 @@ func (acq *AuthCodeQuery) querySpec() *sqlgraph.QuerySpec { func (acq *AuthCodeQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(acq.driver.Dialect()) t1 := builder.Table(authcode.Table) - selector := builder.Select(t1.Columns(authcode.Columns...)...).From(t1) + columns := acq.fields + if len(columns) == 0 { + columns = authcode.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if acq.sql != nil { selector = acq.sql - selector.Select(selector.Columns(authcode.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range acq.predicates { p(selector) @@ -669,13 +673,24 @@ func (acgb *AuthCodeGroupBy) sqlScan(ctx context.Context, v interface{}) error { } func (acgb *AuthCodeGroupBy) sqlQuery() *sql.Selector { - selector := acgb.sql - columns := make([]string, 0, len(acgb.fields)+len(acgb.fns)) - columns = append(columns, acgb.fields...) + selector := acgb.sql.Select() + aggregation := make([]string, 0, len(acgb.fns)) for _, fn := range acgb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(acgb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(acgb.fields)+len(acgb.fns)) + for _, f := range acgb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(acgb.fields...)...) } // AuthCodeSelect is the builder for selecting fields of AuthCode entities. @@ -891,16 +906,10 @@ func (acs *AuthCodeSelect) BoolX(ctx context.Context) bool { func (acs *AuthCodeSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := acs.sqlQuery().Query() + query, args := acs.sql.Query() if err := acs.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (acs *AuthCodeSelect) sqlQuery() sql.Querier { - selector := acs.sql - selector.Select(selector.Columns(acs.fields...)...) - return selector -} diff --git a/storage/ent/db/authcode_update.go b/storage/ent/db/authcode_update.go index 08374bd3..f86d04c6 100644 --- a/storage/ent/db/authcode_update.go +++ b/storage/ent/db/authcode_update.go @@ -21,9 +21,9 @@ type AuthCodeUpdate struct { mutation *AuthCodeMutation } -// Where adds a new predicate for the AuthCodeUpdate builder. +// Where appends a list predicates to the AuthCodeUpdate builder. func (acu *AuthCodeUpdate) Where(ps ...predicate.AuthCode) *AuthCodeUpdate { - acu.mutation.predicates = append(acu.mutation.predicates, ps...) + acu.mutation.Where(ps...) return acu } @@ -190,6 +190,9 @@ func (acu *AuthCodeUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(acu.hooks) - 1; i >= 0; i-- { + if acu.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = acu.hooks[i](mut) } if _, err := mut.Mutate(ctx, acu.mutation); err != nil { @@ -405,8 +408,8 @@ func (acu *AuthCodeUpdate) sqlSave(ctx context.Context) (n int, err error) { if n, err = sqlgraph.UpdateNodes(ctx, acu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authcode.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -591,6 +594,9 @@ func (acuo *AuthCodeUpdateOne) Save(ctx context.Context) (*AuthCode, error) { return node, err }) for i := len(acuo.hooks) - 1; i >= 0; i-- { + if acuo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = acuo.hooks[i](mut) } if _, err := mut.Mutate(ctx, acuo.mutation); err != nil { @@ -826,8 +832,8 @@ func (acuo *AuthCodeUpdateOne) sqlSave(ctx context.Context) (_node *AuthCode, er if err = sqlgraph.UpdateNode(ctx, acuo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authcode.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/authrequest.go b/storage/ent/db/authrequest.go index ed64d9f6..ee3d0dd4 100644 --- a/storage/ent/db/authrequest.go +++ b/storage/ent/db/authrequest.go @@ -98,7 +98,6 @@ func (ar *AuthRequest) assignValues(columns []string, values []interface{}) erro ar.ClientID = value.String } case authrequest.FieldScopes: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field scopes", values[i]) } else if value != nil && len(*value) > 0 { @@ -107,7 +106,6 @@ func (ar *AuthRequest) assignValues(columns []string, values []interface{}) erro } } case authrequest.FieldResponseTypes: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field response_types", values[i]) } else if value != nil && len(*value) > 0 { @@ -170,7 +168,6 @@ func (ar *AuthRequest) assignValues(columns []string, values []interface{}) erro ar.ClaimsEmailVerified = value.Bool } case authrequest.FieldClaimsGroups: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field claims_groups", values[i]) } else if value != nil && len(*value) > 0 { diff --git a/storage/ent/db/authrequest_create.go b/storage/ent/db/authrequest_create.go index e7a2c8ce..3e96b284 100644 --- a/storage/ent/db/authrequest_create.go +++ b/storage/ent/db/authrequest_create.go @@ -191,11 +191,17 @@ func (arc *AuthRequestCreate) Save(ctx context.Context) (*AuthRequest, error) { return nil, err } arc.mutation = mutation - node, err = arc.sqlSave(ctx) + if node, err = arc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(arc.hooks) - 1; i >= 0; i-- { + if arc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = arc.hooks[i](mut) } if _, err := mut.Mutate(ctx, arc.mutation); err != nil { @@ -214,6 +220,19 @@ func (arc *AuthRequestCreate) SaveX(ctx context.Context) *AuthRequest { return v } +// Exec executes the query. +func (arc *AuthRequestCreate) Exec(ctx context.Context) error { + _, err := arc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (arc *AuthRequestCreate) ExecX(ctx context.Context) { + if err := arc.Exec(ctx); err != nil { + panic(err) + } +} + // defaults sets the default values of the builder before save. func (arc *AuthRequestCreate) defaults() { if _, ok := arc.mutation.ClaimsPreferredUsername(); !ok { @@ -233,53 +252,53 @@ func (arc *AuthRequestCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (arc *AuthRequestCreate) check() error { if _, ok := arc.mutation.ClientID(); !ok { - return &ValidationError{Name: "client_id", err: errors.New("db: missing required field \"client_id\"")} + return &ValidationError{Name: "client_id", err: errors.New(`db: missing required field "client_id"`)} } if _, ok := arc.mutation.RedirectURI(); !ok { - return &ValidationError{Name: "redirect_uri", err: errors.New("db: missing required field \"redirect_uri\"")} + return &ValidationError{Name: "redirect_uri", err: errors.New(`db: missing required field "redirect_uri"`)} } if _, ok := arc.mutation.Nonce(); !ok { - return &ValidationError{Name: "nonce", err: errors.New("db: missing required field \"nonce\"")} + return &ValidationError{Name: "nonce", err: errors.New(`db: missing required field "nonce"`)} } if _, ok := arc.mutation.State(); !ok { - return &ValidationError{Name: "state", err: errors.New("db: missing required field \"state\"")} + return &ValidationError{Name: "state", err: errors.New(`db: missing required field "state"`)} } if _, ok := arc.mutation.ForceApprovalPrompt(); !ok { - return &ValidationError{Name: "force_approval_prompt", err: errors.New("db: missing required field \"force_approval_prompt\"")} + return &ValidationError{Name: "force_approval_prompt", err: errors.New(`db: missing required field "force_approval_prompt"`)} } if _, ok := arc.mutation.LoggedIn(); !ok { - return &ValidationError{Name: "logged_in", err: errors.New("db: missing required field \"logged_in\"")} + return &ValidationError{Name: "logged_in", err: errors.New(`db: missing required field "logged_in"`)} } if _, ok := arc.mutation.ClaimsUserID(); !ok { - return &ValidationError{Name: "claims_user_id", err: errors.New("db: missing required field \"claims_user_id\"")} + return &ValidationError{Name: "claims_user_id", err: errors.New(`db: missing required field "claims_user_id"`)} } if _, ok := arc.mutation.ClaimsUsername(); !ok { - return &ValidationError{Name: "claims_username", err: errors.New("db: missing required field \"claims_username\"")} + return &ValidationError{Name: "claims_username", err: errors.New(`db: missing required field "claims_username"`)} } if _, ok := arc.mutation.ClaimsEmail(); !ok { - return &ValidationError{Name: "claims_email", err: errors.New("db: missing required field \"claims_email\"")} + return &ValidationError{Name: "claims_email", err: errors.New(`db: missing required field "claims_email"`)} } if _, ok := arc.mutation.ClaimsEmailVerified(); !ok { - return &ValidationError{Name: "claims_email_verified", err: errors.New("db: missing required field \"claims_email_verified\"")} + return &ValidationError{Name: "claims_email_verified", err: errors.New(`db: missing required field "claims_email_verified"`)} } if _, ok := arc.mutation.ClaimsPreferredUsername(); !ok { - return &ValidationError{Name: "claims_preferred_username", err: errors.New("db: missing required field \"claims_preferred_username\"")} + return &ValidationError{Name: "claims_preferred_username", err: errors.New(`db: missing required field "claims_preferred_username"`)} } if _, ok := arc.mutation.ConnectorID(); !ok { - return &ValidationError{Name: "connector_id", err: errors.New("db: missing required field \"connector_id\"")} + return &ValidationError{Name: "connector_id", err: errors.New(`db: missing required field "connector_id"`)} } if _, ok := arc.mutation.Expiry(); !ok { - return &ValidationError{Name: "expiry", err: errors.New("db: missing required field \"expiry\"")} + return &ValidationError{Name: "expiry", err: errors.New(`db: missing required field "expiry"`)} } if _, ok := arc.mutation.CodeChallenge(); !ok { - return &ValidationError{Name: "code_challenge", err: errors.New("db: missing required field \"code_challenge\"")} + return &ValidationError{Name: "code_challenge", err: errors.New(`db: missing required field "code_challenge"`)} } if _, ok := arc.mutation.CodeChallengeMethod(); !ok { - return &ValidationError{Name: "code_challenge_method", err: errors.New("db: missing required field \"code_challenge_method\"")} + return &ValidationError{Name: "code_challenge_method", err: errors.New(`db: missing required field "code_challenge_method"`)} } if v, ok := arc.mutation.ID(); ok { if err := authrequest.IDValidator(v); err != nil { - return &ValidationError{Name: "id", err: fmt.Errorf("db: validator failed for field \"id\": %w", err)} + return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "id": %w`, err)} } } return nil @@ -288,8 +307,8 @@ func (arc *AuthRequestCreate) check() error { func (arc *AuthRequestCreate) sqlSave(ctx context.Context) (*AuthRequest, error) { _node, _spec := arc.createSpec() if err := sqlgraph.CreateNode(ctx, arc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -495,17 +514,19 @@ func (arcb *AuthRequestCreateBulk) Save(ctx context.Context) ([]*AuthRequest, er if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, arcb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, arcb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, arcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } + mutation.id = &nodes[i].ID + mutation.done = true return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -530,3 +551,16 @@ func (arcb *AuthRequestCreateBulk) SaveX(ctx context.Context) []*AuthRequest { } return v } + +// Exec executes the query. +func (arcb *AuthRequestCreateBulk) Exec(ctx context.Context) error { + _, err := arcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (arcb *AuthRequestCreateBulk) ExecX(ctx context.Context) { + if err := arcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/authrequest_delete.go b/storage/ent/db/authrequest_delete.go index e0d0ba66..1e6ff754 100644 --- a/storage/ent/db/authrequest_delete.go +++ b/storage/ent/db/authrequest_delete.go @@ -20,9 +20,9 @@ type AuthRequestDelete struct { mutation *AuthRequestMutation } -// Where adds a new predicate to the AuthRequestDelete builder. +// Where appends a list predicates to the AuthRequestDelete builder. func (ard *AuthRequestDelete) Where(ps ...predicate.AuthRequest) *AuthRequestDelete { - ard.mutation.predicates = append(ard.mutation.predicates, ps...) + ard.mutation.Where(ps...) return ard } @@ -46,6 +46,9 @@ func (ard *AuthRequestDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(ard.hooks) - 1; i >= 0; i-- { + if ard.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = ard.hooks[i](mut) } if _, err := mut.Mutate(ctx, ard.mutation); err != nil { diff --git a/storage/ent/db/authrequest_query.go b/storage/ent/db/authrequest_query.go index b55861cf..e71f8546 100644 --- a/storage/ent/db/authrequest_query.go +++ b/storage/ent/db/authrequest_query.go @@ -287,8 +287,8 @@ func (arq *AuthRequestQuery) GroupBy(field string, fields ...string) *AuthReques // Select(authrequest.FieldClientID). // Scan(ctx, &v) // -func (arq *AuthRequestQuery) Select(field string, fields ...string) *AuthRequestSelect { - arq.fields = append([]string{field}, fields...) +func (arq *AuthRequestQuery) Select(fields ...string) *AuthRequestSelect { + arq.fields = append(arq.fields, fields...) return &AuthRequestSelect{AuthRequestQuery: arq} } @@ -398,10 +398,14 @@ func (arq *AuthRequestQuery) querySpec() *sqlgraph.QuerySpec { func (arq *AuthRequestQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(arq.driver.Dialect()) t1 := builder.Table(authrequest.Table) - selector := builder.Select(t1.Columns(authrequest.Columns...)...).From(t1) + columns := arq.fields + if len(columns) == 0 { + columns = authrequest.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if arq.sql != nil { selector = arq.sql - selector.Select(selector.Columns(authrequest.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range arq.predicates { p(selector) @@ -669,13 +673,24 @@ func (argb *AuthRequestGroupBy) sqlScan(ctx context.Context, v interface{}) erro } func (argb *AuthRequestGroupBy) sqlQuery() *sql.Selector { - selector := argb.sql - columns := make([]string, 0, len(argb.fields)+len(argb.fns)) - columns = append(columns, argb.fields...) + selector := argb.sql.Select() + aggregation := make([]string, 0, len(argb.fns)) for _, fn := range argb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(argb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(argb.fields)+len(argb.fns)) + for _, f := range argb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(argb.fields...)...) } // AuthRequestSelect is the builder for selecting fields of AuthRequest entities. @@ -891,16 +906,10 @@ func (ars *AuthRequestSelect) BoolX(ctx context.Context) bool { func (ars *AuthRequestSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := ars.sqlQuery().Query() + query, args := ars.sql.Query() if err := ars.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (ars *AuthRequestSelect) sqlQuery() sql.Querier { - selector := ars.sql - selector.Select(selector.Columns(ars.fields...)...) - return selector -} diff --git a/storage/ent/db/authrequest_update.go b/storage/ent/db/authrequest_update.go index 2d3f8594..2306d2e6 100644 --- a/storage/ent/db/authrequest_update.go +++ b/storage/ent/db/authrequest_update.go @@ -21,9 +21,9 @@ type AuthRequestUpdate struct { mutation *AuthRequestMutation } -// Where adds a new predicate for the AuthRequestUpdate builder. +// Where appends a list predicates to the AuthRequestUpdate builder. func (aru *AuthRequestUpdate) Where(ps ...predicate.AuthRequest) *AuthRequestUpdate { - aru.mutation.predicates = append(aru.mutation.predicates, ps...) + aru.mutation.Where(ps...) return aru } @@ -214,6 +214,9 @@ func (aru *AuthRequestUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(aru.hooks) - 1; i >= 0; i-- { + if aru.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = aru.hooks[i](mut) } if _, err := mut.Mutate(ctx, aru.mutation); err != nil { @@ -423,8 +426,8 @@ func (aru *AuthRequestUpdate) sqlSave(ctx context.Context) (n int, err error) { if n, err = sqlgraph.UpdateNodes(ctx, aru.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authrequest.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -633,6 +636,9 @@ func (aruo *AuthRequestUpdateOne) Save(ctx context.Context) (*AuthRequest, error return node, err }) for i := len(aruo.hooks) - 1; i >= 0; i-- { + if aruo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = aruo.hooks[i](mut) } if _, err := mut.Mutate(ctx, aruo.mutation); err != nil { @@ -862,8 +868,8 @@ func (aruo *AuthRequestUpdateOne) sqlSave(ctx context.Context) (_node *AuthReque if err = sqlgraph.UpdateNode(ctx, aruo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authrequest.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/connector_create.go b/storage/ent/db/connector_create.go index eebe4d05..387231b9 100644 --- a/storage/ent/db/connector_create.go +++ b/storage/ent/db/connector_create.go @@ -75,11 +75,17 @@ func (cc *ConnectorCreate) Save(ctx context.Context) (*Connector, error) { return nil, err } cc.mutation = mutation - node, err = cc.sqlSave(ctx) + if node, err = cc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(cc.hooks) - 1; i >= 0; i-- { + if cc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = cc.hooks[i](mut) } if _, err := mut.Mutate(ctx, cc.mutation); err != nil { @@ -98,33 +104,46 @@ func (cc *ConnectorCreate) SaveX(ctx context.Context) *Connector { return v } +// Exec executes the query. +func (cc *ConnectorCreate) Exec(ctx context.Context) error { + _, err := cc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (cc *ConnectorCreate) ExecX(ctx context.Context) { + if err := cc.Exec(ctx); err != nil { + panic(err) + } +} + // check runs all checks and user-defined validators on the builder. func (cc *ConnectorCreate) check() error { if _, ok := cc.mutation.GetType(); !ok { - return &ValidationError{Name: "type", err: errors.New("db: missing required field \"type\"")} + return &ValidationError{Name: "type", err: errors.New(`db: missing required field "type"`)} } if v, ok := cc.mutation.GetType(); ok { if err := connector.TypeValidator(v); err != nil { - return &ValidationError{Name: "type", err: fmt.Errorf("db: validator failed for field \"type\": %w", err)} + return &ValidationError{Name: "type", err: fmt.Errorf(`db: validator failed for field "type": %w`, err)} } } if _, ok := cc.mutation.Name(); !ok { - return &ValidationError{Name: "name", err: errors.New("db: missing required field \"name\"")} + return &ValidationError{Name: "name", err: errors.New(`db: missing required field "name"`)} } if v, ok := cc.mutation.Name(); ok { if err := connector.NameValidator(v); err != nil { - return &ValidationError{Name: "name", err: fmt.Errorf("db: validator failed for field \"name\": %w", err)} + return &ValidationError{Name: "name", err: fmt.Errorf(`db: validator failed for field "name": %w`, err)} } } if _, ok := cc.mutation.ResourceVersion(); !ok { - return &ValidationError{Name: "resource_version", err: errors.New("db: missing required field \"resource_version\"")} + return &ValidationError{Name: "resource_version", err: errors.New(`db: missing required field "resource_version"`)} } if _, ok := cc.mutation.Config(); !ok { - return &ValidationError{Name: "config", err: errors.New("db: missing required field \"config\"")} + return &ValidationError{Name: "config", err: errors.New(`db: missing required field "config"`)} } if v, ok := cc.mutation.ID(); ok { if err := connector.IDValidator(v); err != nil { - return &ValidationError{Name: "id", err: fmt.Errorf("db: validator failed for field \"id\": %w", err)} + return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "id": %w`, err)} } } return nil @@ -133,8 +152,8 @@ func (cc *ConnectorCreate) check() error { func (cc *ConnectorCreate) sqlSave(ctx context.Context) (*Connector, error) { _node, _spec := cc.createSpec() if err := sqlgraph.CreateNode(ctx, cc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -219,17 +238,19 @@ func (ccb *ConnectorCreateBulk) Save(ctx context.Context) ([]*Connector, error) if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, ccb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, ccb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, ccb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } + mutation.id = &nodes[i].ID + mutation.done = true return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -254,3 +275,16 @@ func (ccb *ConnectorCreateBulk) SaveX(ctx context.Context) []*Connector { } return v } + +// Exec executes the query. +func (ccb *ConnectorCreateBulk) Exec(ctx context.Context) error { + _, err := ccb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (ccb *ConnectorCreateBulk) ExecX(ctx context.Context) { + if err := ccb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/connector_delete.go b/storage/ent/db/connector_delete.go index 1368fcc1..1cad771f 100644 --- a/storage/ent/db/connector_delete.go +++ b/storage/ent/db/connector_delete.go @@ -20,9 +20,9 @@ type ConnectorDelete struct { mutation *ConnectorMutation } -// Where adds a new predicate to the ConnectorDelete builder. +// Where appends a list predicates to the ConnectorDelete builder. func (cd *ConnectorDelete) Where(ps ...predicate.Connector) *ConnectorDelete { - cd.mutation.predicates = append(cd.mutation.predicates, ps...) + cd.mutation.Where(ps...) return cd } @@ -46,6 +46,9 @@ func (cd *ConnectorDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(cd.hooks) - 1; i >= 0; i-- { + if cd.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = cd.hooks[i](mut) } if _, err := mut.Mutate(ctx, cd.mutation); err != nil { diff --git a/storage/ent/db/connector_query.go b/storage/ent/db/connector_query.go index 2b4c7872..9bcf368a 100644 --- a/storage/ent/db/connector_query.go +++ b/storage/ent/db/connector_query.go @@ -287,8 +287,8 @@ func (cq *ConnectorQuery) GroupBy(field string, fields ...string) *ConnectorGrou // Select(connector.FieldType). // Scan(ctx, &v) // -func (cq *ConnectorQuery) Select(field string, fields ...string) *ConnectorSelect { - cq.fields = append([]string{field}, fields...) +func (cq *ConnectorQuery) Select(fields ...string) *ConnectorSelect { + cq.fields = append(cq.fields, fields...) return &ConnectorSelect{ConnectorQuery: cq} } @@ -398,10 +398,14 @@ func (cq *ConnectorQuery) querySpec() *sqlgraph.QuerySpec { func (cq *ConnectorQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(cq.driver.Dialect()) t1 := builder.Table(connector.Table) - selector := builder.Select(t1.Columns(connector.Columns...)...).From(t1) + columns := cq.fields + if len(columns) == 0 { + columns = connector.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if cq.sql != nil { selector = cq.sql - selector.Select(selector.Columns(connector.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range cq.predicates { p(selector) @@ -669,13 +673,24 @@ func (cgb *ConnectorGroupBy) sqlScan(ctx context.Context, v interface{}) error { } func (cgb *ConnectorGroupBy) sqlQuery() *sql.Selector { - selector := cgb.sql - columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) - columns = append(columns, cgb.fields...) + selector := cgb.sql.Select() + aggregation := make([]string, 0, len(cgb.fns)) for _, fn := range cgb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(cgb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) + for _, f := range cgb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(cgb.fields...)...) } // ConnectorSelect is the builder for selecting fields of Connector entities. @@ -891,16 +906,10 @@ func (cs *ConnectorSelect) BoolX(ctx context.Context) bool { func (cs *ConnectorSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := cs.sqlQuery().Query() + query, args := cs.sql.Query() if err := cs.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (cs *ConnectorSelect) sqlQuery() sql.Querier { - selector := cs.sql - selector.Select(selector.Columns(cs.fields...)...) - return selector -} diff --git a/storage/ent/db/connector_update.go b/storage/ent/db/connector_update.go index 90c972e4..1bc9ffb5 100644 --- a/storage/ent/db/connector_update.go +++ b/storage/ent/db/connector_update.go @@ -20,9 +20,9 @@ type ConnectorUpdate struct { mutation *ConnectorMutation } -// Where adds a new predicate for the ConnectorUpdate builder. +// Where appends a list predicates to the ConnectorUpdate builder. func (cu *ConnectorUpdate) Where(ps ...predicate.Connector) *ConnectorUpdate { - cu.mutation.predicates = append(cu.mutation.predicates, ps...) + cu.mutation.Where(ps...) return cu } @@ -81,6 +81,9 @@ func (cu *ConnectorUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(cu.hooks) - 1; i >= 0; i-- { + if cu.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = cu.hooks[i](mut) } if _, err := mut.Mutate(ctx, cu.mutation); err != nil { @@ -176,8 +179,8 @@ func (cu *ConnectorUpdate) sqlSave(ctx context.Context) (n int, err error) { if n, err = sqlgraph.UpdateNodes(ctx, cu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{connector.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -254,6 +257,9 @@ func (cuo *ConnectorUpdateOne) Save(ctx context.Context) (*Connector, error) { return node, err }) for i := len(cuo.hooks) - 1; i >= 0; i-- { + if cuo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = cuo.hooks[i](mut) } if _, err := mut.Mutate(ctx, cuo.mutation); err != nil { @@ -369,8 +375,8 @@ func (cuo *ConnectorUpdateOne) sqlSave(ctx context.Context) (_node *Connector, e if err = sqlgraph.UpdateNode(ctx, cuo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{connector.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/devicerequest.go b/storage/ent/db/devicerequest.go index d50a7c83..24da4a46 100644 --- a/storage/ent/db/devicerequest.go +++ b/storage/ent/db/devicerequest.go @@ -90,7 +90,6 @@ func (dr *DeviceRequest) assignValues(columns []string, values []interface{}) er dr.ClientSecret = value.String } case devicerequest.FieldScopes: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field scopes", values[i]) } else if value != nil && len(*value) > 0 { diff --git a/storage/ent/db/devicerequest_create.go b/storage/ent/db/devicerequest_create.go index 70599fed..102d028e 100644 --- a/storage/ent/db/devicerequest_create.go +++ b/storage/ent/db/devicerequest_create.go @@ -82,11 +82,17 @@ func (drc *DeviceRequestCreate) Save(ctx context.Context) (*DeviceRequest, error return nil, err } drc.mutation = mutation - node, err = drc.sqlSave(ctx) + if node, err = drc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(drc.hooks) - 1; i >= 0; i-- { + if drc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = drc.hooks[i](mut) } if _, err := mut.Mutate(ctx, drc.mutation); err != nil { @@ -105,42 +111,55 @@ func (drc *DeviceRequestCreate) SaveX(ctx context.Context) *DeviceRequest { return v } +// Exec executes the query. +func (drc *DeviceRequestCreate) Exec(ctx context.Context) error { + _, err := drc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (drc *DeviceRequestCreate) ExecX(ctx context.Context) { + if err := drc.Exec(ctx); err != nil { + panic(err) + } +} + // check runs all checks and user-defined validators on the builder. func (drc *DeviceRequestCreate) check() error { if _, ok := drc.mutation.UserCode(); !ok { - return &ValidationError{Name: "user_code", err: errors.New("db: missing required field \"user_code\"")} + return &ValidationError{Name: "user_code", err: errors.New(`db: missing required field "user_code"`)} } if v, ok := drc.mutation.UserCode(); ok { if err := devicerequest.UserCodeValidator(v); err != nil { - return &ValidationError{Name: "user_code", err: fmt.Errorf("db: validator failed for field \"user_code\": %w", err)} + return &ValidationError{Name: "user_code", err: fmt.Errorf(`db: validator failed for field "user_code": %w`, err)} } } if _, ok := drc.mutation.DeviceCode(); !ok { - return &ValidationError{Name: "device_code", err: errors.New("db: missing required field \"device_code\"")} + return &ValidationError{Name: "device_code", err: errors.New(`db: missing required field "device_code"`)} } if v, ok := drc.mutation.DeviceCode(); ok { if err := devicerequest.DeviceCodeValidator(v); err != nil { - return &ValidationError{Name: "device_code", err: fmt.Errorf("db: validator failed for field \"device_code\": %w", err)} + return &ValidationError{Name: "device_code", err: fmt.Errorf(`db: validator failed for field "device_code": %w`, err)} } } if _, ok := drc.mutation.ClientID(); !ok { - return &ValidationError{Name: "client_id", err: errors.New("db: missing required field \"client_id\"")} + return &ValidationError{Name: "client_id", err: errors.New(`db: missing required field "client_id"`)} } if v, ok := drc.mutation.ClientID(); ok { if err := devicerequest.ClientIDValidator(v); err != nil { - return &ValidationError{Name: "client_id", err: fmt.Errorf("db: validator failed for field \"client_id\": %w", err)} + return &ValidationError{Name: "client_id", err: fmt.Errorf(`db: validator failed for field "client_id": %w`, err)} } } if _, ok := drc.mutation.ClientSecret(); !ok { - return &ValidationError{Name: "client_secret", err: errors.New("db: missing required field \"client_secret\"")} + return &ValidationError{Name: "client_secret", err: errors.New(`db: missing required field "client_secret"`)} } if v, ok := drc.mutation.ClientSecret(); ok { if err := devicerequest.ClientSecretValidator(v); err != nil { - return &ValidationError{Name: "client_secret", err: fmt.Errorf("db: validator failed for field \"client_secret\": %w", err)} + return &ValidationError{Name: "client_secret", err: fmt.Errorf(`db: validator failed for field "client_secret": %w`, err)} } } if _, ok := drc.mutation.Expiry(); !ok { - return &ValidationError{Name: "expiry", err: errors.New("db: missing required field \"expiry\"")} + return &ValidationError{Name: "expiry", err: errors.New(`db: missing required field "expiry"`)} } return nil } @@ -148,8 +167,8 @@ func (drc *DeviceRequestCreate) check() error { func (drc *DeviceRequestCreate) sqlSave(ctx context.Context) (*DeviceRequest, error) { _node, _spec := drc.createSpec() if err := sqlgraph.CreateNode(ctx, drc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -248,19 +267,23 @@ func (drcb *DeviceRequestCreateBulk) Save(ctx context.Context) ([]*DeviceRequest if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, drcb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, drcb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, drcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } - id := specs[i].ID.Value.(int64) - nodes[i].ID = int(id) + mutation.id = &nodes[i].ID + mutation.done = true + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -285,3 +308,16 @@ func (drcb *DeviceRequestCreateBulk) SaveX(ctx context.Context) []*DeviceRequest } return v } + +// Exec executes the query. +func (drcb *DeviceRequestCreateBulk) Exec(ctx context.Context) error { + _, err := drcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (drcb *DeviceRequestCreateBulk) ExecX(ctx context.Context) { + if err := drcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/devicerequest_delete.go b/storage/ent/db/devicerequest_delete.go index 34c0b890..1a642622 100644 --- a/storage/ent/db/devicerequest_delete.go +++ b/storage/ent/db/devicerequest_delete.go @@ -20,9 +20,9 @@ type DeviceRequestDelete struct { mutation *DeviceRequestMutation } -// Where adds a new predicate to the DeviceRequestDelete builder. +// Where appends a list predicates to the DeviceRequestDelete builder. func (drd *DeviceRequestDelete) Where(ps ...predicate.DeviceRequest) *DeviceRequestDelete { - drd.mutation.predicates = append(drd.mutation.predicates, ps...) + drd.mutation.Where(ps...) return drd } @@ -46,6 +46,9 @@ func (drd *DeviceRequestDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(drd.hooks) - 1; i >= 0; i-- { + if drd.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = drd.hooks[i](mut) } if _, err := mut.Mutate(ctx, drd.mutation); err != nil { diff --git a/storage/ent/db/devicerequest_query.go b/storage/ent/db/devicerequest_query.go index 08c76871..350b7ae6 100644 --- a/storage/ent/db/devicerequest_query.go +++ b/storage/ent/db/devicerequest_query.go @@ -287,8 +287,8 @@ func (drq *DeviceRequestQuery) GroupBy(field string, fields ...string) *DeviceRe // Select(devicerequest.FieldUserCode). // Scan(ctx, &v) // -func (drq *DeviceRequestQuery) Select(field string, fields ...string) *DeviceRequestSelect { - drq.fields = append([]string{field}, fields...) +func (drq *DeviceRequestQuery) Select(fields ...string) *DeviceRequestSelect { + drq.fields = append(drq.fields, fields...) return &DeviceRequestSelect{DeviceRequestQuery: drq} } @@ -398,10 +398,14 @@ func (drq *DeviceRequestQuery) querySpec() *sqlgraph.QuerySpec { func (drq *DeviceRequestQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(drq.driver.Dialect()) t1 := builder.Table(devicerequest.Table) - selector := builder.Select(t1.Columns(devicerequest.Columns...)...).From(t1) + columns := drq.fields + if len(columns) == 0 { + columns = devicerequest.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if drq.sql != nil { selector = drq.sql - selector.Select(selector.Columns(devicerequest.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range drq.predicates { p(selector) @@ -669,13 +673,24 @@ func (drgb *DeviceRequestGroupBy) sqlScan(ctx context.Context, v interface{}) er } func (drgb *DeviceRequestGroupBy) sqlQuery() *sql.Selector { - selector := drgb.sql - columns := make([]string, 0, len(drgb.fields)+len(drgb.fns)) - columns = append(columns, drgb.fields...) + selector := drgb.sql.Select() + aggregation := make([]string, 0, len(drgb.fns)) for _, fn := range drgb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(drgb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(drgb.fields)+len(drgb.fns)) + for _, f := range drgb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(drgb.fields...)...) } // DeviceRequestSelect is the builder for selecting fields of DeviceRequest entities. @@ -891,16 +906,10 @@ func (drs *DeviceRequestSelect) BoolX(ctx context.Context) bool { func (drs *DeviceRequestSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := drs.sqlQuery().Query() + query, args := drs.sql.Query() if err := drs.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (drs *DeviceRequestSelect) sqlQuery() sql.Querier { - selector := drs.sql - selector.Select(selector.Columns(drs.fields...)...) - return selector -} diff --git a/storage/ent/db/devicerequest_update.go b/storage/ent/db/devicerequest_update.go index d71ca0ed..21c2f300 100644 --- a/storage/ent/db/devicerequest_update.go +++ b/storage/ent/db/devicerequest_update.go @@ -21,9 +21,9 @@ type DeviceRequestUpdate struct { mutation *DeviceRequestMutation } -// Where adds a new predicate for the DeviceRequestUpdate builder. +// Where appends a list predicates to the DeviceRequestUpdate builder. func (dru *DeviceRequestUpdate) Where(ps ...predicate.DeviceRequest) *DeviceRequestUpdate { - dru.mutation.predicates = append(dru.mutation.predicates, ps...) + dru.mutation.Where(ps...) return dru } @@ -100,6 +100,9 @@ func (dru *DeviceRequestUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(dru.hooks) - 1; i >= 0; i-- { + if dru.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = dru.hooks[i](mut) } if _, err := mut.Mutate(ctx, dru.mutation); err != nil { @@ -225,8 +228,8 @@ func (dru *DeviceRequestUpdate) sqlSave(ctx context.Context) (n int, err error) if n, err = sqlgraph.UpdateNodes(ctx, dru.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{devicerequest.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -321,6 +324,9 @@ func (druo *DeviceRequestUpdateOne) Save(ctx context.Context) (*DeviceRequest, e return node, err }) for i := len(druo.hooks) - 1; i >= 0; i-- { + if druo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = druo.hooks[i](mut) } if _, err := mut.Mutate(ctx, druo.mutation); err != nil { @@ -466,8 +472,8 @@ func (druo *DeviceRequestUpdateOne) sqlSave(ctx context.Context) (_node *DeviceR if err = sqlgraph.UpdateNode(ctx, druo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{devicerequest.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/devicetoken_create.go b/storage/ent/db/devicetoken_create.go index 50ed1aad..db73ad4d 100644 --- a/storage/ent/db/devicetoken_create.go +++ b/storage/ent/db/devicetoken_create.go @@ -82,11 +82,17 @@ func (dtc *DeviceTokenCreate) Save(ctx context.Context) (*DeviceToken, error) { return nil, err } dtc.mutation = mutation - node, err = dtc.sqlSave(ctx) + if node, err = dtc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(dtc.hooks) - 1; i >= 0; i-- { + if dtc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = dtc.hooks[i](mut) } if _, err := mut.Mutate(ctx, dtc.mutation); err != nil { @@ -105,32 +111,45 @@ func (dtc *DeviceTokenCreate) SaveX(ctx context.Context) *DeviceToken { return v } +// Exec executes the query. +func (dtc *DeviceTokenCreate) Exec(ctx context.Context) error { + _, err := dtc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dtc *DeviceTokenCreate) ExecX(ctx context.Context) { + if err := dtc.Exec(ctx); err != nil { + panic(err) + } +} + // check runs all checks and user-defined validators on the builder. func (dtc *DeviceTokenCreate) check() error { if _, ok := dtc.mutation.DeviceCode(); !ok { - return &ValidationError{Name: "device_code", err: errors.New("db: missing required field \"device_code\"")} + return &ValidationError{Name: "device_code", err: errors.New(`db: missing required field "device_code"`)} } if v, ok := dtc.mutation.DeviceCode(); ok { if err := devicetoken.DeviceCodeValidator(v); err != nil { - return &ValidationError{Name: "device_code", err: fmt.Errorf("db: validator failed for field \"device_code\": %w", err)} + return &ValidationError{Name: "device_code", err: fmt.Errorf(`db: validator failed for field "device_code": %w`, err)} } } if _, ok := dtc.mutation.Status(); !ok { - return &ValidationError{Name: "status", err: errors.New("db: missing required field \"status\"")} + return &ValidationError{Name: "status", err: errors.New(`db: missing required field "status"`)} } if v, ok := dtc.mutation.Status(); ok { if err := devicetoken.StatusValidator(v); err != nil { - return &ValidationError{Name: "status", err: fmt.Errorf("db: validator failed for field \"status\": %w", err)} + return &ValidationError{Name: "status", err: fmt.Errorf(`db: validator failed for field "status": %w`, err)} } } if _, ok := dtc.mutation.Expiry(); !ok { - return &ValidationError{Name: "expiry", err: errors.New("db: missing required field \"expiry\"")} + return &ValidationError{Name: "expiry", err: errors.New(`db: missing required field "expiry"`)} } if _, ok := dtc.mutation.LastRequest(); !ok { - return &ValidationError{Name: "last_request", err: errors.New("db: missing required field \"last_request\"")} + return &ValidationError{Name: "last_request", err: errors.New(`db: missing required field "last_request"`)} } if _, ok := dtc.mutation.PollInterval(); !ok { - return &ValidationError{Name: "poll_interval", err: errors.New("db: missing required field \"poll_interval\"")} + return &ValidationError{Name: "poll_interval", err: errors.New(`db: missing required field "poll_interval"`)} } return nil } @@ -138,8 +157,8 @@ func (dtc *DeviceTokenCreate) check() error { func (dtc *DeviceTokenCreate) sqlSave(ctx context.Context) (*DeviceToken, error) { _node, _spec := dtc.createSpec() if err := sqlgraph.CreateNode(ctx, dtc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -238,19 +257,23 @@ func (dtcb *DeviceTokenCreateBulk) Save(ctx context.Context) ([]*DeviceToken, er if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, dtcb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, dtcb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, dtcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } - id := specs[i].ID.Value.(int64) - nodes[i].ID = int(id) + mutation.id = &nodes[i].ID + mutation.done = true + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -275,3 +298,16 @@ func (dtcb *DeviceTokenCreateBulk) SaveX(ctx context.Context) []*DeviceToken { } return v } + +// Exec executes the query. +func (dtcb *DeviceTokenCreateBulk) Exec(ctx context.Context) error { + _, err := dtcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dtcb *DeviceTokenCreateBulk) ExecX(ctx context.Context) { + if err := dtcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/devicetoken_delete.go b/storage/ent/db/devicetoken_delete.go index 0ea9069d..f23cc50e 100644 --- a/storage/ent/db/devicetoken_delete.go +++ b/storage/ent/db/devicetoken_delete.go @@ -20,9 +20,9 @@ type DeviceTokenDelete struct { mutation *DeviceTokenMutation } -// Where adds a new predicate to the DeviceTokenDelete builder. +// Where appends a list predicates to the DeviceTokenDelete builder. func (dtd *DeviceTokenDelete) Where(ps ...predicate.DeviceToken) *DeviceTokenDelete { - dtd.mutation.predicates = append(dtd.mutation.predicates, ps...) + dtd.mutation.Where(ps...) return dtd } @@ -46,6 +46,9 @@ func (dtd *DeviceTokenDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(dtd.hooks) - 1; i >= 0; i-- { + if dtd.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = dtd.hooks[i](mut) } if _, err := mut.Mutate(ctx, dtd.mutation); err != nil { diff --git a/storage/ent/db/devicetoken_query.go b/storage/ent/db/devicetoken_query.go index e085440d..2fb2712f 100644 --- a/storage/ent/db/devicetoken_query.go +++ b/storage/ent/db/devicetoken_query.go @@ -287,8 +287,8 @@ func (dtq *DeviceTokenQuery) GroupBy(field string, fields ...string) *DeviceToke // Select(devicetoken.FieldDeviceCode). // Scan(ctx, &v) // -func (dtq *DeviceTokenQuery) Select(field string, fields ...string) *DeviceTokenSelect { - dtq.fields = append([]string{field}, fields...) +func (dtq *DeviceTokenQuery) Select(fields ...string) *DeviceTokenSelect { + dtq.fields = append(dtq.fields, fields...) return &DeviceTokenSelect{DeviceTokenQuery: dtq} } @@ -398,10 +398,14 @@ func (dtq *DeviceTokenQuery) querySpec() *sqlgraph.QuerySpec { func (dtq *DeviceTokenQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(dtq.driver.Dialect()) t1 := builder.Table(devicetoken.Table) - selector := builder.Select(t1.Columns(devicetoken.Columns...)...).From(t1) + columns := dtq.fields + if len(columns) == 0 { + columns = devicetoken.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if dtq.sql != nil { selector = dtq.sql - selector.Select(selector.Columns(devicetoken.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range dtq.predicates { p(selector) @@ -669,13 +673,24 @@ func (dtgb *DeviceTokenGroupBy) sqlScan(ctx context.Context, v interface{}) erro } func (dtgb *DeviceTokenGroupBy) sqlQuery() *sql.Selector { - selector := dtgb.sql - columns := make([]string, 0, len(dtgb.fields)+len(dtgb.fns)) - columns = append(columns, dtgb.fields...) + selector := dtgb.sql.Select() + aggregation := make([]string, 0, len(dtgb.fns)) for _, fn := range dtgb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(dtgb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(dtgb.fields)+len(dtgb.fns)) + for _, f := range dtgb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(dtgb.fields...)...) } // DeviceTokenSelect is the builder for selecting fields of DeviceToken entities. @@ -891,16 +906,10 @@ func (dts *DeviceTokenSelect) BoolX(ctx context.Context) bool { func (dts *DeviceTokenSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := dts.sqlQuery().Query() + query, args := dts.sql.Query() if err := dts.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (dts *DeviceTokenSelect) sqlQuery() sql.Querier { - selector := dts.sql - selector.Select(selector.Columns(dts.fields...)...) - return selector -} diff --git a/storage/ent/db/devicetoken_update.go b/storage/ent/db/devicetoken_update.go index 51a4efe0..a61a8b0b 100644 --- a/storage/ent/db/devicetoken_update.go +++ b/storage/ent/db/devicetoken_update.go @@ -21,9 +21,9 @@ type DeviceTokenUpdate struct { mutation *DeviceTokenMutation } -// Where adds a new predicate for the DeviceTokenUpdate builder. +// Where appends a list predicates to the DeviceTokenUpdate builder. func (dtu *DeviceTokenUpdate) Where(ps ...predicate.DeviceToken) *DeviceTokenUpdate { - dtu.mutation.predicates = append(dtu.mutation.predicates, ps...) + dtu.mutation.Where(ps...) return dtu } @@ -107,6 +107,9 @@ func (dtu *DeviceTokenUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(dtu.hooks) - 1; i >= 0; i-- { + if dtu.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = dtu.hooks[i](mut) } if _, err := mut.Mutate(ctx, dtu.mutation); err != nil { @@ -229,8 +232,8 @@ func (dtu *DeviceTokenUpdate) sqlSave(ctx context.Context) (n int, err error) { if n, err = sqlgraph.UpdateNodes(ctx, dtu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{devicetoken.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -332,6 +335,9 @@ func (dtuo *DeviceTokenUpdateOne) Save(ctx context.Context) (*DeviceToken, error return node, err }) for i := len(dtuo.hooks) - 1; i >= 0; i-- { + if dtuo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = dtuo.hooks[i](mut) } if _, err := mut.Mutate(ctx, dtuo.mutation); err != nil { @@ -474,8 +480,8 @@ func (dtuo *DeviceTokenUpdateOne) sqlSave(ctx context.Context) (_node *DeviceTok if err = sqlgraph.UpdateNode(ctx, dtuo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{devicetoken.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/ent.go b/storage/ent/db/ent.go index d84e721d..42d7ad52 100644 --- a/storage/ent/db/ent.go +++ b/storage/ent/db/ent.go @@ -7,9 +7,7 @@ import ( "fmt" "entgo.io/ent" - "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" - "entgo.io/ent/dialect/sql/sqlgraph" "github.com/dexidp/dex/storage/ent/db/authcode" "github.com/dexidp/dex/storage/ent/db/authrequest" "github.com/dexidp/dex/storage/ent/db/connector" @@ -177,7 +175,7 @@ func (e *ValidationError) Unwrap() error { return e.err } -// IsValidationError returns a boolean indicating whether the error is a validaton error. +// IsValidationError returns a boolean indicating whether the error is a validation error. func IsValidationError(err error) bool { if err == nil { return false @@ -277,21 +275,3 @@ func IsConstraintError(err error) bool { var e *ConstraintError return errors.As(err, &e) } - -func isSQLConstraintError(err error) (*ConstraintError, bool) { - if sqlgraph.IsConstraintError(err) { - return &ConstraintError{err.Error(), err}, true - } - return nil, false -} - -// rollback calls tx.Rollback and wraps the given error with the rollback error if present. -func rollback(tx dialect.Tx, err error) error { - if rerr := tx.Rollback(); rerr != nil { - err = fmt.Errorf("%w: %v", err, rerr) - } - if err, ok := isSQLConstraintError(err); ok { - return err - } - return err -} diff --git a/storage/ent/db/keys.go b/storage/ent/db/keys.go index d0312c92..acdd8257 100644 --- a/storage/ent/db/keys.go +++ b/storage/ent/db/keys.go @@ -62,7 +62,6 @@ func (k *Keys) assignValues(columns []string, values []interface{}) error { k.ID = value.String } case keys.FieldVerificationKeys: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field verification_keys", values[i]) } else if value != nil && len(*value) > 0 { @@ -71,7 +70,6 @@ func (k *Keys) assignValues(columns []string, values []interface{}) error { } } case keys.FieldSigningKey: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field signing_key", values[i]) } else if value != nil && len(*value) > 0 { @@ -80,7 +78,6 @@ func (k *Keys) assignValues(columns []string, values []interface{}) error { } } case keys.FieldSigningKeyPub: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field signing_key_pub", values[i]) } else if value != nil && len(*value) > 0 { diff --git a/storage/ent/db/keys_create.go b/storage/ent/db/keys_create.go index 4dfae78b..18e58f8b 100644 --- a/storage/ent/db/keys_create.go +++ b/storage/ent/db/keys_create.go @@ -78,11 +78,17 @@ func (kc *KeysCreate) Save(ctx context.Context) (*Keys, error) { return nil, err } kc.mutation = mutation - node, err = kc.sqlSave(ctx) + if node, err = kc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(kc.hooks) - 1; i >= 0; i-- { + if kc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = kc.hooks[i](mut) } if _, err := mut.Mutate(ctx, kc.mutation); err != nil { @@ -101,23 +107,36 @@ func (kc *KeysCreate) SaveX(ctx context.Context) *Keys { return v } +// Exec executes the query. +func (kc *KeysCreate) Exec(ctx context.Context) error { + _, err := kc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (kc *KeysCreate) ExecX(ctx context.Context) { + if err := kc.Exec(ctx); err != nil { + panic(err) + } +} + // check runs all checks and user-defined validators on the builder. func (kc *KeysCreate) check() error { if _, ok := kc.mutation.VerificationKeys(); !ok { - return &ValidationError{Name: "verification_keys", err: errors.New("db: missing required field \"verification_keys\"")} + return &ValidationError{Name: "verification_keys", err: errors.New(`db: missing required field "verification_keys"`)} } if _, ok := kc.mutation.SigningKey(); !ok { - return &ValidationError{Name: "signing_key", err: errors.New("db: missing required field \"signing_key\"")} + return &ValidationError{Name: "signing_key", err: errors.New(`db: missing required field "signing_key"`)} } if _, ok := kc.mutation.SigningKeyPub(); !ok { - return &ValidationError{Name: "signing_key_pub", err: errors.New("db: missing required field \"signing_key_pub\"")} + return &ValidationError{Name: "signing_key_pub", err: errors.New(`db: missing required field "signing_key_pub"`)} } if _, ok := kc.mutation.NextRotation(); !ok { - return &ValidationError{Name: "next_rotation", err: errors.New("db: missing required field \"next_rotation\"")} + return &ValidationError{Name: "next_rotation", err: errors.New(`db: missing required field "next_rotation"`)} } if v, ok := kc.mutation.ID(); ok { if err := keys.IDValidator(v); err != nil { - return &ValidationError{Name: "id", err: fmt.Errorf("db: validator failed for field \"id\": %w", err)} + return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "id": %w`, err)} } } return nil @@ -126,8 +145,8 @@ func (kc *KeysCreate) check() error { func (kc *KeysCreate) sqlSave(ctx context.Context) (*Keys, error) { _node, _spec := kc.createSpec() if err := sqlgraph.CreateNode(ctx, kc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -212,17 +231,19 @@ func (kcb *KeysCreateBulk) Save(ctx context.Context) ([]*Keys, error) { if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, kcb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, kcb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, kcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } + mutation.id = &nodes[i].ID + mutation.done = true return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -247,3 +268,16 @@ func (kcb *KeysCreateBulk) SaveX(ctx context.Context) []*Keys { } return v } + +// Exec executes the query. +func (kcb *KeysCreateBulk) Exec(ctx context.Context) error { + _, err := kcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (kcb *KeysCreateBulk) ExecX(ctx context.Context) { + if err := kcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/keys_delete.go b/storage/ent/db/keys_delete.go index 620f8f10..6950c257 100644 --- a/storage/ent/db/keys_delete.go +++ b/storage/ent/db/keys_delete.go @@ -20,9 +20,9 @@ type KeysDelete struct { mutation *KeysMutation } -// Where adds a new predicate to the KeysDelete builder. +// Where appends a list predicates to the KeysDelete builder. func (kd *KeysDelete) Where(ps ...predicate.Keys) *KeysDelete { - kd.mutation.predicates = append(kd.mutation.predicates, ps...) + kd.mutation.Where(ps...) return kd } @@ -46,6 +46,9 @@ func (kd *KeysDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(kd.hooks) - 1; i >= 0; i-- { + if kd.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = kd.hooks[i](mut) } if _, err := mut.Mutate(ctx, kd.mutation); err != nil { diff --git a/storage/ent/db/keys_query.go b/storage/ent/db/keys_query.go index 6d6b00f9..ee102e27 100644 --- a/storage/ent/db/keys_query.go +++ b/storage/ent/db/keys_query.go @@ -287,8 +287,8 @@ func (kq *KeysQuery) GroupBy(field string, fields ...string) *KeysGroupBy { // Select(keys.FieldVerificationKeys). // Scan(ctx, &v) // -func (kq *KeysQuery) Select(field string, fields ...string) *KeysSelect { - kq.fields = append([]string{field}, fields...) +func (kq *KeysQuery) Select(fields ...string) *KeysSelect { + kq.fields = append(kq.fields, fields...) return &KeysSelect{KeysQuery: kq} } @@ -398,10 +398,14 @@ func (kq *KeysQuery) querySpec() *sqlgraph.QuerySpec { func (kq *KeysQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(kq.driver.Dialect()) t1 := builder.Table(keys.Table) - selector := builder.Select(t1.Columns(keys.Columns...)...).From(t1) + columns := kq.fields + if len(columns) == 0 { + columns = keys.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if kq.sql != nil { selector = kq.sql - selector.Select(selector.Columns(keys.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range kq.predicates { p(selector) @@ -669,13 +673,24 @@ func (kgb *KeysGroupBy) sqlScan(ctx context.Context, v interface{}) error { } func (kgb *KeysGroupBy) sqlQuery() *sql.Selector { - selector := kgb.sql - columns := make([]string, 0, len(kgb.fields)+len(kgb.fns)) - columns = append(columns, kgb.fields...) + selector := kgb.sql.Select() + aggregation := make([]string, 0, len(kgb.fns)) for _, fn := range kgb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(kgb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(kgb.fields)+len(kgb.fns)) + for _, f := range kgb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(kgb.fields...)...) } // KeysSelect is the builder for selecting fields of Keys entities. @@ -891,16 +906,10 @@ func (ks *KeysSelect) BoolX(ctx context.Context) bool { func (ks *KeysSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := ks.sqlQuery().Query() + query, args := ks.sql.Query() if err := ks.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (ks *KeysSelect) sqlQuery() sql.Querier { - selector := ks.sql - selector.Select(selector.Columns(ks.fields...)...) - return selector -} diff --git a/storage/ent/db/keys_update.go b/storage/ent/db/keys_update.go index 8bc0ed3e..0e40565f 100644 --- a/storage/ent/db/keys_update.go +++ b/storage/ent/db/keys_update.go @@ -23,9 +23,9 @@ type KeysUpdate struct { mutation *KeysMutation } -// Where adds a new predicate for the KeysUpdate builder. +// Where appends a list predicates to the KeysUpdate builder. func (ku *KeysUpdate) Where(ps ...predicate.Keys) *KeysUpdate { - ku.mutation.predicates = append(ku.mutation.predicates, ps...) + ku.mutation.Where(ps...) return ku } @@ -78,6 +78,9 @@ func (ku *KeysUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(ku.hooks) - 1; i >= 0; i-- { + if ku.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = ku.hooks[i](mut) } if _, err := mut.Mutate(ctx, ku.mutation); err != nil { @@ -158,8 +161,8 @@ func (ku *KeysUpdate) sqlSave(ctx context.Context) (n int, err error) { if n, err = sqlgraph.UpdateNodes(ctx, ku.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{keys.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -230,6 +233,9 @@ func (kuo *KeysUpdateOne) Save(ctx context.Context) (*Keys, error) { return node, err }) for i := len(kuo.hooks) - 1; i >= 0; i-- { + if kuo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = kuo.hooks[i](mut) } if _, err := mut.Mutate(ctx, kuo.mutation); err != nil { @@ -330,8 +336,8 @@ func (kuo *KeysUpdateOne) sqlSave(ctx context.Context) (_node *Keys, err error) if err = sqlgraph.UpdateNode(ctx, kuo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{keys.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index d5b1f535..d8b8b62c 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -10,196 +10,186 @@ import ( var ( // AuthCodesColumns holds the columns for the "auth_codes" table. AuthCodesColumns = []*schema.Column{ - {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "client_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "client_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "scopes", Type: field.TypeJSON, Nullable: true}, - {Name: "nonce", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "redirect_uri", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "claims_user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "claims_username", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "claims_email", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "nonce", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "redirect_uri", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_username", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_email", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "claims_email_verified", Type: field.TypeBool}, {Name: "claims_groups", Type: field.TypeJSON, Nullable: true}, - {Name: "claims_preferred_username", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "connector_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "claims_preferred_username", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "connector_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "connector_data", Type: field.TypeBytes, Nullable: true}, - {Name: "expiry", Type: field.TypeTime}, - {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, } // AuthCodesTable holds the schema information for the "auth_codes" table. AuthCodesTable = &schema.Table{ - Name: "auth_codes", - Columns: AuthCodesColumns, - PrimaryKey: []*schema.Column{AuthCodesColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "auth_codes", + Columns: AuthCodesColumns, + PrimaryKey: []*schema.Column{AuthCodesColumns[0]}, } // AuthRequestsColumns holds the columns for the "auth_requests" table. AuthRequestsColumns = []*schema.Column{ - {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "client_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "client_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "scopes", Type: field.TypeJSON, Nullable: true}, {Name: "response_types", Type: field.TypeJSON, Nullable: true}, - {Name: "redirect_uri", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "nonce", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "state", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "redirect_uri", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "nonce", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "state", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "force_approval_prompt", Type: field.TypeBool}, {Name: "logged_in", Type: field.TypeBool}, - {Name: "claims_user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "claims_username", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "claims_email", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "claims_user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_username", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_email", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "claims_email_verified", Type: field.TypeBool}, {Name: "claims_groups", Type: field.TypeJSON, Nullable: true}, - {Name: "claims_preferred_username", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "connector_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "claims_preferred_username", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "connector_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "connector_data", Type: field.TypeBytes, Nullable: true}, - {Name: "expiry", Type: field.TypeTime}, - {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, } // AuthRequestsTable holds the schema information for the "auth_requests" table. AuthRequestsTable = &schema.Table{ - Name: "auth_requests", - Columns: AuthRequestsColumns, - PrimaryKey: []*schema.Column{AuthRequestsColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "auth_requests", + Columns: AuthRequestsColumns, + PrimaryKey: []*schema.Column{AuthRequestsColumns[0]}, } // ConnectorsColumns holds the columns for the "connectors" table. ConnectorsColumns = []*schema.Column{ - {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "type", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "name", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "resource_version", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "id", Type: field.TypeString, Unique: true, Size: 100, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "type", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "name", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "resource_version", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "config", Type: field.TypeBytes}, } // ConnectorsTable holds the schema information for the "connectors" table. ConnectorsTable = &schema.Table{ - Name: "connectors", - Columns: ConnectorsColumns, - PrimaryKey: []*schema.Column{ConnectorsColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "connectors", + Columns: ConnectorsColumns, + PrimaryKey: []*schema.Column{ConnectorsColumns[0]}, } // DeviceRequestsColumns holds the columns for the "device_requests" table. DeviceRequestsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "user_code", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "device_code", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "client_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "client_secret", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "user_code", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "device_code", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "client_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "client_secret", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "scopes", Type: field.TypeJSON, Nullable: true}, - {Name: "expiry", Type: field.TypeTime}, + {Name: "expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, } // DeviceRequestsTable holds the schema information for the "device_requests" table. DeviceRequestsTable = &schema.Table{ - Name: "device_requests", - Columns: DeviceRequestsColumns, - PrimaryKey: []*schema.Column{DeviceRequestsColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "device_requests", + Columns: DeviceRequestsColumns, + PrimaryKey: []*schema.Column{DeviceRequestsColumns[0]}, } // DeviceTokensColumns holds the columns for the "device_tokens" table. DeviceTokensColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "device_code", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "status", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "device_code", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "status", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "token", Type: field.TypeBytes, Nullable: true}, - {Name: "expiry", Type: field.TypeTime}, - {Name: "last_request", Type: field.TypeTime}, + {Name: "expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "last_request", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "poll_interval", Type: field.TypeInt}, } // DeviceTokensTable holds the schema information for the "device_tokens" table. DeviceTokensTable = &schema.Table{ - Name: "device_tokens", - Columns: DeviceTokensColumns, - PrimaryKey: []*schema.Column{DeviceTokensColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "device_tokens", + Columns: DeviceTokensColumns, + PrimaryKey: []*schema.Column{DeviceTokensColumns[0]}, } // KeysColumns holds the columns for the "keys" table. KeysColumns = []*schema.Column{ - {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "verification_keys", Type: field.TypeJSON}, {Name: "signing_key", Type: field.TypeJSON}, {Name: "signing_key_pub", Type: field.TypeJSON}, - {Name: "next_rotation", Type: field.TypeTime}, + {Name: "next_rotation", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, } // KeysTable holds the schema information for the "keys" table. KeysTable = &schema.Table{ - Name: "keys", - Columns: KeysColumns, - PrimaryKey: []*schema.Column{KeysColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "keys", + Columns: KeysColumns, + PrimaryKey: []*schema.Column{KeysColumns[0]}, } // Oauth2clientsColumns holds the columns for the "oauth2clients" table. Oauth2clientsColumns = []*schema.Column{ - {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "secret", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "id", Type: field.TypeString, Unique: true, Size: 100, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "secret", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "redirect_uris", Type: field.TypeJSON, Nullable: true}, {Name: "trusted_peers", Type: field.TypeJSON, Nullable: true}, {Name: "public", Type: field.TypeBool}, - {Name: "name", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "logo_url", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "name", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "logo_url", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, } // Oauth2clientsTable holds the schema information for the "oauth2clients" table. Oauth2clientsTable = &schema.Table{ - Name: "oauth2clients", - Columns: Oauth2clientsColumns, - PrimaryKey: []*schema.Column{Oauth2clientsColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "oauth2clients", + Columns: Oauth2clientsColumns, + PrimaryKey: []*schema.Column{Oauth2clientsColumns[0]}, } // OfflineSessionsColumns holds the columns for the "offline_sessions" table. OfflineSessionsColumns = []*schema.Column{ - {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "conn_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "conn_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "refresh", Type: field.TypeBytes}, {Name: "connector_data", Type: field.TypeBytes, Nullable: true}, } // OfflineSessionsTable holds the schema information for the "offline_sessions" table. OfflineSessionsTable = &schema.Table{ - Name: "offline_sessions", - Columns: OfflineSessionsColumns, - PrimaryKey: []*schema.Column{OfflineSessionsColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "offline_sessions", + Columns: OfflineSessionsColumns, + PrimaryKey: []*schema.Column{OfflineSessionsColumns[0]}, } // PasswordsColumns holds the columns for the "passwords" table. PasswordsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "email", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "email", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "hash", Type: field.TypeBytes}, - {Name: "username", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "username", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, } // PasswordsTable holds the schema information for the "passwords" table. PasswordsTable = &schema.Table{ - Name: "passwords", - Columns: PasswordsColumns, - PrimaryKey: []*schema.Column{PasswordsColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "passwords", + Columns: PasswordsColumns, + PrimaryKey: []*schema.Column{PasswordsColumns[0]}, } // RefreshTokensColumns holds the columns for the "refresh_tokens" table. RefreshTokensColumns = []*schema.Column{ - {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "client_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "client_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "scopes", Type: field.TypeJSON, Nullable: true}, - {Name: "nonce", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "claims_user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "claims_username", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "claims_email", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "nonce", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_username", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_email", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "claims_email_verified", Type: field.TypeBool}, {Name: "claims_groups", Type: field.TypeJSON, Nullable: true}, - {Name: "claims_preferred_username", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "connector_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "claims_preferred_username", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "connector_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "connector_data", Type: field.TypeBytes, Nullable: true}, - {Name: "token", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "obsolete_token", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, - {Name: "created_at", Type: field.TypeTime}, - {Name: "last_used", Type: field.TypeTime}, + {Name: "token", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "obsolete_token", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "last_used", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, } // RefreshTokensTable holds the schema information for the "refresh_tokens" table. RefreshTokensTable = &schema.Table{ - Name: "refresh_tokens", - Columns: RefreshTokensColumns, - PrimaryKey: []*schema.Column{RefreshTokensColumns[0]}, - ForeignKeys: []*schema.ForeignKey{}, + Name: "refresh_tokens", + Columns: RefreshTokensColumns, + PrimaryKey: []*schema.Column{RefreshTokensColumns[0]}, } // Tables holds all the tables in the schema. Tables = []*schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index 7ccab3f2..d6d1dbab 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -149,8 +149,8 @@ func (m *AuthCodeMutation) SetID(id string) { m.id = &id } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *AuthCodeMutation) ID() (id string, exists bool) { if m.id == nil { return @@ -737,6 +737,11 @@ func (m *AuthCodeMutation) ResetCodeChallengeMethod() { m.code_challenge_method = nil } +// Where appends a list predicates to the AuthCodeMutation builder. +func (m *AuthCodeMutation) Where(ps ...predicate.AuthCode) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *AuthCodeMutation) Op() Op { return m.op @@ -1262,8 +1267,8 @@ func (m *AuthRequestMutation) SetID(id string) { m.id = &id } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *AuthRequestMutation) ID() (id string, exists bool) { if m.id == nil { return @@ -2007,6 +2012,11 @@ func (m *AuthRequestMutation) ResetCodeChallengeMethod() { m.code_challenge_method = nil } +// Where appends a list predicates to the AuthRequestMutation builder. +func (m *AuthRequestMutation) Where(ps ...predicate.AuthRequest) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *AuthRequestMutation) Op() Op { return m.op @@ -2591,8 +2601,8 @@ func (m *ConnectorMutation) SetID(id string) { m.id = &id } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *ConnectorMutation) ID() (id string, exists bool) { if m.id == nil { return @@ -2744,6 +2754,11 @@ func (m *ConnectorMutation) ResetConfig() { m._config = nil } +// Where appends a list predicates to the ConnectorMutation builder. +func (m *ConnectorMutation) Where(ps ...predicate.Connector) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *ConnectorMutation) Op() Op { return m.op @@ -3042,8 +3057,8 @@ func (m DeviceRequestMutation) Tx() (*Tx, error) { return tx, nil } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *DeviceRequestMutation) ID() (id int, exists bool) { if m.id == nil { return @@ -3280,6 +3295,11 @@ func (m *DeviceRequestMutation) ResetExpiry() { m.expiry = nil } +// Where appends a list predicates to the DeviceRequestMutation builder. +func (m *DeviceRequestMutation) Where(ps ...predicate.DeviceRequest) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *DeviceRequestMutation) Op() Op { return m.op @@ -3622,8 +3642,8 @@ func (m DeviceTokenMutation) Tx() (*Tx, error) { return tx, nil } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *DeviceTokenMutation) ID() (id int, exists bool) { if m.id == nil { return @@ -3880,6 +3900,11 @@ func (m *DeviceTokenMutation) ResetPollInterval() { m.addpoll_interval = nil } +// Where appends a list predicates to the DeviceTokenMutation builder. +func (m *DeviceTokenMutation) Where(ps ...predicate.DeviceToken) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *DeviceTokenMutation) Op() Op { return m.op @@ -4240,8 +4265,8 @@ func (m *KeysMutation) SetID(id string) { m.id = &id } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *KeysMutation) ID() (id string, exists bool) { if m.id == nil { return @@ -4393,6 +4418,11 @@ func (m *KeysMutation) ResetNextRotation() { m.next_rotation = nil } +// Where appends a list predicates to the KeysMutation builder. +func (m *KeysMutation) Where(ps ...predicate.Keys) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *KeysMutation) Op() Op { return m.op @@ -4697,8 +4727,8 @@ func (m *OAuth2ClientMutation) SetID(id string) { m.id = &id } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *OAuth2ClientMutation) ID() (id string, exists bool) { if m.id == nil { return @@ -4948,6 +4978,11 @@ func (m *OAuth2ClientMutation) ResetLogoURL() { m.logo_url = nil } +// Where appends a list predicates to the OAuth2ClientMutation builder. +func (m *OAuth2ClientMutation) Where(ps ...predicate.OAuth2Client) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *OAuth2ClientMutation) Op() Op { return m.op @@ -5299,8 +5334,8 @@ func (m *OfflineSessionMutation) SetID(id string) { m.id = &id } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *OfflineSessionMutation) ID() (id string, exists bool) { if m.id == nil { return @@ -5465,6 +5500,11 @@ func (m *OfflineSessionMutation) ResetConnectorData() { delete(m.clearedFields, offlinesession.FieldConnectorData) } +// Where appends a list predicates to the OfflineSessionMutation builder. +func (m *OfflineSessionMutation) Where(ps ...predicate.OfflineSession) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *OfflineSessionMutation) Op() Op { return m.op @@ -5770,8 +5810,8 @@ func (m PasswordMutation) Tx() (*Tx, error) { return tx, nil } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *PasswordMutation) ID() (id int, exists bool) { if m.id == nil { return @@ -5923,6 +5963,11 @@ func (m *PasswordMutation) ResetUserID() { m.user_id = nil } +// Where appends a list predicates to the PasswordMutation builder. +func (m *PasswordMutation) Where(ps ...predicate.Password) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *PasswordMutation) Op() Op { return m.op @@ -6236,8 +6281,8 @@ func (m *RefreshTokenMutation) SetID(id string) { m.id = &id } -// ID returns the ID value in the mutation. Note that the ID -// is only available if it was provided to the builder. +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. func (m *RefreshTokenMutation) ID() (id string, exists bool) { if m.id == nil { return @@ -6824,6 +6869,11 @@ func (m *RefreshTokenMutation) ResetLastUsed() { m.last_used = nil } +// Where appends a list predicates to the RefreshTokenMutation builder. +func (m *RefreshTokenMutation) Where(ps ...predicate.RefreshToken) { + m.predicates = append(m.predicates, ps...) +} + // Op returns the operation name. func (m *RefreshTokenMutation) Op() Op { return m.op diff --git a/storage/ent/db/oauth2client.go b/storage/ent/db/oauth2client.go index 687a6e69..57d64a49 100644 --- a/storage/ent/db/oauth2client.go +++ b/storage/ent/db/oauth2client.go @@ -69,7 +69,6 @@ func (o *OAuth2Client) assignValues(columns []string, values []interface{}) erro o.Secret = value.String } case oauth2client.FieldRedirectUris: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field redirect_uris", values[i]) } else if value != nil && len(*value) > 0 { @@ -78,7 +77,6 @@ func (o *OAuth2Client) assignValues(columns []string, values []interface{}) erro } } case oauth2client.FieldTrustedPeers: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field trusted_peers", values[i]) } else if value != nil && len(*value) > 0 { diff --git a/storage/ent/db/oauth2client_create.go b/storage/ent/db/oauth2client_create.go index 259b9473..b141352b 100644 --- a/storage/ent/db/oauth2client_create.go +++ b/storage/ent/db/oauth2client_create.go @@ -87,11 +87,17 @@ func (oc *OAuth2ClientCreate) Save(ctx context.Context) (*OAuth2Client, error) { return nil, err } oc.mutation = mutation - node, err = oc.sqlSave(ctx) + if node, err = oc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(oc.hooks) - 1; i >= 0; i-- { + if oc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = oc.hooks[i](mut) } if _, err := mut.Mutate(ctx, oc.mutation); err != nil { @@ -110,38 +116,51 @@ func (oc *OAuth2ClientCreate) SaveX(ctx context.Context) *OAuth2Client { return v } +// Exec executes the query. +func (oc *OAuth2ClientCreate) Exec(ctx context.Context) error { + _, err := oc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (oc *OAuth2ClientCreate) ExecX(ctx context.Context) { + if err := oc.Exec(ctx); err != nil { + panic(err) + } +} + // check runs all checks and user-defined validators on the builder. func (oc *OAuth2ClientCreate) check() error { if _, ok := oc.mutation.Secret(); !ok { - return &ValidationError{Name: "secret", err: errors.New("db: missing required field \"secret\"")} + return &ValidationError{Name: "secret", err: errors.New(`db: missing required field "secret"`)} } if v, ok := oc.mutation.Secret(); ok { if err := oauth2client.SecretValidator(v); err != nil { - return &ValidationError{Name: "secret", err: fmt.Errorf("db: validator failed for field \"secret\": %w", err)} + return &ValidationError{Name: "secret", err: fmt.Errorf(`db: validator failed for field "secret": %w`, err)} } } if _, ok := oc.mutation.Public(); !ok { - return &ValidationError{Name: "public", err: errors.New("db: missing required field \"public\"")} + return &ValidationError{Name: "public", err: errors.New(`db: missing required field "public"`)} } if _, ok := oc.mutation.Name(); !ok { - return &ValidationError{Name: "name", err: errors.New("db: missing required field \"name\"")} + return &ValidationError{Name: "name", err: errors.New(`db: missing required field "name"`)} } if v, ok := oc.mutation.Name(); ok { if err := oauth2client.NameValidator(v); err != nil { - return &ValidationError{Name: "name", err: fmt.Errorf("db: validator failed for field \"name\": %w", err)} + return &ValidationError{Name: "name", err: fmt.Errorf(`db: validator failed for field "name": %w`, err)} } } if _, ok := oc.mutation.LogoURL(); !ok { - return &ValidationError{Name: "logo_url", err: errors.New("db: missing required field \"logo_url\"")} + return &ValidationError{Name: "logo_url", err: errors.New(`db: missing required field "logo_url"`)} } if v, ok := oc.mutation.LogoURL(); ok { if err := oauth2client.LogoURLValidator(v); err != nil { - return &ValidationError{Name: "logo_url", err: fmt.Errorf("db: validator failed for field \"logo_url\": %w", err)} + return &ValidationError{Name: "logo_url", err: fmt.Errorf(`db: validator failed for field "logo_url": %w`, err)} } } if v, ok := oc.mutation.ID(); ok { if err := oauth2client.IDValidator(v); err != nil { - return &ValidationError{Name: "id", err: fmt.Errorf("db: validator failed for field \"id\": %w", err)} + return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "id": %w`, err)} } } return nil @@ -150,8 +169,8 @@ func (oc *OAuth2ClientCreate) check() error { func (oc *OAuth2ClientCreate) sqlSave(ctx context.Context) (*OAuth2Client, error) { _node, _spec := oc.createSpec() if err := sqlgraph.CreateNode(ctx, oc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -252,17 +271,19 @@ func (ocb *OAuth2ClientCreateBulk) Save(ctx context.Context) ([]*OAuth2Client, e if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, ocb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, ocb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, ocb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } + mutation.id = &nodes[i].ID + mutation.done = true return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -287,3 +308,16 @@ func (ocb *OAuth2ClientCreateBulk) SaveX(ctx context.Context) []*OAuth2Client { } return v } + +// Exec executes the query. +func (ocb *OAuth2ClientCreateBulk) Exec(ctx context.Context) error { + _, err := ocb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (ocb *OAuth2ClientCreateBulk) ExecX(ctx context.Context) { + if err := ocb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/oauth2client_delete.go b/storage/ent/db/oauth2client_delete.go index ab0a45f6..71a33c76 100644 --- a/storage/ent/db/oauth2client_delete.go +++ b/storage/ent/db/oauth2client_delete.go @@ -20,9 +20,9 @@ type OAuth2ClientDelete struct { mutation *OAuth2ClientMutation } -// Where adds a new predicate to the OAuth2ClientDelete builder. +// Where appends a list predicates to the OAuth2ClientDelete builder. func (od *OAuth2ClientDelete) Where(ps ...predicate.OAuth2Client) *OAuth2ClientDelete { - od.mutation.predicates = append(od.mutation.predicates, ps...) + od.mutation.Where(ps...) return od } @@ -46,6 +46,9 @@ func (od *OAuth2ClientDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(od.hooks) - 1; i >= 0; i-- { + if od.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = od.hooks[i](mut) } if _, err := mut.Mutate(ctx, od.mutation); err != nil { diff --git a/storage/ent/db/oauth2client_query.go b/storage/ent/db/oauth2client_query.go index 558542f1..d2363288 100644 --- a/storage/ent/db/oauth2client_query.go +++ b/storage/ent/db/oauth2client_query.go @@ -287,8 +287,8 @@ func (oq *OAuth2ClientQuery) GroupBy(field string, fields ...string) *OAuth2Clie // Select(oauth2client.FieldSecret). // Scan(ctx, &v) // -func (oq *OAuth2ClientQuery) Select(field string, fields ...string) *OAuth2ClientSelect { - oq.fields = append([]string{field}, fields...) +func (oq *OAuth2ClientQuery) Select(fields ...string) *OAuth2ClientSelect { + oq.fields = append(oq.fields, fields...) return &OAuth2ClientSelect{OAuth2ClientQuery: oq} } @@ -398,10 +398,14 @@ func (oq *OAuth2ClientQuery) querySpec() *sqlgraph.QuerySpec { func (oq *OAuth2ClientQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(oq.driver.Dialect()) t1 := builder.Table(oauth2client.Table) - selector := builder.Select(t1.Columns(oauth2client.Columns...)...).From(t1) + columns := oq.fields + if len(columns) == 0 { + columns = oauth2client.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if oq.sql != nil { selector = oq.sql - selector.Select(selector.Columns(oauth2client.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range oq.predicates { p(selector) @@ -669,13 +673,24 @@ func (ogb *OAuth2ClientGroupBy) sqlScan(ctx context.Context, v interface{}) erro } func (ogb *OAuth2ClientGroupBy) sqlQuery() *sql.Selector { - selector := ogb.sql - columns := make([]string, 0, len(ogb.fields)+len(ogb.fns)) - columns = append(columns, ogb.fields...) + selector := ogb.sql.Select() + aggregation := make([]string, 0, len(ogb.fns)) for _, fn := range ogb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(ogb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(ogb.fields)+len(ogb.fns)) + for _, f := range ogb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(ogb.fields...)...) } // OAuth2ClientSelect is the builder for selecting fields of OAuth2Client entities. @@ -891,16 +906,10 @@ func (os *OAuth2ClientSelect) BoolX(ctx context.Context) bool { func (os *OAuth2ClientSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := os.sqlQuery().Query() + query, args := os.sql.Query() if err := os.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (os *OAuth2ClientSelect) sqlQuery() sql.Querier { - selector := os.sql - selector.Select(selector.Columns(os.fields...)...) - return selector -} diff --git a/storage/ent/db/oauth2client_update.go b/storage/ent/db/oauth2client_update.go index 32982418..29b7f090 100644 --- a/storage/ent/db/oauth2client_update.go +++ b/storage/ent/db/oauth2client_update.go @@ -20,9 +20,9 @@ type OAuth2ClientUpdate struct { mutation *OAuth2ClientMutation } -// Where adds a new predicate for the OAuth2ClientUpdate builder. +// Where appends a list predicates to the OAuth2ClientUpdate builder. func (ou *OAuth2ClientUpdate) Where(ps ...predicate.OAuth2Client) *OAuth2ClientUpdate { - ou.mutation.predicates = append(ou.mutation.predicates, ps...) + ou.mutation.Where(ps...) return ou } @@ -105,6 +105,9 @@ func (ou *OAuth2ClientUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(ou.hooks) - 1; i >= 0; i-- { + if ou.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = ou.hooks[i](mut) } if _, err := mut.Mutate(ctx, ou.mutation); err != nil { @@ -231,8 +234,8 @@ func (ou *OAuth2ClientUpdate) sqlSave(ctx context.Context) (n int, err error) { if n, err = sqlgraph.UpdateNodes(ctx, ou.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{oauth2client.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -333,6 +336,9 @@ func (ouo *OAuth2ClientUpdateOne) Save(ctx context.Context) (*OAuth2Client, erro return node, err }) for i := len(ouo.hooks) - 1; i >= 0; i-- { + if ouo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = ouo.hooks[i](mut) } if _, err := mut.Mutate(ctx, ouo.mutation); err != nil { @@ -479,8 +485,8 @@ func (ouo *OAuth2ClientUpdateOne) sqlSave(ctx context.Context) (_node *OAuth2Cli if err = sqlgraph.UpdateNode(ctx, ouo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{oauth2client.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/offlinesession_create.go b/storage/ent/db/offlinesession_create.go index 1103e8ee..add3912a 100644 --- a/storage/ent/db/offlinesession_create.go +++ b/storage/ent/db/offlinesession_create.go @@ -75,11 +75,17 @@ func (osc *OfflineSessionCreate) Save(ctx context.Context) (*OfflineSession, err return nil, err } osc.mutation = mutation - node, err = osc.sqlSave(ctx) + if node, err = osc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(osc.hooks) - 1; i >= 0; i-- { + if osc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = osc.hooks[i](mut) } if _, err := mut.Mutate(ctx, osc.mutation); err != nil { @@ -98,30 +104,43 @@ func (osc *OfflineSessionCreate) SaveX(ctx context.Context) *OfflineSession { return v } +// Exec executes the query. +func (osc *OfflineSessionCreate) Exec(ctx context.Context) error { + _, err := osc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (osc *OfflineSessionCreate) ExecX(ctx context.Context) { + if err := osc.Exec(ctx); err != nil { + panic(err) + } +} + // check runs all checks and user-defined validators on the builder. func (osc *OfflineSessionCreate) check() error { if _, ok := osc.mutation.UserID(); !ok { - return &ValidationError{Name: "user_id", err: errors.New("db: missing required field \"user_id\"")} + return &ValidationError{Name: "user_id", err: errors.New(`db: missing required field "user_id"`)} } if v, ok := osc.mutation.UserID(); ok { if err := offlinesession.UserIDValidator(v); err != nil { - return &ValidationError{Name: "user_id", err: fmt.Errorf("db: validator failed for field \"user_id\": %w", err)} + return &ValidationError{Name: "user_id", err: fmt.Errorf(`db: validator failed for field "user_id": %w`, err)} } } if _, ok := osc.mutation.ConnID(); !ok { - return &ValidationError{Name: "conn_id", err: errors.New("db: missing required field \"conn_id\"")} + return &ValidationError{Name: "conn_id", err: errors.New(`db: missing required field "conn_id"`)} } if v, ok := osc.mutation.ConnID(); ok { if err := offlinesession.ConnIDValidator(v); err != nil { - return &ValidationError{Name: "conn_id", err: fmt.Errorf("db: validator failed for field \"conn_id\": %w", err)} + return &ValidationError{Name: "conn_id", err: fmt.Errorf(`db: validator failed for field "conn_id": %w`, err)} } } if _, ok := osc.mutation.Refresh(); !ok { - return &ValidationError{Name: "refresh", err: errors.New("db: missing required field \"refresh\"")} + return &ValidationError{Name: "refresh", err: errors.New(`db: missing required field "refresh"`)} } if v, ok := osc.mutation.ID(); ok { if err := offlinesession.IDValidator(v); err != nil { - return &ValidationError{Name: "id", err: fmt.Errorf("db: validator failed for field \"id\": %w", err)} + return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "id": %w`, err)} } } return nil @@ -130,8 +149,8 @@ func (osc *OfflineSessionCreate) check() error { func (osc *OfflineSessionCreate) sqlSave(ctx context.Context) (*OfflineSession, error) { _node, _spec := osc.createSpec() if err := sqlgraph.CreateNode(ctx, osc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -216,17 +235,19 @@ func (oscb *OfflineSessionCreateBulk) Save(ctx context.Context) ([]*OfflineSessi if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, oscb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, oscb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, oscb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } + mutation.id = &nodes[i].ID + mutation.done = true return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -251,3 +272,16 @@ func (oscb *OfflineSessionCreateBulk) SaveX(ctx context.Context) []*OfflineSessi } return v } + +// Exec executes the query. +func (oscb *OfflineSessionCreateBulk) Exec(ctx context.Context) error { + _, err := oscb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (oscb *OfflineSessionCreateBulk) ExecX(ctx context.Context) { + if err := oscb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/offlinesession_delete.go b/storage/ent/db/offlinesession_delete.go index 8ca83378..3b2e9143 100644 --- a/storage/ent/db/offlinesession_delete.go +++ b/storage/ent/db/offlinesession_delete.go @@ -20,9 +20,9 @@ type OfflineSessionDelete struct { mutation *OfflineSessionMutation } -// Where adds a new predicate to the OfflineSessionDelete builder. +// Where appends a list predicates to the OfflineSessionDelete builder. func (osd *OfflineSessionDelete) Where(ps ...predicate.OfflineSession) *OfflineSessionDelete { - osd.mutation.predicates = append(osd.mutation.predicates, ps...) + osd.mutation.Where(ps...) return osd } @@ -46,6 +46,9 @@ func (osd *OfflineSessionDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(osd.hooks) - 1; i >= 0; i-- { + if osd.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = osd.hooks[i](mut) } if _, err := mut.Mutate(ctx, osd.mutation); err != nil { diff --git a/storage/ent/db/offlinesession_query.go b/storage/ent/db/offlinesession_query.go index a4fbe1fd..4306deaa 100644 --- a/storage/ent/db/offlinesession_query.go +++ b/storage/ent/db/offlinesession_query.go @@ -287,8 +287,8 @@ func (osq *OfflineSessionQuery) GroupBy(field string, fields ...string) *Offline // Select(offlinesession.FieldUserID). // Scan(ctx, &v) // -func (osq *OfflineSessionQuery) Select(field string, fields ...string) *OfflineSessionSelect { - osq.fields = append([]string{field}, fields...) +func (osq *OfflineSessionQuery) Select(fields ...string) *OfflineSessionSelect { + osq.fields = append(osq.fields, fields...) return &OfflineSessionSelect{OfflineSessionQuery: osq} } @@ -398,10 +398,14 @@ func (osq *OfflineSessionQuery) querySpec() *sqlgraph.QuerySpec { func (osq *OfflineSessionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(osq.driver.Dialect()) t1 := builder.Table(offlinesession.Table) - selector := builder.Select(t1.Columns(offlinesession.Columns...)...).From(t1) + columns := osq.fields + if len(columns) == 0 { + columns = offlinesession.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if osq.sql != nil { selector = osq.sql - selector.Select(selector.Columns(offlinesession.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range osq.predicates { p(selector) @@ -669,13 +673,24 @@ func (osgb *OfflineSessionGroupBy) sqlScan(ctx context.Context, v interface{}) e } func (osgb *OfflineSessionGroupBy) sqlQuery() *sql.Selector { - selector := osgb.sql - columns := make([]string, 0, len(osgb.fields)+len(osgb.fns)) - columns = append(columns, osgb.fields...) + selector := osgb.sql.Select() + aggregation := make([]string, 0, len(osgb.fns)) for _, fn := range osgb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(osgb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(osgb.fields)+len(osgb.fns)) + for _, f := range osgb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(osgb.fields...)...) } // OfflineSessionSelect is the builder for selecting fields of OfflineSession entities. @@ -891,16 +906,10 @@ func (oss *OfflineSessionSelect) BoolX(ctx context.Context) bool { func (oss *OfflineSessionSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := oss.sqlQuery().Query() + query, args := oss.sql.Query() if err := oss.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (oss *OfflineSessionSelect) sqlQuery() sql.Querier { - selector := oss.sql - selector.Select(selector.Columns(oss.fields...)...) - return selector -} diff --git a/storage/ent/db/offlinesession_update.go b/storage/ent/db/offlinesession_update.go index d6edd522..20c9faf3 100644 --- a/storage/ent/db/offlinesession_update.go +++ b/storage/ent/db/offlinesession_update.go @@ -20,9 +20,9 @@ type OfflineSessionUpdate struct { mutation *OfflineSessionMutation } -// Where adds a new predicate for the OfflineSessionUpdate builder. +// Where appends a list predicates to the OfflineSessionUpdate builder. func (osu *OfflineSessionUpdate) Where(ps ...predicate.OfflineSession) *OfflineSessionUpdate { - osu.mutation.predicates = append(osu.mutation.predicates, ps...) + osu.mutation.Where(ps...) return osu } @@ -87,6 +87,9 @@ func (osu *OfflineSessionUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(osu.hooks) - 1; i >= 0; i-- { + if osu.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = osu.hooks[i](mut) } if _, err := mut.Mutate(ctx, osu.mutation); err != nil { @@ -188,8 +191,8 @@ func (osu *OfflineSessionUpdate) sqlSave(ctx context.Context) (n int, err error) if n, err = sqlgraph.UpdateNodes(ctx, osu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{offlinesession.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -272,6 +275,9 @@ func (osuo *OfflineSessionUpdateOne) Save(ctx context.Context) (*OfflineSession, return node, err }) for i := len(osuo.hooks) - 1; i >= 0; i-- { + if osuo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = osuo.hooks[i](mut) } if _, err := mut.Mutate(ctx, osuo.mutation); err != nil { @@ -393,8 +399,8 @@ func (osuo *OfflineSessionUpdateOne) sqlSave(ctx context.Context) (_node *Offlin if err = sqlgraph.UpdateNode(ctx, osuo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{offlinesession.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/password_create.go b/storage/ent/db/password_create.go index 2e01f4a2..af902ea3 100644 --- a/storage/ent/db/password_create.go +++ b/storage/ent/db/password_create.go @@ -69,11 +69,17 @@ func (pc *PasswordCreate) Save(ctx context.Context) (*Password, error) { return nil, err } pc.mutation = mutation - node, err = pc.sqlSave(ctx) + if node, err = pc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(pc.hooks) - 1; i >= 0; i-- { + if pc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = pc.hooks[i](mut) } if _, err := mut.Mutate(ctx, pc.mutation); err != nil { @@ -92,33 +98,46 @@ func (pc *PasswordCreate) SaveX(ctx context.Context) *Password { return v } +// Exec executes the query. +func (pc *PasswordCreate) Exec(ctx context.Context) error { + _, err := pc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (pc *PasswordCreate) ExecX(ctx context.Context) { + if err := pc.Exec(ctx); err != nil { + panic(err) + } +} + // check runs all checks and user-defined validators on the builder. func (pc *PasswordCreate) check() error { if _, ok := pc.mutation.Email(); !ok { - return &ValidationError{Name: "email", err: errors.New("db: missing required field \"email\"")} + return &ValidationError{Name: "email", err: errors.New(`db: missing required field "email"`)} } if v, ok := pc.mutation.Email(); ok { if err := password.EmailValidator(v); err != nil { - return &ValidationError{Name: "email", err: fmt.Errorf("db: validator failed for field \"email\": %w", err)} + return &ValidationError{Name: "email", err: fmt.Errorf(`db: validator failed for field "email": %w`, err)} } } if _, ok := pc.mutation.Hash(); !ok { - return &ValidationError{Name: "hash", err: errors.New("db: missing required field \"hash\"")} + return &ValidationError{Name: "hash", err: errors.New(`db: missing required field "hash"`)} } if _, ok := pc.mutation.Username(); !ok { - return &ValidationError{Name: "username", err: errors.New("db: missing required field \"username\"")} + return &ValidationError{Name: "username", err: errors.New(`db: missing required field "username"`)} } if v, ok := pc.mutation.Username(); ok { if err := password.UsernameValidator(v); err != nil { - return &ValidationError{Name: "username", err: fmt.Errorf("db: validator failed for field \"username\": %w", err)} + return &ValidationError{Name: "username", err: fmt.Errorf(`db: validator failed for field "username": %w`, err)} } } if _, ok := pc.mutation.UserID(); !ok { - return &ValidationError{Name: "user_id", err: errors.New("db: missing required field \"user_id\"")} + return &ValidationError{Name: "user_id", err: errors.New(`db: missing required field "user_id"`)} } if v, ok := pc.mutation.UserID(); ok { if err := password.UserIDValidator(v); err != nil { - return &ValidationError{Name: "user_id", err: fmt.Errorf("db: validator failed for field \"user_id\": %w", err)} + return &ValidationError{Name: "user_id", err: fmt.Errorf(`db: validator failed for field "user_id": %w`, err)} } } return nil @@ -127,8 +146,8 @@ func (pc *PasswordCreate) check() error { func (pc *PasswordCreate) sqlSave(ctx context.Context) (*Password, error) { _node, _spec := pc.createSpec() if err := sqlgraph.CreateNode(ctx, pc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -211,19 +230,23 @@ func (pcb *PasswordCreateBulk) Save(ctx context.Context) ([]*Password, error) { if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, pcb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, pcb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, pcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } - id := specs[i].ID.Value.(int64) - nodes[i].ID = int(id) + mutation.id = &nodes[i].ID + mutation.done = true + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -248,3 +271,16 @@ func (pcb *PasswordCreateBulk) SaveX(ctx context.Context) []*Password { } return v } + +// Exec executes the query. +func (pcb *PasswordCreateBulk) Exec(ctx context.Context) error { + _, err := pcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (pcb *PasswordCreateBulk) ExecX(ctx context.Context) { + if err := pcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/password_delete.go b/storage/ent/db/password_delete.go index 87d018fc..d1c59870 100644 --- a/storage/ent/db/password_delete.go +++ b/storage/ent/db/password_delete.go @@ -20,9 +20,9 @@ type PasswordDelete struct { mutation *PasswordMutation } -// Where adds a new predicate to the PasswordDelete builder. +// Where appends a list predicates to the PasswordDelete builder. func (pd *PasswordDelete) Where(ps ...predicate.Password) *PasswordDelete { - pd.mutation.predicates = append(pd.mutation.predicates, ps...) + pd.mutation.Where(ps...) return pd } @@ -46,6 +46,9 @@ func (pd *PasswordDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(pd.hooks) - 1; i >= 0; i-- { + if pd.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = pd.hooks[i](mut) } if _, err := mut.Mutate(ctx, pd.mutation); err != nil { diff --git a/storage/ent/db/password_query.go b/storage/ent/db/password_query.go index 8bfe9a83..50e493fd 100644 --- a/storage/ent/db/password_query.go +++ b/storage/ent/db/password_query.go @@ -287,8 +287,8 @@ func (pq *PasswordQuery) GroupBy(field string, fields ...string) *PasswordGroupB // Select(password.FieldEmail). // Scan(ctx, &v) // -func (pq *PasswordQuery) Select(field string, fields ...string) *PasswordSelect { - pq.fields = append([]string{field}, fields...) +func (pq *PasswordQuery) Select(fields ...string) *PasswordSelect { + pq.fields = append(pq.fields, fields...) return &PasswordSelect{PasswordQuery: pq} } @@ -398,10 +398,14 @@ func (pq *PasswordQuery) querySpec() *sqlgraph.QuerySpec { func (pq *PasswordQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pq.driver.Dialect()) t1 := builder.Table(password.Table) - selector := builder.Select(t1.Columns(password.Columns...)...).From(t1) + columns := pq.fields + if len(columns) == 0 { + columns = password.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if pq.sql != nil { selector = pq.sql - selector.Select(selector.Columns(password.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range pq.predicates { p(selector) @@ -669,13 +673,24 @@ func (pgb *PasswordGroupBy) sqlScan(ctx context.Context, v interface{}) error { } func (pgb *PasswordGroupBy) sqlQuery() *sql.Selector { - selector := pgb.sql - columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) - columns = append(columns, pgb.fields...) + selector := pgb.sql.Select() + aggregation := make([]string, 0, len(pgb.fns)) for _, fn := range pgb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(pgb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) + for _, f := range pgb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(pgb.fields...)...) } // PasswordSelect is the builder for selecting fields of Password entities. @@ -891,16 +906,10 @@ func (ps *PasswordSelect) BoolX(ctx context.Context) bool { func (ps *PasswordSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := ps.sqlQuery().Query() + query, args := ps.sql.Query() if err := ps.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (ps *PasswordSelect) sqlQuery() sql.Querier { - selector := ps.sql - selector.Select(selector.Columns(ps.fields...)...) - return selector -} diff --git a/storage/ent/db/password_update.go b/storage/ent/db/password_update.go index 0eb1cb61..f15fb017 100644 --- a/storage/ent/db/password_update.go +++ b/storage/ent/db/password_update.go @@ -20,9 +20,9 @@ type PasswordUpdate struct { mutation *PasswordMutation } -// Where adds a new predicate for the PasswordUpdate builder. +// Where appends a list predicates to the PasswordUpdate builder. func (pu *PasswordUpdate) Where(ps ...predicate.Password) *PasswordUpdate { - pu.mutation.predicates = append(pu.mutation.predicates, ps...) + pu.mutation.Where(ps...) return pu } @@ -81,6 +81,9 @@ func (pu *PasswordUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(pu.hooks) - 1; i >= 0; i-- { + if pu.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = pu.hooks[i](mut) } if _, err := mut.Mutate(ctx, pu.mutation); err != nil { @@ -181,8 +184,8 @@ func (pu *PasswordUpdate) sqlSave(ctx context.Context) (n int, err error) { if n, err = sqlgraph.UpdateNodes(ctx, pu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{password.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -259,6 +262,9 @@ func (puo *PasswordUpdateOne) Save(ctx context.Context) (*Password, error) { return node, err }) for i := len(puo.hooks) - 1; i >= 0; i-- { + if puo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = puo.hooks[i](mut) } if _, err := mut.Mutate(ctx, puo.mutation); err != nil { @@ -379,8 +385,8 @@ func (puo *PasswordUpdateOne) sqlSave(ctx context.Context) (_node *Password, err if err = sqlgraph.UpdateNode(ctx, puo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{password.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/refreshtoken.go b/storage/ent/db/refreshtoken.go index 7e527079..3c591206 100644 --- a/storage/ent/db/refreshtoken.go +++ b/storage/ent/db/refreshtoken.go @@ -90,7 +90,6 @@ func (rt *RefreshToken) assignValues(columns []string, values []interface{}) err rt.ClientID = value.String } case refreshtoken.FieldScopes: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field scopes", values[i]) } else if value != nil && len(*value) > 0 { @@ -129,7 +128,6 @@ func (rt *RefreshToken) assignValues(columns []string, values []interface{}) err rt.ClaimsEmailVerified = value.Bool } case refreshtoken.FieldClaimsGroups: - if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field claims_groups", values[i]) } else if value != nil && len(*value) > 0 { diff --git a/storage/ent/db/refreshtoken_create.go b/storage/ent/db/refreshtoken_create.go index e73f276a..00d29775 100644 --- a/storage/ent/db/refreshtoken_create.go +++ b/storage/ent/db/refreshtoken_create.go @@ -183,11 +183,17 @@ func (rtc *RefreshTokenCreate) Save(ctx context.Context) (*RefreshToken, error) return nil, err } rtc.mutation = mutation - node, err = rtc.sqlSave(ctx) + if node, err = rtc.sqlSave(ctx); err != nil { + return nil, err + } + mutation.id = &node.ID mutation.done = true return node, err }) for i := len(rtc.hooks) - 1; i >= 0; i-- { + if rtc.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = rtc.hooks[i](mut) } if _, err := mut.Mutate(ctx, rtc.mutation); err != nil { @@ -206,6 +212,19 @@ func (rtc *RefreshTokenCreate) SaveX(ctx context.Context) *RefreshToken { return v } +// Exec executes the query. +func (rtc *RefreshTokenCreate) Exec(ctx context.Context) error { + _, err := rtc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (rtc *RefreshTokenCreate) ExecX(ctx context.Context) { + if err := rtc.Exec(ctx); err != nil { + panic(err) + } +} + // defaults sets the default values of the builder before save. func (rtc *RefreshTokenCreate) defaults() { if _, ok := rtc.mutation.ClaimsPreferredUsername(); !ok { @@ -233,74 +252,74 @@ func (rtc *RefreshTokenCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (rtc *RefreshTokenCreate) check() error { if _, ok := rtc.mutation.ClientID(); !ok { - return &ValidationError{Name: "client_id", err: errors.New("db: missing required field \"client_id\"")} + return &ValidationError{Name: "client_id", err: errors.New(`db: missing required field "client_id"`)} } if v, ok := rtc.mutation.ClientID(); ok { if err := refreshtoken.ClientIDValidator(v); err != nil { - return &ValidationError{Name: "client_id", err: fmt.Errorf("db: validator failed for field \"client_id\": %w", err)} + return &ValidationError{Name: "client_id", err: fmt.Errorf(`db: validator failed for field "client_id": %w`, err)} } } if _, ok := rtc.mutation.Nonce(); !ok { - return &ValidationError{Name: "nonce", err: errors.New("db: missing required field \"nonce\"")} + return &ValidationError{Name: "nonce", err: errors.New(`db: missing required field "nonce"`)} } if v, ok := rtc.mutation.Nonce(); ok { if err := refreshtoken.NonceValidator(v); err != nil { - return &ValidationError{Name: "nonce", err: fmt.Errorf("db: validator failed for field \"nonce\": %w", err)} + return &ValidationError{Name: "nonce", err: fmt.Errorf(`db: validator failed for field "nonce": %w`, err)} } } if _, ok := rtc.mutation.ClaimsUserID(); !ok { - return &ValidationError{Name: "claims_user_id", err: errors.New("db: missing required field \"claims_user_id\"")} + return &ValidationError{Name: "claims_user_id", err: errors.New(`db: missing required field "claims_user_id"`)} } if v, ok := rtc.mutation.ClaimsUserID(); ok { if err := refreshtoken.ClaimsUserIDValidator(v); err != nil { - return &ValidationError{Name: "claims_user_id", err: fmt.Errorf("db: validator failed for field \"claims_user_id\": %w", err)} + return &ValidationError{Name: "claims_user_id", err: fmt.Errorf(`db: validator failed for field "claims_user_id": %w`, err)} } } if _, ok := rtc.mutation.ClaimsUsername(); !ok { - return &ValidationError{Name: "claims_username", err: errors.New("db: missing required field \"claims_username\"")} + return &ValidationError{Name: "claims_username", err: errors.New(`db: missing required field "claims_username"`)} } if v, ok := rtc.mutation.ClaimsUsername(); ok { if err := refreshtoken.ClaimsUsernameValidator(v); err != nil { - return &ValidationError{Name: "claims_username", err: fmt.Errorf("db: validator failed for field \"claims_username\": %w", err)} + return &ValidationError{Name: "claims_username", err: fmt.Errorf(`db: validator failed for field "claims_username": %w`, err)} } } if _, ok := rtc.mutation.ClaimsEmail(); !ok { - return &ValidationError{Name: "claims_email", err: errors.New("db: missing required field \"claims_email\"")} + return &ValidationError{Name: "claims_email", err: errors.New(`db: missing required field "claims_email"`)} } if v, ok := rtc.mutation.ClaimsEmail(); ok { if err := refreshtoken.ClaimsEmailValidator(v); err != nil { - return &ValidationError{Name: "claims_email", err: fmt.Errorf("db: validator failed for field \"claims_email\": %w", err)} + return &ValidationError{Name: "claims_email", err: fmt.Errorf(`db: validator failed for field "claims_email": %w`, err)} } } if _, ok := rtc.mutation.ClaimsEmailVerified(); !ok { - return &ValidationError{Name: "claims_email_verified", err: errors.New("db: missing required field \"claims_email_verified\"")} + return &ValidationError{Name: "claims_email_verified", err: errors.New(`db: missing required field "claims_email_verified"`)} } if _, ok := rtc.mutation.ClaimsPreferredUsername(); !ok { - return &ValidationError{Name: "claims_preferred_username", err: errors.New("db: missing required field \"claims_preferred_username\"")} + return &ValidationError{Name: "claims_preferred_username", err: errors.New(`db: missing required field "claims_preferred_username"`)} } if _, ok := rtc.mutation.ConnectorID(); !ok { - return &ValidationError{Name: "connector_id", err: errors.New("db: missing required field \"connector_id\"")} + return &ValidationError{Name: "connector_id", err: errors.New(`db: missing required field "connector_id"`)} } if v, ok := rtc.mutation.ConnectorID(); ok { if err := refreshtoken.ConnectorIDValidator(v); err != nil { - return &ValidationError{Name: "connector_id", err: fmt.Errorf("db: validator failed for field \"connector_id\": %w", err)} + return &ValidationError{Name: "connector_id", err: fmt.Errorf(`db: validator failed for field "connector_id": %w`, err)} } } if _, ok := rtc.mutation.Token(); !ok { - return &ValidationError{Name: "token", err: errors.New("db: missing required field \"token\"")} + return &ValidationError{Name: "token", err: errors.New(`db: missing required field "token"`)} } if _, ok := rtc.mutation.ObsoleteToken(); !ok { - return &ValidationError{Name: "obsolete_token", err: errors.New("db: missing required field \"obsolete_token\"")} + return &ValidationError{Name: "obsolete_token", err: errors.New(`db: missing required field "obsolete_token"`)} } if _, ok := rtc.mutation.CreatedAt(); !ok { - return &ValidationError{Name: "created_at", err: errors.New("db: missing required field \"created_at\"")} + return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "created_at"`)} } if _, ok := rtc.mutation.LastUsed(); !ok { - return &ValidationError{Name: "last_used", err: errors.New("db: missing required field \"last_used\"")} + return &ValidationError{Name: "last_used", err: errors.New(`db: missing required field "last_used"`)} } if v, ok := rtc.mutation.ID(); ok { if err := refreshtoken.IDValidator(v); err != nil { - return &ValidationError{Name: "id", err: fmt.Errorf("db: validator failed for field \"id\": %w", err)} + return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "id": %w`, err)} } } return nil @@ -309,8 +328,8 @@ func (rtc *RefreshTokenCreate) check() error { func (rtc *RefreshTokenCreate) sqlSave(ctx context.Context) (*RefreshToken, error) { _node, _spec := rtc.createSpec() if err := sqlgraph.CreateNode(ctx, rtc.driver, _spec); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } @@ -484,17 +503,19 @@ func (rtcb *RefreshTokenCreateBulk) Save(ctx context.Context) ([]*RefreshToken, if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, rtcb.builders[i+1].mutation) } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} // Invoke the actual operation on the latest mutation in the chain. - if err = sqlgraph.BatchCreate(ctx, rtcb.driver, &sqlgraph.BatchCreateSpec{Nodes: specs}); err != nil { - if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + if err = sqlgraph.BatchCreate(ctx, rtcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } } } - mutation.done = true if err != nil { return nil, err } + mutation.id = &nodes[i].ID + mutation.done = true return nodes[i], nil }) for i := len(builder.hooks) - 1; i >= 0; i-- { @@ -519,3 +540,16 @@ func (rtcb *RefreshTokenCreateBulk) SaveX(ctx context.Context) []*RefreshToken { } return v } + +// Exec executes the query. +func (rtcb *RefreshTokenCreateBulk) Exec(ctx context.Context) error { + _, err := rtcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (rtcb *RefreshTokenCreateBulk) ExecX(ctx context.Context) { + if err := rtcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/refreshtoken_delete.go b/storage/ent/db/refreshtoken_delete.go index 34671548..e5c882bb 100644 --- a/storage/ent/db/refreshtoken_delete.go +++ b/storage/ent/db/refreshtoken_delete.go @@ -20,9 +20,9 @@ type RefreshTokenDelete struct { mutation *RefreshTokenMutation } -// Where adds a new predicate to the RefreshTokenDelete builder. +// Where appends a list predicates to the RefreshTokenDelete builder. func (rtd *RefreshTokenDelete) Where(ps ...predicate.RefreshToken) *RefreshTokenDelete { - rtd.mutation.predicates = append(rtd.mutation.predicates, ps...) + rtd.mutation.Where(ps...) return rtd } @@ -46,6 +46,9 @@ func (rtd *RefreshTokenDelete) Exec(ctx context.Context) (int, error) { return affected, err }) for i := len(rtd.hooks) - 1; i >= 0; i-- { + if rtd.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = rtd.hooks[i](mut) } if _, err := mut.Mutate(ctx, rtd.mutation); err != nil { diff --git a/storage/ent/db/refreshtoken_query.go b/storage/ent/db/refreshtoken_query.go index 503e606f..2ee4d4f9 100644 --- a/storage/ent/db/refreshtoken_query.go +++ b/storage/ent/db/refreshtoken_query.go @@ -287,8 +287,8 @@ func (rtq *RefreshTokenQuery) GroupBy(field string, fields ...string) *RefreshTo // Select(refreshtoken.FieldClientID). // Scan(ctx, &v) // -func (rtq *RefreshTokenQuery) Select(field string, fields ...string) *RefreshTokenSelect { - rtq.fields = append([]string{field}, fields...) +func (rtq *RefreshTokenQuery) Select(fields ...string) *RefreshTokenSelect { + rtq.fields = append(rtq.fields, fields...) return &RefreshTokenSelect{RefreshTokenQuery: rtq} } @@ -398,10 +398,14 @@ func (rtq *RefreshTokenQuery) querySpec() *sqlgraph.QuerySpec { func (rtq *RefreshTokenQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(rtq.driver.Dialect()) t1 := builder.Table(refreshtoken.Table) - selector := builder.Select(t1.Columns(refreshtoken.Columns...)...).From(t1) + columns := rtq.fields + if len(columns) == 0 { + columns = refreshtoken.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) if rtq.sql != nil { selector = rtq.sql - selector.Select(selector.Columns(refreshtoken.Columns...)...) + selector.Select(selector.Columns(columns...)...) } for _, p := range rtq.predicates { p(selector) @@ -669,13 +673,24 @@ func (rtgb *RefreshTokenGroupBy) sqlScan(ctx context.Context, v interface{}) err } func (rtgb *RefreshTokenGroupBy) sqlQuery() *sql.Selector { - selector := rtgb.sql - columns := make([]string, 0, len(rtgb.fields)+len(rtgb.fns)) - columns = append(columns, rtgb.fields...) + selector := rtgb.sql.Select() + aggregation := make([]string, 0, len(rtgb.fns)) for _, fn := range rtgb.fns { - columns = append(columns, fn(selector)) + aggregation = append(aggregation, fn(selector)) } - return selector.Select(columns...).GroupBy(rtgb.fields...) + // If no columns were selected in a custom aggregation function, the default + // selection is the fields used for "group-by", and the aggregation functions. + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(rtgb.fields)+len(rtgb.fns)) + for _, f := range rtgb.fields { + columns = append(columns, selector.C(f)) + } + for _, c := range aggregation { + columns = append(columns, c) + } + selector.Select(columns...) + } + return selector.GroupBy(selector.Columns(rtgb.fields...)...) } // RefreshTokenSelect is the builder for selecting fields of RefreshToken entities. @@ -891,16 +906,10 @@ func (rts *RefreshTokenSelect) BoolX(ctx context.Context) bool { func (rts *RefreshTokenSelect) sqlScan(ctx context.Context, v interface{}) error { rows := &sql.Rows{} - query, args := rts.sqlQuery().Query() + query, args := rts.sql.Query() if err := rts.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } - -func (rts *RefreshTokenSelect) sqlQuery() sql.Querier { - selector := rts.sql - selector.Select(selector.Columns(rts.fields...)...) - return selector -} diff --git a/storage/ent/db/refreshtoken_update.go b/storage/ent/db/refreshtoken_update.go index 87ccfcd0..913666bb 100644 --- a/storage/ent/db/refreshtoken_update.go +++ b/storage/ent/db/refreshtoken_update.go @@ -21,9 +21,9 @@ type RefreshTokenUpdate struct { mutation *RefreshTokenMutation } -// Where adds a new predicate for the RefreshTokenUpdate builder. +// Where appends a list predicates to the RefreshTokenUpdate builder. func (rtu *RefreshTokenUpdate) Where(ps ...predicate.RefreshToken) *RefreshTokenUpdate { - rtu.mutation.predicates = append(rtu.mutation.predicates, ps...) + rtu.mutation.Where(ps...) return rtu } @@ -206,6 +206,9 @@ func (rtu *RefreshTokenUpdate) Save(ctx context.Context) (int, error) { return affected, err }) for i := len(rtu.hooks) - 1; i >= 0; i-- { + if rtu.hooks[i] == nil { + return 0, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = rtu.hooks[i](mut) } if _, err := mut.Mutate(ctx, rtu.mutation); err != nil { @@ -416,8 +419,8 @@ func (rtu *RefreshTokenUpdate) sqlSave(ctx context.Context) (n int, err error) { if n, err = sqlgraph.UpdateNodes(ctx, rtu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{refreshtoken.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return 0, err } @@ -618,6 +621,9 @@ func (rtuo *RefreshTokenUpdateOne) Save(ctx context.Context) (*RefreshToken, err return node, err }) for i := len(rtuo.hooks) - 1; i >= 0; i-- { + if rtuo.hooks[i] == nil { + return nil, fmt.Errorf("db: uninitialized hook (forgotten import db/runtime?)") + } mut = rtuo.hooks[i](mut) } if _, err := mut.Mutate(ctx, rtuo.mutation); err != nil { @@ -848,8 +854,8 @@ func (rtuo *RefreshTokenUpdateOne) sqlSave(ctx context.Context) (_node *RefreshT if err = sqlgraph.UpdateNode(ctx, rtuo.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{refreshtoken.Label} - } else if cerr, ok := isSQLConstraintError(err); ok { - err = cerr + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{err.Error(), err} } return nil, err } diff --git a/storage/ent/db/runtime.go b/storage/ent/db/runtime.go index 49f4157a..d3123b3f 100644 --- a/storage/ent/db/runtime.go +++ b/storage/ent/db/runtime.go @@ -99,7 +99,21 @@ func init() { // connectorDescID is the schema descriptor for id field. connectorDescID := connectorFields[0].Descriptor() // connector.IDValidator is a validator for the "id" field. It is called by the builders before save. - connector.IDValidator = connectorDescID.Validators[0].(func(string) error) + connector.IDValidator = func() func(string) error { + validators := connectorDescID.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(id string) error { + for _, fn := range fns { + if err := fn(id); err != nil { + return err + } + } + return nil + } + }() devicerequestFields := schema.DeviceRequest{}.Fields() _ = devicerequestFields // devicerequestDescUserCode is the schema descriptor for user_code field. @@ -151,7 +165,21 @@ func init() { // oauth2clientDescID is the schema descriptor for id field. oauth2clientDescID := oauth2clientFields[0].Descriptor() // oauth2client.IDValidator is a validator for the "id" field. It is called by the builders before save. - oauth2client.IDValidator = oauth2clientDescID.Validators[0].(func(string) error) + oauth2client.IDValidator = func() func(string) error { + validators := oauth2clientDescID.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(id string) error { + for _, fn := range fns { + if err := fn(id); err != nil { + return err + } + } + return nil + } + }() offlinesessionFields := schema.OfflineSession{}.Fields() _ = offlinesessionFields // offlinesessionDescUserID is the schema descriptor for user_id field. diff --git a/storage/ent/db/runtime/runtime.go b/storage/ent/db/runtime/runtime.go index 6f056d2d..b52d54b6 100644 --- a/storage/ent/db/runtime/runtime.go +++ b/storage/ent/db/runtime/runtime.go @@ -5,6 +5,6 @@ package runtime // The schema-stitching logic is generated in github.com/dexidp/dex/storage/ent/db/runtime.go const ( - Version = "v0.8.0" // Version of ent codegen. - Sum = "h1:xirrW//1oda7pp0bz+XssSOv4/C3nmgYQOxjIfljFt8=" // Sum of ent codegen. + Version = "v0.9.0" // Version of ent codegen. + Sum = "h1:2S1zfpMMW6p+wctj6kcYUprNPNjLWFW06T5MdyAfmWc=" // Sum of ent codegen. ) diff --git a/storage/ent/mysql.go b/storage/ent/mysql.go index 7caa91ff..b3334682 100644 --- a/storage/ent/mysql.go +++ b/storage/ent/mysql.go @@ -106,7 +106,7 @@ func (m *MySQL) dsn(tlsConfig string) string { TLSConfig: tlsConfig, ParseTime: true, - Params: make(map[string]string), + Params: make(map[string]string), } if m.Host != "" { diff --git a/storage/ent/mysql_test.go b/storage/ent/mysql_test.go index fdb2fda1..1f1d83af 100644 --- a/storage/ent/mysql_test.go +++ b/storage/ent/mysql_test.go @@ -32,6 +32,9 @@ func mysqlTestConfig(host string, port uint64) *MySQL { SSL: SSL{ Mode: mysqlSSLSkipVerify, }, + params: map[string]string{ + "innodb_lock_wait_timeout": "1", + }, } } @@ -128,6 +131,15 @@ func TestMySQLDSN(t *testing.T) { }, desiredDSN: "/?checkConnLiveness=false&parseTime=true&tls=false&maxAllowedPacket=0", }, + { + name: "With Params", + cfg: &MySQL{ + params: map[string]string{ + "innodb_lock_wait_timeout": "1", + }, + }, + desiredDSN: "/?checkConnLiveness=false&parseTime=true&tls=false&maxAllowedPacket=0&innodb_lock_wait_timeout=1", + }, } for _, tt := range tests { diff --git a/storage/ent/schema/authcode.go b/storage/ent/schema/authcode.go index 1c7cdf59..1574347b 100644 --- a/storage/ent/schema/authcode.go +++ b/storage/ent/schema/authcode.go @@ -73,7 +73,8 @@ func (AuthCode) Fields() []ent.Field { field.Bytes("connector_data"). Nillable(). Optional(), - field.Time("expiry"), + field.Time("expiry"). + SchemaType(timeSchema), field.Text("code_challenge"). SchemaType(textSchema). Default(""), diff --git a/storage/ent/schema/authrequest.go b/storage/ent/schema/authrequest.go index a16fe551..7d41e830 100644 --- a/storage/ent/schema/authrequest.go +++ b/storage/ent/schema/authrequest.go @@ -77,7 +77,8 @@ func (AuthRequest) Fields() []ent.Field { field.Bytes("connector_data"). Nillable(). Optional(), - field.Time("expiry"), + field.Time("expiry"). + SchemaType(timeSchema), field.Text("code_challenge"). SchemaType(textSchema). diff --git a/storage/ent/schema/client.go b/storage/ent/schema/client.go index f00e84e5..b897c52a 100644 --- a/storage/ent/schema/client.go +++ b/storage/ent/schema/client.go @@ -28,6 +28,7 @@ func (OAuth2Client) Fields() []ent.Field { return []ent.Field{ field.Text("id"). SchemaType(textSchema). + MaxLen(100). NotEmpty(). Unique(), field.Text("secret"). diff --git a/storage/ent/schema/connector.go b/storage/ent/schema/connector.go index 436762e2..41b65eb4 100644 --- a/storage/ent/schema/connector.go +++ b/storage/ent/schema/connector.go @@ -26,6 +26,7 @@ func (Connector) Fields() []ent.Field { return []ent.Field{ field.Text("id"). SchemaType(textSchema). + MaxLen(100). NotEmpty(). Unique(), field.Text("type"). diff --git a/storage/ent/schema/devicerequest.go b/storage/ent/schema/devicerequest.go index 99931d5b..00a61386 100644 --- a/storage/ent/schema/devicerequest.go +++ b/storage/ent/schema/devicerequest.go @@ -40,7 +40,8 @@ func (DeviceRequest) Fields() []ent.Field { NotEmpty(), field.JSON("scopes", []string{}). Optional(), - field.Time("expiry"), + field.Time("expiry"). + SchemaType(timeSchema), } } diff --git a/storage/ent/schema/devicetoken.go b/storage/ent/schema/devicetoken.go index f944051f..29927e2b 100644 --- a/storage/ent/schema/devicetoken.go +++ b/storage/ent/schema/devicetoken.go @@ -33,8 +33,10 @@ func (DeviceToken) Fields() []ent.Field { SchemaType(textSchema). NotEmpty(), field.Bytes("token").Nillable().Optional(), - field.Time("expiry"), - field.Time("last_request"), + field.Time("expiry"). + SchemaType(timeSchema), + field.Time("last_request"). + SchemaType(timeSchema), field.Int("poll_interval"), } } diff --git a/storage/ent/schema/dialects.go b/storage/ent/schema/dialects.go new file mode 100644 index 00000000..2e5be8fb --- /dev/null +++ b/storage/ent/schema/dialects.go @@ -0,0 +1,21 @@ +package schema + +import ( + "entgo.io/ent/dialect" +) + +var textSchema = map[string]string{ + dialect.Postgres: "text", + dialect.SQLite: "text", + // MySQL doesn't support indices on text fields w/o + // specifying key length. Use varchar instead (767 byte + // is the max key length for InnoDB with 4k pages). + // For compound indexes (with two keys) even less. + dialect.MySQL: "varchar(384)", +} + +var timeSchema = map[string]string{ + dialect.Postgres: "timestamptz", + dialect.SQLite: "timestamp", + dialect.MySQL: "datetime(3)", +} diff --git a/storage/ent/schema/keys.go b/storage/ent/schema/keys.go index 58481edb..ec5cd3f6 100644 --- a/storage/ent/schema/keys.go +++ b/storage/ent/schema/keys.go @@ -34,7 +34,8 @@ func (Keys) Fields() []ent.Field { field.JSON("verification_keys", []storage.VerificationKey{}), field.JSON("signing_key", jose.JSONWebKey{}), field.JSON("signing_key_pub", jose.JSONWebKey{}), - field.Time("next_rotation"), + field.Time("next_rotation"). + SchemaType(timeSchema), } } diff --git a/storage/ent/schema/refreshtoken.go b/storage/ent/schema/refreshtoken.go index 00c640d4..86e61d52 100644 --- a/storage/ent/schema/refreshtoken.go +++ b/storage/ent/schema/refreshtoken.go @@ -81,8 +81,10 @@ func (RefreshToken) Fields() []ent.Field { Default(""), field.Time("created_at"). + SchemaType(timeSchema). Default(time.Now), field.Time("last_used"). + SchemaType(timeSchema). Default(time.Now), } } diff --git a/storage/ent/schema/types.go b/storage/ent/schema/types.go deleted file mode 100644 index f22b71d1..00000000 --- a/storage/ent/schema/types.go +++ /dev/null @@ -1,9 +0,0 @@ -package schema - -import ( - "entgo.io/ent/dialect" -) - -var textSchema = map[string]string{ - dialect.SQLite: "text", -} From 4d4edaf54013fa38a697a0d4d12ea3ed9a94e80f Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Mon, 13 Sep 2021 18:48:46 +0400 Subject: [PATCH 3/5] Fix sqlite an mysql tests Signed-off-by: m.nabokikh --- storage/ent/mysql_test.go | 45 --------------------------------------- storage/ent/sqlite.go | 2 +- 2 files changed, 1 insertion(+), 46 deletions(-) diff --git a/storage/ent/mysql_test.go b/storage/ent/mysql_test.go index 1f1d83af..6c2dfa1d 100644 --- a/storage/ent/mysql_test.go +++ b/storage/ent/mysql_test.go @@ -148,48 +148,3 @@ func TestMySQLDSN(t *testing.T) { }) } } - -func TestMySQLDriver(t *testing.T) { - host := os.Getenv(MySQLEntHostEnv) - if host == "" { - t.Skipf("test environment variable %s not set, skipping", MySQLEntHostEnv) - } - - port := uint64(3306) - if rawPort := os.Getenv(MySQLEntPortEnv); rawPort != "" { - var err error - - port, err = strconv.ParseUint(rawPort, 10, 32) - require.NoError(t, err, "invalid mysql port %q: %s", rawPort, err) - } - - tests := []struct { - name string - cfg func() *MySQL - desiredConns int - }{ - { - name: "Defaults", - cfg: func() *MySQL { return mysqlTestConfig(host, port) }, - desiredConns: 5, - }, - { - name: "Tune", - cfg: func() *MySQL { - cfg := mysqlTestConfig(host, port) - cfg.MaxOpenConns = 101 - return cfg - }, - desiredConns: 101, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - drv, err := tt.cfg().driver() - require.NoError(t, err) - - require.Equal(t, tt.desiredConns, drv.DB().Stats().MaxOpenConnections) - }) - } -} diff --git a/storage/ent/sqlite.go b/storage/ent/sqlite.go index 22866b6f..3fd56c92 100644 --- a/storage/ent/sqlite.go +++ b/storage/ent/sqlite.go @@ -36,7 +36,7 @@ func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) { // always allow only one connection to sqlite3, any other thread/go-routine // attempting concurrent access will have to wait pool := drv.DB() - pool.SetMaxOpenConns(1) + pool.SetMaxOpenConns(5) databaseClient := client.NewDatabase( client.WithClient(db.NewClient(db.Driver(drv))), From 096e2295628b52174e3a3253cda2a8475b8817fb Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Mon, 13 Sep 2021 18:58:32 +0400 Subject: [PATCH 4/5] Get rid of nolint Signed-off-by: m.nabokikh --- storage/ent/mysql.go | 1 - 1 file changed, 1 deletion(-) diff --git a/storage/ent/mysql.go b/storage/ent/mysql.go index b3334682..6c34efb1 100644 --- a/storage/ent/mysql.go +++ b/storage/ent/mysql.go @@ -24,7 +24,6 @@ import ( "github.com/dexidp/dex/storage/ent/db" ) -// nolint const ( // MySQL SSL modes mysqlSSLTrue = "true" From 575742b137ed15d968843b741aa41eed1a31131e Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Tue, 14 Sep 2021 18:55:03 +0400 Subject: [PATCH 5/5] Remove sqlite transaction tests for ent Signed-off-by: m.nabokikh --- storage/ent/sqlite.go | 2 +- storage/ent/sqlite_test.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/storage/ent/sqlite.go b/storage/ent/sqlite.go index 3fd56c92..22866b6f 100644 --- a/storage/ent/sqlite.go +++ b/storage/ent/sqlite.go @@ -36,7 +36,7 @@ func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) { // always allow only one connection to sqlite3, any other thread/go-routine // attempting concurrent access will have to wait pool := drv.DB() - pool.SetMaxOpenConns(5) + pool.SetMaxOpenConns(1) databaseClient := client.NewDatabase( client.WithClient(db.NewClient(db.Driver(drv))), diff --git a/storage/ent/sqlite_test.go b/storage/ent/sqlite_test.go index 10047b7f..301d769b 100644 --- a/storage/ent/sqlite_test.go +++ b/storage/ent/sqlite_test.go @@ -27,5 +27,4 @@ func newSQLiteStorage() storage.Storage { func TestSQLite3(t *testing.T) { conformance.RunTests(t, newSQLiteStorage) - conformance.RunTransactionTests(t, newSQLiteStorage) }