package client import ( "context" "database/sql" "hash" "time" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/ent/db" "github.com/dexidp/dex/storage/ent/db/authcode" "github.com/dexidp/dex/storage/ent/db/authrequest" "github.com/dexidp/dex/storage/ent/db/devicerequest" "github.com/dexidp/dex/storage/ent/db/devicetoken" "github.com/dexidp/dex/storage/ent/db/migrate" ) var _ storage.Storage = (*Database)(nil) type Database struct { client *db.Client txOptions *sql.TxOptions hasher func() hash.Hash } // NewDatabase returns new database client with set options. func NewDatabase(opts ...func(*Database)) *Database { database := &Database{} for _, f := range opts { f(database) } return database } // WithClient sets client option of a Database object. func WithClient(c *db.Client) func(*Database) { return func(s *Database) { s.client = c } } // WithHasher sets client option of a Database object. func WithHasher(h func() hash.Hash) func(*Database) { return func(s *Database) { s.hasher = h } } // WithTxIsolationLevel sets correct isolation level for database transactions. func WithTxIsolationLevel(level sql.IsolationLevel) func(*Database) { return func(s *Database) { s.txOptions = &sql.TxOptions{Isolation: level} } } // Schema exposes migration schema to perform migrations. func (d *Database) Schema() *migrate.Schema { return d.client.Schema } // Close calls the corresponding method of the ent database client. func (d *Database) Close() error { return d.client.Close() } // BeginTx is a wrapper to begin transaction with defined options. func (d *Database) BeginTx(ctx context.Context) (*db.Tx, error) { return d.client.BeginTx(ctx, d.txOptions) } // GarbageCollect removes expired entities from the database. func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) { result := storage.GCResult{} utcNow := now.UTC() q, err := d.client.AuthRequest.Delete(). Where(authrequest.ExpiryLT(utcNow)). Exec(context.TODO()) if err != nil { return result, convertDBError("gc auth request: %w", err) } result.AuthRequests = int64(q) q, err = d.client.AuthCode.Delete(). Where(authcode.ExpiryLT(utcNow)). Exec(context.TODO()) if err != nil { return result, convertDBError("gc auth code: %w", err) } result.AuthCodes = int64(q) q, err = d.client.DeviceRequest.Delete(). Where(devicerequest.ExpiryLT(utcNow)). Exec(context.TODO()) if err != nil { return result, convertDBError("gc device request: %w", err) } result.DeviceRequests = int64(q) q, err = d.client.DeviceToken.Delete(). Where(devicetoken.ExpiryLT(utcNow)). Exec(context.TODO()) if err != nil { return result, convertDBError("gc device token: %w", err) } result.DeviceTokens = int64(q) return result, err }