/* mCaptcha - A proof of work based DoS protection system * Copyright © 2021 Aravinth Manivannan * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ use std::collections::HashMap; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use std::time::Duration; use dashmap::DashMap; use serde::{Deserialize, Serialize}; use super::defense::Defense; use libmcaptcha::errors::*; use libmcaptcha::master::messages as MasterMessages; /// Builder for [MCaptcha] #[derive(Clone, Serialize, Deserialize, Debug)] pub struct MCaptchaBuilder { visitor_threshold: u32, defense: Option, duration: Option, } impl Default for MCaptchaBuilder { fn default() -> Self { MCaptchaBuilder { visitor_threshold: 0, defense: None, duration: None, } } } impl MCaptchaBuilder { /// set defense pub fn defense(&mut self, d: Defense) -> &mut Self { self.defense = Some(d); self } /// set duration pub fn duration(&mut self, d: u64) -> &mut Self { self.duration = Some(d); self } /// Builds new [MCaptcha] pub fn build(self: &mut MCaptchaBuilder) -> CaptchaResult { if self.duration.is_none() { Err(CaptchaError::PleaseSetValue("duration".into())) } else if self.defense.is_none() { Err(CaptchaError::PleaseSetValue("defense".into())) } else if self.duration <= Some(0) { Err(CaptchaError::CaptchaDurationZero) } else { let m = MCaptcha { duration: self.duration.unwrap(), defense: self.defense.clone().unwrap(), visitor_threshold: Arc::new(AtomicU32::new(self.visitor_threshold)), }; Ok(m) } } } #[derive(Clone, Serialize, Deserialize, Debug)] pub struct MCaptcha { visitor_threshold: Arc, defense: Defense, duration: u64, } impl MCaptcha { /// increments the visitor count by one #[inline] pub fn add_visitor(&self) -> u32 { // self.visitor_threshold += 1; let current_visitor_level = self.visitor_threshold.fetch_add(1, Ordering::SeqCst) + 1; let current_level = self.defense.current_level(current_visitor_level); current_level.difficulty_factor } /// decrements the visitor count by specified count #[inline] pub fn set_visitor_count(&self, new_current: u32) { self.visitor_threshold .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut current| { if current != new_current { Some(new_current) } else { None } }); } /// decrements the visitor count by specified count #[inline] pub fn decrement_visitor_by(&self, count: u32) { self.visitor_threshold .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut current| { if current > 0 { if current >= count { current -= count; } else { current = 0; } Some(current) } else { None } }); } } #[derive(Clone, Serialize, Deserialize)] pub struct Manager { pub captchas: Arc>>, pub gc: u64, } impl Manager { /// add [Counter] actor to [Master] pub fn add_captcha(&self, m: Arc, id: String) { self.captchas.insert(id, m); } /// create new master /// accepts a `u64` to configure garbage collection period pub fn new(gc: u64) -> Self { Manager { captchas: Arc::new(DashMap::new()), gc, } } /// get [Counter] actor from [Master] pub fn get_captcha(&self, id: &str) -> Option> { if let Some(captcha) = self.captchas.get(id) { Some(captcha.clone()) } else { None } } /// removes [Counter] actor from [Master] pub fn rm_captcha(&self, id: &str) -> Option<(String, Arc)> { self.captchas.remove(id) } /// renames [Counter] actor pub fn rename(&self, current_id: &str, new_id: String) { // If actor isn't present, it's okay to not throw an error // since actors are lazyily initialized and are cleaned up when inactive if let Some((_, captcha)) = self.captchas.remove(current_id) { self.add_captcha(captcha, new_id); } } pub async fn clean_all_after_cold_start(&self, updated: Manager) { updated.captchas.iter().map(|x| { self.captchas .insert(x.key().to_owned(), x.value().to_owned()) }); let captchas = self.clone(); let keys: Vec = captchas .captchas .clone() .iter() .map(|x| x.key().to_owned()) .collect(); let fut = async move { tokio::time::sleep(Duration::new(captchas.gc, 0)).await; for key in keys.iter() { captchas.rm_captcha(key); } }; tokio::spawn(fut); } pub fn add_visitor( &self, msg: &MasterMessages::AddVisitor, ) -> Option { if let Some(captcha) = self.captchas.get(&msg.0) { let difficulty_factor = captcha.add_visitor(); let c = captcha.clone(); let fut = async move { tokio::time::sleep(Duration::new(c.duration, 0)).await; c.decrement_visitor_by(1); }; tokio::spawn(fut); Some(libmcaptcha::master::AddVisitorResult { duration: captcha.duration, difficulty_factor, }) } else { None } } pub fn get_internal_data(&self) -> Arc>> { self.captchas.clone() } pub fn set_internal_data(&self, mut map: HashMap) { for (id, captcha) in map.drain() { let visitors = captcha.get_visitors(); let new_captcha: MCaptcha = (&captcha).into(); new_captcha.set_visitor_count(visitors); self.captchas.insert(id, Arc::new(new_captcha)); } } } impl From<&libmcaptcha::mcaptcha::MCaptcha> for MCaptcha { fn from(value: &libmcaptcha::mcaptcha::MCaptcha) -> Self { let mut defense = super::defense::DefenseBuilder::default(); for level in value.get_defense().get_levels() { let _ = defense.add_level(level); } let defense = defense.build().unwrap(); let new_captcha = MCaptchaBuilder::default() .defense(defense) .duration(value.get_duration()) .build() .unwrap(); new_captcha } }