forked from mystiq/dex
93 lines
2.6 KiB
Go
93 lines
2.6 KiB
Go
|
package client
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
|
||
|
"github.com/dexidp/dex/storage"
|
||
|
)
|
||
|
|
||
|
// CreateClient saves provided oauth2 client settings into the database.
|
||
|
func (d *Database) CreateClient(client storage.Client) error {
|
||
|
_, err := d.client.OAuth2Client.Create().
|
||
|
SetID(client.ID).
|
||
|
SetName(client.Name).
|
||
|
SetSecret(client.Secret).
|
||
|
SetPublic(client.Public).
|
||
|
SetLogoURL(client.LogoURL).
|
||
|
SetRedirectUris(client.RedirectURIs).
|
||
|
SetTrustedPeers(client.TrustedPeers).
|
||
|
Save(context.TODO())
|
||
|
if err != nil {
|
||
|
return convertDBError("create oauth2 client: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// ListClients extracts an array of oauth2 clients from the database.
|
||
|
func (d *Database) ListClients() ([]storage.Client, error) {
|
||
|
clients, err := d.client.OAuth2Client.Query().All(context.TODO())
|
||
|
if err != nil {
|
||
|
return nil, convertDBError("list clients: %w", err)
|
||
|
}
|
||
|
|
||
|
storageClients := make([]storage.Client, 0, len(clients))
|
||
|
for _, c := range clients {
|
||
|
storageClients = append(storageClients, toStorageClient(c))
|
||
|
}
|
||
|
return storageClients, nil
|
||
|
}
|
||
|
|
||
|
// GetClient extracts an oauth2 client from the database by id.
|
||
|
func (d *Database) GetClient(id string) (storage.Client, error) {
|
||
|
client, err := d.client.OAuth2Client.Get(context.TODO(), id)
|
||
|
if err != nil {
|
||
|
return storage.Client{}, convertDBError("get client: %w", err)
|
||
|
}
|
||
|
return toStorageClient(client), nil
|
||
|
}
|
||
|
|
||
|
// DeleteClient deletes an oauth2 client from the database by id.
|
||
|
func (d *Database) DeleteClient(id string) error {
|
||
|
err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO())
|
||
|
if err != nil {
|
||
|
return convertDBError("delete client: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// UpdateClient changes an oauth2 client by id using an updater function and saves it to the database.
|
||
|
func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||
|
tx, err := d.client.Tx(context.TODO())
|
||
|
if err != nil {
|
||
|
return convertDBError("update client tx: %w", err)
|
||
|
}
|
||
|
|
||
|
client, err := tx.OAuth2Client.Get(context.TODO(), id)
|
||
|
if err != nil {
|
||
|
return rollback(tx, "update client database: %w", err)
|
||
|
}
|
||
|
|
||
|
newClient, err := updater(toStorageClient(client))
|
||
|
if err != nil {
|
||
|
return rollback(tx, "update client updating: %w", err)
|
||
|
}
|
||
|
|
||
|
_, err = tx.OAuth2Client.UpdateOneID(newClient.ID).
|
||
|
SetName(newClient.Name).
|
||
|
SetSecret(newClient.Secret).
|
||
|
SetPublic(newClient.Public).
|
||
|
SetLogoURL(newClient.LogoURL).
|
||
|
SetRedirectUris(newClient.RedirectURIs).
|
||
|
SetTrustedPeers(newClient.TrustedPeers).
|
||
|
Save(context.TODO())
|
||
|
if err != nil {
|
||
|
return rollback(tx, "update client uploading: %w", err)
|
||
|
}
|
||
|
|
||
|
if err = tx.Commit(); err != nil {
|
||
|
return rollback(tx, "update auth request commit: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|