218 lines
5.9 KiB
Rust
218 lines
5.9 KiB
Rust
// Copyright (C) 2022 Aravinth Manivannan <realaravinth@batsense.net>
|
|
// SPDX-FileCopyrightText: 2023 Aravinth Manivannan <realaravinth@batsense.net>
|
|
//
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
use std::str::FromStr;
|
|
|
|
use lazy_static::lazy_static;
|
|
use serde::{Deserialize, Serialize};
|
|
use sqlx::postgres::PgPoolOptions;
|
|
use sqlx::types::time::OffsetDateTime;
|
|
use sqlx::ConnectOptions;
|
|
use sqlx::PgPool;
|
|
|
|
use crate::errors::*;
|
|
|
|
/// Connect to databse
|
|
pub enum ConnectionOptions {
|
|
/// fresh connection
|
|
Fresh(Fresh),
|
|
/// existing connection
|
|
Existing(Conn),
|
|
}
|
|
|
|
/// Use an existing database pool
|
|
pub struct Conn(pub PgPool);
|
|
|
|
pub struct Fresh {
|
|
pub pool_options: PgPoolOptions,
|
|
pub disable_logging: bool,
|
|
pub url: String,
|
|
}
|
|
|
|
impl ConnectionOptions {
|
|
async fn connect(self) -> ServiceResult<Database> {
|
|
let pool = match self {
|
|
Self::Fresh(fresh) => {
|
|
let mut connect_options =
|
|
sqlx::postgres::PgConnectOptions::from_str(&fresh.url).unwrap();
|
|
if fresh.disable_logging {
|
|
connect_options.disable_statement_logging();
|
|
}
|
|
sqlx::postgres::PgConnectOptions::from_str(&fresh.url)
|
|
.unwrap()
|
|
.disable_statement_logging();
|
|
fresh
|
|
.pool_options
|
|
.connect_with(connect_options)
|
|
.await
|
|
.unwrap()
|
|
}
|
|
|
|
Self::Existing(conn) => conn.0,
|
|
};
|
|
Ok(Database { pool })
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Database {
|
|
pub pool: PgPool,
|
|
}
|
|
|
|
#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
|
|
pub struct JobState {
|
|
pub name: String,
|
|
}
|
|
|
|
impl JobState {
|
|
pub fn new(name: String) -> Self {
|
|
Self { name }
|
|
}
|
|
}
|
|
|
|
lazy_static! {
|
|
pub static ref JOB_STATE_CREATE: JobState = JobState::new("job.state.create".into());
|
|
pub static ref JOB_STATE_FINISH: JobState = JobState::new("job.state.finish".into());
|
|
pub static ref JOB_STATE_RUNNING: JobState = JobState::new("job.state.running".into());
|
|
pub static ref JOB_STATES: [&'static JobState; 3] =
|
|
[&*JOB_STATE_CREATE, &*JOB_STATE_FINISH, &*JOB_STATE_RUNNING];
|
|
}
|
|
|
|
impl Database {
|
|
pub async fn migrate(&self) -> ServiceResult<()> {
|
|
sqlx::migrate!("./migrations/")
|
|
.run(&self.pool)
|
|
.await
|
|
.unwrap();
|
|
self.create_job_states().await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// check if event type exists
|
|
async fn job_state_exists(&self, job_state: &JobState) -> ServiceResult<bool> {
|
|
let res = sqlx::query!(
|
|
"SELECT EXISTS (SELECT 1 from ftest_job_states WHERE name = $1)",
|
|
job_state.name,
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await
|
|
.map_err(map_register_err)?;
|
|
|
|
let mut resp = false;
|
|
if let Some(x) = res.exists {
|
|
resp = x;
|
|
}
|
|
|
|
Ok(resp)
|
|
}
|
|
|
|
async fn create_job_states(&self) -> ServiceResult<()> {
|
|
for j in &*JOB_STATES {
|
|
if !self.job_state_exists(j).await? {
|
|
sqlx::query!(
|
|
"INSERT INTO ftest_job_states
|
|
(name) VALUES ($1) ON CONFLICT (name) DO NOTHING;",
|
|
j.name
|
|
)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(map_register_err)?;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn ping(&self) -> bool {
|
|
use sqlx::Connection;
|
|
|
|
if let Ok(mut con) = self.pool.acquire().await {
|
|
con.ping().await.is_ok()
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
}
|
|
|
|
fn now_unix_time_stamp() -> OffsetDateTime {
|
|
OffsetDateTime::now_utc()
|
|
}
|
|
|
|
pub async fn get_db(settings: &crate::settings::Settings) -> Database {
|
|
let pool_options = PgPoolOptions::new().max_connections(settings.database.pool);
|
|
ConnectionOptions::Fresh(Fresh {
|
|
pool_options,
|
|
url: settings.database.url.clone(),
|
|
disable_logging: !settings.debug,
|
|
})
|
|
.connect()
|
|
.await
|
|
.unwrap()
|
|
}
|
|
|
|
/// map custom row not found error to DB error
|
|
pub fn map_row_not_found_err(e: sqlx::Error, row_not_found: ServiceError) -> ServiceError {
|
|
if let sqlx::Error::RowNotFound = e {
|
|
row_not_found
|
|
} else {
|
|
map_register_err(e)
|
|
}
|
|
}
|
|
|
|
/// map postgres errors to [ServiceError](ServiceError) types
|
|
fn map_register_err(e: sqlx::Error) -> ServiceError {
|
|
use sqlx::Error;
|
|
use std::borrow::Cow;
|
|
|
|
if let Error::Database(err) = e {
|
|
if err.code() == Some(Cow::from("23505")) {
|
|
let msg = err.message();
|
|
unimplemented!("{}", msg);
|
|
// if msg.contains("librepages_users_name_key") {
|
|
// ServiceError::UsernameTaken
|
|
// } else {
|
|
// error!("{}", msg);
|
|
// ServiceError::InternalServerError
|
|
// }
|
|
} else {
|
|
ServiceError::InternalServerError
|
|
}
|
|
} else {
|
|
ServiceError::InternalServerError
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::settings::Settings;
|
|
|
|
#[actix_rt::test]
|
|
async fn db_works() {
|
|
let settings = Settings::new().unwrap();
|
|
let pool_options = PgPoolOptions::new().max_connections(1);
|
|
let db = ConnectionOptions::Fresh(Fresh {
|
|
pool_options,
|
|
url: settings.database.url.clone(),
|
|
disable_logging: !settings.debug,
|
|
})
|
|
.connect()
|
|
.await
|
|
.unwrap();
|
|
assert!(db.ping().await);
|
|
|
|
const EMAIL: &str = "postgresuser@foo.com";
|
|
const EMAIL2: &str = "postgresuser2@foo.com";
|
|
const NAME: &str = "postgresuser";
|
|
const PASSWORD: &str = "pasdfasdfasdfadf";
|
|
|
|
db.migrate().await.unwrap();
|
|
|
|
for e in (*JOB_STATES).iter() {
|
|
println!("checking job state {}", e.name);
|
|
assert!(db.job_state_exists(e).await.unwrap());
|
|
}
|
|
}
|
|
}
|