// Copyright (C) 2022 Aravinth Manivannan // SPDX-FileCopyrightText: 2023 Aravinth Manivannan // // SPDX-License-Identifier: AGPL-3.0-or-later use std::str::FromStr; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgPoolOptions; use sqlx::types::time::OffsetDateTime; use sqlx::ConnectOptions; use sqlx::PgPool; use crate::errors::*; /// Connect to databse pub enum ConnectionOptions { /// fresh connection Fresh(Fresh), /// existing connection Existing(Conn), } /// Use an existing database pool pub struct Conn(pub PgPool); pub struct Fresh { pub pool_options: PgPoolOptions, pub disable_logging: bool, pub url: String, } impl ConnectionOptions { async fn connect(self) -> ServiceResult { let pool = match self { Self::Fresh(fresh) => { let mut connect_options = sqlx::postgres::PgConnectOptions::from_str(&fresh.url).unwrap(); if fresh.disable_logging { connect_options = connect_options.disable_statement_logging(); } sqlx::postgres::PgConnectOptions::from_str(&fresh.url) .unwrap() .disable_statement_logging(); fresh .pool_options .connect_with(connect_options) .await .unwrap() } Self::Existing(conn) => conn.0, }; Ok(Database { pool }) } } #[derive(Clone)] pub struct Database { pub pool: PgPool, } #[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] pub struct JobState { pub name: String, } impl JobState { pub fn new(name: String) -> Self { Self { name } } } lazy_static! { pub static ref JOB_STATE_CREATE: JobState = JobState::new("job.state.create".into()); pub static ref JOB_STATE_FINISH: JobState = JobState::new("job.state.finish".into()); pub static ref JOB_STATE_RUNNING: JobState = JobState::new("job.state.running".into()); pub static ref JOB_STATES: [&'static JobState; 3] = [&*JOB_STATE_CREATE, &*JOB_STATE_FINISH, &*JOB_STATE_RUNNING]; } impl Database { pub async fn migrate(&self) -> ServiceResult<()> { sqlx::migrate!("./migrations/") .run(&self.pool) .await .unwrap(); self.create_job_states().await?; Ok(()) } /// check if event type exists async fn job_state_exists(&self, job_state: &JobState) -> ServiceResult { let res = sqlx::query!( "SELECT EXISTS (SELECT 1 from ftest_job_states WHERE name = $1)", job_state.name, ) .fetch_one(&self.pool) .await .map_err(map_register_err)?; let mut resp = false; if let Some(x) = res.exists { resp = x; } Ok(resp) } async fn create_job_states(&self) -> ServiceResult<()> { for j in &*JOB_STATES { if !self.job_state_exists(j).await? { sqlx::query!( "INSERT INTO ftest_job_states (name) VALUES ($1) ON CONFLICT (name) DO NOTHING;", j.name ) .execute(&self.pool) .await .map_err(map_register_err)?; } } Ok(()) } pub async fn ping(&self) -> bool { use sqlx::Connection; if let Ok(mut con) = self.pool.acquire().await { con.ping().await.is_ok() } else { false } } pub async fn add_job(&self, commit_hash: &str) -> ServiceResult<()> { let now = now_unix_time_stamp(); sqlx::query!( "INSERT INTO ftest_jobs (commit_hash, job_state, created_at) VALUES ($1, (SELECT ID FROM ftest_job_states WHERE name = $2), $3)", commit_hash, &JOB_STATE_CREATE.name, now ) .execute(&self.pool) .await .map_err(map_register_err)?; Ok(()) } pub async fn mark_job_scheduled(&self, commit_hash: &str) -> ServiceResult<()> { let now = now_unix_time_stamp(); sqlx::query!( " UPDATE ftest_jobs SET job_state = (SELECT ID FROM ftest_job_states WHERE name = $1), scheduled_at = $2 WHERE commit_hash = $3;", &JOB_STATE_RUNNING.name, now, commit_hash, ) .execute(&self.pool) .await .map_err(map_register_err)?; Ok(()) } pub async fn mark_job_finished(&self, commit_hash: &str) -> ServiceResult<()> { let now = now_unix_time_stamp(); sqlx::query!( " UPDATE ftest_jobs SET job_state = (SELECT ID FROM ftest_job_states WHERE name = $1), finished_at = $2 WHERE commit_hash = $3;", &JOB_STATE_FINISH.name, now, commit_hash, ) .execute(&self.pool) .await .map_err(map_register_err)?; Ok(()) } pub async fn get_job(&self, commit_hash: &str) -> ServiceResult { let res = sqlx::query_as!( InnerJob, " SELECT ftest_jobs.ID, ftest_jobs.commit_hash, ftest_job_states.name, ftest_jobs.created_at, ftest_jobs.scheduled_at, ftest_jobs.finished_at FROM ftest_jobs INNER JOIN ftest_job_states ON ftest_job_states.ID = ftest_jobs.job_state WHERE ftest_jobs.commit_hash = $1", commit_hash ) .fetch_one(&self.pool) .await .map_err(map_register_err)?; Ok(res.into()) } pub async fn get_all_jobs_of_state(&self, state: &JobState) -> ServiceResult> { let mut res = sqlx::query_as!( InnerJob, " SELECT ftest_jobs.ID, ftest_jobs.commit_hash, ftest_job_states.name, ftest_jobs.created_at, ftest_jobs.scheduled_at, ftest_jobs.finished_at FROM ftest_jobs INNER JOIN ftest_job_states ON ftest_job_states.ID = ftest_jobs.job_state WHERE ftest_job_states.name = $1;", &state.name ) .fetch_all(&self.pool) .await .map_err(map_register_err)?; let res = res.drain(0..).map(|r| r.into()).collect(); Ok(res) } pub async fn delete_job(&self, commit_hash: &str) -> ServiceResult<()> { sqlx::query!( " DELETE FROM ftest_jobs WHERE commit_hash = $1;", commit_hash, ) .execute(&self.pool) .await .map_err(map_register_err)?; Ok(()) } pub async fn get_next_job_to_run(&self) -> ServiceResult { let res = sqlx::query_as!( SchedulerJob, "SELECT commit_hash FROM ftest_jobs WHERE job_state = (SELECT ID FROM ftest_job_states WHERE name = $1) AND finished_at is NULL AND scheduled_at is NULL ORDER BY created_at ASC;", &JOB_STATE_CREATE.name ) .fetch_one(&self.pool) .await .map_err(map_register_err)?; Ok(res) } } fn now_unix_time_stamp() -> OffsetDateTime { OffsetDateTime::now_utc() } pub async fn get_db(settings: &crate::settings::Settings) -> Database { let pool_options = PgPoolOptions::new().max_connections(settings.database.pool); ConnectionOptions::Fresh(Fresh { pool_options, url: settings.database.url.clone(), disable_logging: !settings.debug, }) .connect() .await .unwrap() } /// map custom row not found error to DB error pub fn map_row_not_found_err(e: sqlx::Error, row_not_found: ServiceError) -> ServiceError { if let sqlx::Error::RowNotFound = e { row_not_found } else { map_register_err(e) } } #[derive(Clone, Debug, PartialEq, Eq)] pub struct SchedulerJob { pub commit_hash: String, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct Job { pub state: JobState, pub commit_hash: String, pub id: u32, pub created_at: OffsetDateTime, pub scheduled_at: Option, pub finished_at: Option, } struct InnerJob { name: String, commit_hash: String, id: i32, created_at: OffsetDateTime, scheduled_at: Option, finished_at: Option, } impl From for Job { fn from(j: InnerJob) -> Self { Job { state: (JOB_STATES) .iter() .find(|d| d.name == j.name) .unwrap() .to_owned() .to_owned(), commit_hash: j.commit_hash, id: j.id as u32, created_at: j.created_at, scheduled_at: j.scheduled_at, finished_at: j.finished_at, } } } /// map postgres errors to [ServiceError](ServiceError) types fn map_register_err(e: sqlx::Error) -> ServiceError { use sqlx::Error; use std::borrow::Cow; if let Error::Database(err) = e { if err.code() == Some(Cow::from("23505")) { let msg = err.message(); unimplemented!("{}", msg); // if msg.contains("librepages_users_name_key") { // ServiceError::UsernameTaken // } else { // error!("{}", msg); // ServiceError::InternalServerError // } } else { ServiceError::InternalServerError } } else { ServiceError::InternalServerError } } #[cfg(test)] mod tests { use super::*; use crate::settings::Settings; #[actix_rt::test] async fn db_works() { let settings = Settings::new().unwrap(); let pool_options = PgPoolOptions::new().max_connections(1); let db = ConnectionOptions::Fresh(Fresh { pool_options, url: settings.database.url.clone(), disable_logging: !settings.debug, }) .connect() .await .unwrap(); assert!(db.ping().await); const COMMIT_HASH: &str = "pasdfasdfasdfadf"; const COMMIT_HASH2: &str = "pasdfasdfasdfadf22"; db.migrate().await.unwrap(); for e in (*JOB_STATES).iter() { println!("checking job state {}", e.name); assert!(db.job_state_exists(e).await.unwrap()); } let _ = db.delete_job(COMMIT_HASH).await; let _ = db.delete_job(COMMIT_HASH2).await; db.add_job(COMMIT_HASH).await.unwrap(); let job = db.get_job(COMMIT_HASH).await.unwrap(); db.add_job(COMMIT_HASH2).await.unwrap(); let job2 = db.get_job(COMMIT_HASH2).await.unwrap(); assert_eq!( db.get_next_job_to_run().await.unwrap().commit_hash, job.commit_hash ); assert!(job.created_at < now_unix_time_stamp()); assert!(job.scheduled_at.is_none()); assert!(job.finished_at.is_none()); assert_eq!( db.get_all_jobs_of_state(&*JOB_STATE_CREATE).await.unwrap(), vec![job, job2.clone()] ); db.mark_job_scheduled(COMMIT_HASH).await.unwrap(); let job = db.get_job(COMMIT_HASH).await.unwrap(); assert!(job.scheduled_at.is_some()); assert_eq!( db.get_all_jobs_of_state(&*JOB_STATE_RUNNING).await.unwrap(), vec![job] ); db.mark_job_finished(COMMIT_HASH).await.unwrap(); let job = db.get_job(COMMIT_HASH).await.unwrap(); assert!(job.finished_at.is_some()); assert_eq!( db.get_all_jobs_of_state(&*JOB_STATE_FINISH).await.unwrap(), vec![job] ); assert_eq!( db.get_next_job_to_run().await.unwrap().commit_hash, job2.commit_hash ); } }