From 8150aa9ca1a1a12e64710ef12f5d07bb516af92b Mon Sep 17 00:00:00 2001 From: realaravinth Date: Sat, 10 Sep 2022 19:21:49 +0530 Subject: [PATCH] feat: init and load db --- config/default.toml | 19 +++++++ src/ctx.rs | 9 ++- src/db.rs | 131 ++++++++++++++++++++++++++++++++++++++++++++ src/errors.rs | 2 +- src/settings.rs | 79 ++++++++++++++++++++++++-- src/tests.rs | 2 +- 6 files changed, 234 insertions(+), 8 deletions(-) create mode 100644 src/db.rs diff --git a/config/default.toml b/config/default.toml index 038c5e4..7f10de2 100644 --- a/config/default.toml +++ b/config/default.toml @@ -1,3 +1,4 @@ +debug = true # source code of your copy of pages server. source_code = "https://github.com/realaravinth/pages" @@ -20,3 +21,21 @@ ip= "0.0.0.0" # with one also. workers = 2 domain = "demo.librepages.org" + + + +[database] +# This section deals with the database location and how to access it +# Please note that at the moment, we have support for only postgresqa. +# Example, if you are Batman, your config would be: +# hostname = "batcave.org" +# port = "5432" +# username = "batman" +# password = "somereallycomplicatedBatmanpassword" +hostname = "localhost" +port = "5432" +username = "postgres" +password = "password" +name = "postgres" +pool = 4 +database_type="postgres" # "postgres" diff --git a/src/ctx.rs b/src/ctx.rs index 27ddb9f..117247c 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -16,15 +16,20 @@ */ use std::sync::Arc; +use crate::db::*; use crate::settings::Settings; +pub type ArcCtx = Arc; + #[derive(Clone)] pub struct Ctx { pub settings: Settings, + pub db: Database, } impl Ctx { - pub fn new(settings: Settings) -> Arc { - Arc::new(Self { settings }) + pub async fn new(settings: Settings) -> Arc { + let db = get_db(&settings).await; + Arc::new(Self { settings, db }) } } diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..f5e05f5 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,131 @@ +/* + * Copyright (C) 2022 Aravinth Manivannan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +use std::str::FromStr; + +use sqlx::postgres::PgPoolOptions; +use sqlx::types::time::OffsetDateTime; +//use sqlx::types::Json; +use sqlx::ConnectOptions; +use sqlx::PgPool; + +use crate::errors::*; + +/// Connect to databse +pub enum ConnectionOptions { + /// fresh connection + Fresh(Fresh), + /// existing connection + Existing(Conn), +} + +/// Use an existing database pool +pub struct Conn(pub PgPool); + +pub struct Fresh { + pub pool_options: PgPoolOptions, + pub disable_logging: bool, + pub url: String, +} + +impl ConnectionOptions { + async fn connect(self) -> ServiceResult { + let pool = match self { + Self::Fresh(fresh) => { + let mut connect_options = + sqlx::postgres::PgConnectOptions::from_str(&fresh.url).unwrap(); + if fresh.disable_logging { + connect_options.disable_statement_logging(); + } + sqlx::postgres::PgConnectOptions::from_str(&fresh.url) + .unwrap() + .disable_statement_logging(); + fresh + .pool_options + .connect_with(connect_options) + .await + .unwrap() + //.map_err(|e| DBError::DBError(Box::new(e)))? + } + + Self::Existing(conn) => conn.0, + }; + Ok(Database { pool }) + } +} + +#[derive(Clone)] +pub struct Database { + pub pool: PgPool, +} + +impl Database { + pub async fn migrate(&self) -> ServiceResult<()> { + sqlx::migrate!("./migrations/") + .run(&self.pool) + .await + .unwrap(); + //.map_err(|e| DBError::DBError(Box::new(e)))?; + Ok(()) + } + + pub async fn ping(&self) -> bool { + use sqlx::Connection; + + if let Ok(mut con) = self.pool.acquire().await { + con.ping().await.is_ok() + } else { + false + } + } +} + +fn now_unix_time_stamp() -> OffsetDateTime { + OffsetDateTime::now_utc() +} + +pub async fn get_db(settings: &crate::settings::Settings) -> Database { + let pool_options = PgPoolOptions::new().max_connections(settings.database.pool); + ConnectionOptions::Fresh(Fresh { + pool_options, + url: settings.database.url.clone(), + disable_logging: !settings.debug, + }) + .connect() + .await + .unwrap() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::settings::Settings; + + #[actix_rt::test] + async fn db_works() { + let settings = Settings::new().unwrap(); + let pool_options = PgPoolOptions::new().max_connections(1); + let db = ConnectionOptions::Fresh(Fresh { + pool_options, + url: settings.database.url.clone(), + disable_logging: !settings.debug, + }) + .connect() + .await + .unwrap(); + assert!(db.ping().await); + } +} diff --git a/src/errors.rs b/src/errors.rs index ce9183d..e1085d3 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -14,7 +14,7 @@ * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ -//! represents all the ways a trait can fail using this crate +//! Represents all the ways a trait can fail using this crate use std::convert::From; use std::io::Error as FSErrorInner; use std::sync::Arc; diff --git a/src/settings.rs b/src/settings.rs index e3d53d3..a5f6e9f 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -18,7 +18,8 @@ use std::env; use std::path::Path; use std::sync::Arc; -use config::{Config, Environment, File}; +use config::{Config, ConfigError, Environment, File}; +use derive_more::Display; #[cfg(not(test))] use log::{error, warn}; @@ -26,6 +27,7 @@ use log::{error, warn}; use std::{println as warn, println as error}; use serde::Deserialize; +use serde::Serialize; use url::Url; use crate::errors::*; @@ -46,11 +48,39 @@ impl Server { } } +#[derive(Deserialize, Serialize, Display, Eq, PartialEq, Clone, Debug)] +#[serde(rename_all = "lowercase")] +pub enum DBType { + #[display(fmt = "postgres")] + Postgres, + // #[display(fmt = "maria")] + // Maria, +} + +impl DBType { + fn from_url(url: &Url) -> Result { + match url.scheme() { + // "mysql" => Ok(Self::Maria), + "postgres" => Ok(Self::Postgres), + _ => Err(ConfigError::Message("Unknown database type".into())), + } + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct Database { + pub url: String, + pub pool: u32, + pub database_type: DBType, +} + #[derive(Debug, Clone, Deserialize)] pub struct Settings { + pub debug: bool, pub server: Server, pub source_code: String, pub pages: Vec>, + pub database: Database, } #[cfg(not(tarpaulin_include))] @@ -84,15 +114,56 @@ impl Settings { s = s.add_source(Environment::with_prefix("PAGES").separator("__")); - let mut settings = s.build()?.try_deserialize::()?; - settings.check_url(); match env::var("PORT") { Ok(val) => { - settings.server.port = val.parse().unwrap(); + s = s.set_override("server.port", val).unwrap(); + //settings.server.port = val.parse().unwrap(); } Err(e) => warn!("couldn't interpret PORT: {}", e), } + if let Ok(val) = env::var("DATABASE_URL") { + // match env::var("DATABASE_URL") { + // Ok(val) => { + let url = Url::parse(&val).expect("couldn't parse Database URL"); + s = s.set_override("database.url", url.to_string()).unwrap(); + let database_type = DBType::from_url(&url).unwrap(); + s = s + .set_override("database.database_type", database_type.to_string()) + .unwrap(); + } + + // Err(_e) => { + // } + + let intermediate_config = s.build_cloned().unwrap(); + + s = s + .set_override( + "database.url", + format!( + r"postgres://{}:{}@{}:{}/{}", + intermediate_config + .get::("database.username") + .expect("Couldn't access database username"), + intermediate_config + .get::("database.password") + .expect("Couldn't access database password"), + intermediate_config + .get::("database.hostname") + .expect("Couldn't access database hostname"), + intermediate_config + .get::("database.port") + .expect("Couldn't access database port"), + intermediate_config + .get::("database.name") + .expect("Couldn't access database name") + ), + ) + .expect("Couldn't set database url"); + + let settings = s.build()?.try_deserialize::()?; + settings.check_url(); settings.init(); Ok(settings) diff --git a/src/tests.rs b/src/tests.rs index 8694c55..a6367c1 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -49,7 +49,7 @@ pub async fn get_data() -> (Temp, Arc) { println!("[log] Initialzing settings again with test config"); settings.init(); - (tmp_dir, Ctx::new(settings)) + (tmp_dir, Ctx::new(settings).await) } #[allow(dead_code, clippy::upper_case_acronyms)]