/* mCaptcha - A proof of work based DoS protection system * Copyright © 2021 Aravinth Manivannan * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ use std::collections::HashMap; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use std::time::Duration; use dashmap::DashMap; use serde::{Deserialize, Serialize}; use super::defense::Defense; use libmcaptcha::errors::*; use libmcaptcha::master::messages as ManagerMessages; /// Builder for [MCaptcha] #[derive(Clone, Serialize, Deserialize, Debug)] pub struct MCaptchaBuilder { visitor_threshold: u32, defense: Option, duration: Option, } impl Default for MCaptchaBuilder { fn default() -> Self { MCaptchaBuilder { visitor_threshold: 0, defense: None, duration: None, } } } impl MCaptchaBuilder { /// set defense pub fn defense(&mut self, d: Defense) -> &mut Self { self.defense = Some(d); self } /// set duration pub fn duration(&mut self, d: u64) -> &mut Self { self.duration = Some(d); self } /// Builds new [MCaptcha] pub fn build(self: &mut MCaptchaBuilder) -> CaptchaResult { if self.duration.is_none() { Err(CaptchaError::PleaseSetValue("duration".into())) } else if self.defense.is_none() { Err(CaptchaError::PleaseSetValue("defense".into())) } else if self.duration <= Some(0) { Err(CaptchaError::CaptchaDurationZero) } else { let m = MCaptcha { duration: self.duration.unwrap(), defense: self.defense.clone().unwrap(), visitor_threshold: Arc::new(AtomicU32::new(self.visitor_threshold)), }; Ok(m) } } } #[derive(Clone, Serialize, Deserialize, Debug)] pub struct MCaptcha { visitor_threshold: Arc, defense: Defense, duration: u64, } impl MCaptcha { /// increments the visitor count by one #[inline] pub fn add_visitor(&self) -> u32 { // self.visitor_threshold += 1; let current_visitor_level = self.visitor_threshold.fetch_add(1, Ordering::SeqCst) + 1; let current_level = self.defense.current_level(current_visitor_level); current_level.difficulty_factor } /// decrements the visitor count by specified count #[inline] pub fn set_visitor_count(&self, new_current: u32) { self.visitor_threshold .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut current| { if current != new_current { Some(new_current) } else { None } }); } /// decrements the visitor count by specified count #[inline] pub fn decrement_visitor_by(&self, count: u32) { self.visitor_threshold .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut current| { if current > 0 { if current >= count { current -= count; } else { current = 0; } Some(current) } else { None } }); } /// get [Counter]'s current visitor_threshold pub fn get_visitors(&self) -> u32 { self.visitor_threshold.load(Ordering::SeqCst) } } #[derive(Clone, Serialize, Deserialize)] pub struct Manager { pub captchas: Arc>>, pub gc: u64, } impl Manager { /// add [Counter] actor to [Manager] pub fn add_captcha(&self, m: Arc, id: String) { self.captchas.insert(id, m); } /// create new master /// accepts a `u64` to configure garbage collection period pub fn new(gc: u64) -> Self { Manager { captchas: Arc::new(DashMap::new()), gc, } } fn gc(captchas: Arc>>) { for captcha in captchas.iter() { let visitor = { captcha.value().get_visitors() }; if visitor == 0 { captchas.remove(captcha.key()); } } } /// get [Counter] actor from [Manager] pub fn get_captcha(&self, id: &str) -> Option> { if let Some(captcha) = self.captchas.get(id) { Some(captcha.clone()) } else { None } } /// removes [Counter] actor from [Manager] pub fn rm_captcha(&self, id: &str) -> Option<(String, Arc)> { self.captchas.remove(id) } /// renames [Counter] actor pub fn rename(&self, current_id: &str, new_id: String) { // If actor isn't present, it's okay to not throw an error // since actors are lazyily initialized and are cleaned up when inactive if let Some((_, captcha)) = self.captchas.remove(current_id) { self.add_captcha(captcha, new_id); } } pub async fn clean_all_after_cold_start(&self, updated: Manager) { updated.captchas.iter().for_each(|x| { self.captchas .insert(x.key().to_owned(), x.value().to_owned()); }); let captchas = self.clone(); let keys: Vec = captchas .captchas .clone() .iter() .map(|x| x.key().to_owned()) .collect(); let fut = async move { tokio::time::sleep(Duration::new(captchas.gc, 0)).await; for key in keys.iter() { captchas.rm_captcha(key); } }; tokio::spawn(fut); } pub fn add_visitor( &self, msg: &ManagerMessages::AddVisitor, ) -> Option { if let Some(captcha) = self.captchas.get(&msg.0) { let difficulty_factor = captcha.add_visitor(); // let id = msg.0.clone(); let c = captcha.clone(); let captchas = self.captchas.clone(); let fut = async move { tokio::time::sleep(Duration::new(c.duration, 0)).await; c.decrement_visitor_by(1); // Self::gc(captchas); // if c.get_visitors() == 0 { // println!("Removing captcha addvivi"); // captchas.remove(&id); // } }; tokio::spawn(fut); Some(libmcaptcha::master::AddVisitorResult { duration: captcha.duration, difficulty_factor, }) } else { None } } pub fn get_internal_data(&self) -> HashMap { let mut res = HashMap::with_capacity(self.captchas.len()); for value in self.captchas.iter() { res.insert(value.key().to_owned(), value.value().as_ref().into()); } res } pub fn set_internal_data(&self, mut map: HashMap) { for (id, captcha) in map.drain() { let visitors = captcha.get_visitors(); let new_captcha: MCaptcha = (&captcha).into(); let new_captcha = Arc::new(new_captcha); self.captchas.insert(id.clone(), new_captcha.clone()); let msg = ManagerMessages::AddVisitor(id); for _ in 0..visitors { self.add_visitor(&msg); } } } } impl From<&libmcaptcha::mcaptcha::MCaptcha> for MCaptcha { fn from(value: &libmcaptcha::mcaptcha::MCaptcha) -> Self { let mut defense = super::defense::DefenseBuilder::default(); for level in value.get_defense().get_levels() { let _ = defense.add_level(level); } let defense = defense.build().unwrap(); let new_captcha = MCaptchaBuilder::default() .defense(defense) .duration(value.get_duration()) .build() .unwrap(); // for _ in 0..value.get_visitors() { // new_captcha.add_visitor(); // } new_captcha } } impl From<&MCaptcha> for libmcaptcha::mcaptcha::MCaptcha { fn from(value: &MCaptcha) -> Self { let mut defense = libmcaptcha::defense::DefenseBuilder::default(); for level in value.defense.get_levels().drain(0..) { let _ = defense.add_level(level); } let defense = defense.build().unwrap(); let mut new_captcha = libmcaptcha::mcaptcha::MCaptchaBuilder::default() .defense(defense) .duration(value.duration) .build() .unwrap(); for _ in 0..value.get_visitors() { new_captcha.add_visitor(); } new_captcha } } #[cfg(test)] mod tests { use super::*; use libmcaptcha::defense::LevelBuilder; use libmcaptcha::master::messages::*; pub const LEVEL_1: (u32, u32) = (50, 50); pub const LEVEL_2: (u32, u32) = (500, 500); pub const DURATION: u64 = 5; use crate::mcaptcha::defense::*; pub fn get_defense() -> Defense { DefenseBuilder::default() .add_level( LevelBuilder::default() .visitor_threshold(LEVEL_1.0) .difficulty_factor(LEVEL_1.1) .unwrap() .build() .unwrap(), ) .unwrap() .add_level( LevelBuilder::default() .visitor_threshold(LEVEL_2.0) .difficulty_factor(LEVEL_2.1) .unwrap() .build() .unwrap(), ) .unwrap() .build() .unwrap() } async fn race(manager: &Manager, id: String, count: (u32, u32)) { let msg = ManagerMessages::AddVisitor(id); for _ in 0..count.0 as usize - 1 { manager.add_visitor(&msg); } } // pub fn get_counter() -> Counter { // get_mcaptcha().into() // } pub fn get_mcaptcha() -> MCaptcha { MCaptchaBuilder::default() .defense(get_defense()) .duration(DURATION) .build() .unwrap() } #[actix_rt::test] async fn manager_works() { let manager = Manager::new(1); // let get_add_site_msg = |id: String, mcaptcha: MCaptcha| { // AddSiteBuilder::default() // .id(id) // .mcaptcha(mcaptcha) // .build() // .unwrap() // }; let id = "yo"; manager.add_captcha(Arc::new(get_mcaptcha()), id.into()); let mcaptcha_addr = manager.get_captcha(id); assert!(mcaptcha_addr.is_some()); let mut mcaptcha_data = manager.get_internal_data(); mcaptcha_data.get_mut(id).unwrap().add_visitor(); mcaptcha_data.get_mut(id).unwrap().add_visitor(); mcaptcha_data.get_mut(id).unwrap().add_visitor(); // let mcaptcha_data: HashMap = { // let serialized = serde_json::to_string(&mcaptcha_data).unwrap(); // serde_json::from_str(&serialized).unwrap() // }; // println!("{:?}", mcaptcha_data); manager.set_internal_data(mcaptcha_data); let mcaptcha_data = manager.get_internal_data(); assert_eq!( manager.get_captcha(id).unwrap().get_visitors(), mcaptcha_data.get(id).unwrap().get_visitors() ); let new_id = "yoyo"; manager.rename(id, new_id.into()); { let mcaptcha_addr = manager.get_captcha(new_id); assert!(mcaptcha_addr.is_some()); let addr_doesnt_exist = manager.get_captcha(id); assert!(addr_doesnt_exist.is_none()); let timer_expire = Duration::new(DURATION, 0); tokio::time::sleep(timer_expire).await; tokio::time::sleep(timer_expire).await; } // Manager::gc(manager.captchas.clone()); // let mcaptcha_addr = manager.get_captcha(new_id); // assert_eq!(mcaptcha_addr.as_ref().unwrap().get_visitors(), 0); // assert!(mcaptcha_addr.is_none()); // // assert!( // manager.rm_captcha(new_id.into()).is_some()); } #[actix_rt::test] async fn counter_defense_works() { let manager = Manager::new(1); let id = "yo"; manager.add_captcha(Arc::new(get_mcaptcha()), id.into()); let mut mcaptcha = manager .add_visitor(&ManagerMessages::AddVisitor(id.to_string())) .unwrap(); assert_eq!(mcaptcha.difficulty_factor, LEVEL_1.0); race(&manager, id.to_string(), LEVEL_2).await; mcaptcha = manager .add_visitor(&ManagerMessages::AddVisitor(id.to_string())) .unwrap(); assert_eq!(mcaptcha.difficulty_factor, LEVEL_2.1); tokio::time::sleep(Duration::new(DURATION * 2, 0)).await; assert_eq!(manager.get_captcha(id).unwrap().get_visitors(), 0); } } // //#[cfg(test)] //pub mod tests { // use super::*; // use crate::defense::*; // use crate::errors::*; // use crate::mcaptcha; // use crate::mcaptcha::MCaptchaBuilder; // // // constants for testing // // (visitor count, level) // pub const LEVEL_1: (u32, u32) = (50, 50); // pub const LEVEL_2: (u32, u32) = (500, 500); // pub const DURATION: u64 = 5; // // type MyActor = Addr; // // pub fn get_defense() -> Defense { // DefenseBuilder::default() // .add_level( // LevelBuilder::default() // .visitor_threshold(LEVEL_1.0) // .difficulty_factor(LEVEL_1.1) // .unwrap() // .build() // .unwrap(), // ) // .unwrap() // .add_level( // LevelBuilder::default() // .visitor_threshold(LEVEL_2.0) // .difficulty_factor(LEVEL_2.1) // .unwrap() // .build() // .unwrap(), // ) // .unwrap() // .build() // .unwrap() // } // // async fn race(addr: Addr, count: (u32, u32)) { // for _ in 0..count.0 as usize - 1 { // let _ = addr.send(AddVisitor).await.unwrap(); // } // } // // pub fn get_counter() -> Counter { // get_mcaptcha().into() // } // // pub fn get_mcaptcha() -> MCaptcha { // MCaptchaBuilder::default() // .defense(get_defense()) // .duration(DURATION) // .build() // .unwrap() // } // // #[test] // fn mcaptcha_decrement_by_works() { // let mut m = get_mcaptcha(); // for _ in 0..100 { // m.add_visitor(); // } // m.decrement_visitor_by(50); // assert_eq!(m.get_visitors(), 50); // m.decrement_visitor_by(500); // assert_eq!(m.get_visitors(), 0); // } // // // #[actix_rt::test] // async fn counter_defense_loosenup_works() { // //use actix::clock::sleep; // //use actix::clock::delay_for; // let addr: MyActor = get_counter().start(); // // race(addr.clone(), LEVEL_2).await; // race(addr.clone(), LEVEL_2).await; // let mut mcaptcha = addr.send(AddVisitor).await.unwrap(); // assert_eq!(mcaptcha.difficulty_factor, LEVEL_2.1); // // let duration = Duration::new(DURATION, 0); // sleep(duration).await; // //delay_for(duration).await; // // mcaptcha = addr.send(AddVisitor).await.unwrap(); // assert_eq!(mcaptcha.difficulty_factor, LEVEL_1.1); // } // // #[test] // fn test_mcatcptha_builder() { // let defense = get_defense(); // let m = MCaptchaBuilder::default() // .duration(0) // .defense(defense.clone()) // .build(); // // assert_eq!(m.err(), Some(CaptchaError::CaptchaDurationZero)); // // let m = MCaptchaBuilder::default().duration(30).build(); // assert_eq!( // m.err(), // Some(CaptchaError::PleaseSetValue("defense".into())) // ); // // let m = MCaptchaBuilder::default().defense(defense).build(); // assert_eq!( // m.err(), // Some(CaptchaError::PleaseSetValue("duration".into())) // ); // } // // #[actix_rt::test] // async fn get_current_visitor_count_works() { // let addr: MyActor = get_counter().start(); // // addr.send(AddVisitor).await.unwrap(); // addr.send(AddVisitor).await.unwrap(); // addr.send(AddVisitor).await.unwrap(); // addr.send(AddVisitor).await.unwrap(); // let count = addr.send(GetCurrentVisitorCount).await.unwrap(); // // assert_eq!(count, 4); // } // // #[actix_rt::test] // #[should_panic] // async fn stop_works() { // let addr: MyActor = get_counter().start(); // addr.send(Stop).await.unwrap(); // addr.send(AddVisitor).await.unwrap(); // } // // #[actix_rt::test] // async fn get_set_internal_data_works() { // let addr: MyActor = get_counter().start(); // let mut mcaptcha = addr.send(GetInternalData).await.unwrap(); // mcaptcha.add_visitor(); // addr.send(SetInternalData(mcaptcha.clone())).await.unwrap(); // assert_eq!( // addr.send(GetInternalData).await.unwrap().get_visitors(), // mcaptcha.get_visitors() // ); // // let duration = Duration::new(mcaptcha.get_duration() + 3, 0); // sleep(duration).await; // assert_eq!(addr.send(GetCurrentVisitorCount).await.unwrap(), 0); // } // // #[actix_rt::test] // async fn bulk_delete_works() { // let addr: MyActor = get_counter().start(); // addr.send(AddVisitor).await.unwrap(); // addr.send(AddVisitor).await.unwrap(); // assert_eq!(addr.send(GetCurrentVisitorCount).await.unwrap(), 2); // addr.send(BulkDecrement(3)).await.unwrap(); // assert_eq!(addr.send(GetCurrentVisitorCount).await.unwrap(), 0); // } //}