dcache/src/mcaptcha/mcaptcha.rs

596 lines
19 KiB
Rust

/* mCaptcha - A proof of work based DoS protection system
* Copyright © 2021 Aravinth Manivannan <realravinth@batsense.net>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use super::defense::Defense;
use libmcaptcha::errors::*;
use libmcaptcha::master::messages as ManagerMessages;
/// Builder for [MCaptcha]
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct MCaptchaBuilder {
visitor_threshold: u32,
defense: Option<Defense>,
duration: Option<u64>,
}
impl Default for MCaptchaBuilder {
fn default() -> Self {
MCaptchaBuilder {
visitor_threshold: 0,
defense: None,
duration: None,
}
}
}
impl MCaptchaBuilder {
/// set defense
pub fn defense(&mut self, d: Defense) -> &mut Self {
self.defense = Some(d);
self
}
/// set duration
pub fn duration(&mut self, d: u64) -> &mut Self {
self.duration = Some(d);
self
}
/// Builds new [MCaptcha]
pub fn build(self: &mut MCaptchaBuilder) -> CaptchaResult<MCaptcha> {
if self.duration.is_none() {
Err(CaptchaError::PleaseSetValue("duration".into()))
} else if self.defense.is_none() {
Err(CaptchaError::PleaseSetValue("defense".into()))
} else if self.duration <= Some(0) {
Err(CaptchaError::CaptchaDurationZero)
} else {
let m = MCaptcha {
duration: self.duration.unwrap(),
defense: self.defense.clone().unwrap(),
visitor_threshold: Arc::new(AtomicU32::new(self.visitor_threshold)),
};
Ok(m)
}
}
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct MCaptcha {
visitor_threshold: Arc<AtomicU32>,
defense: Defense,
duration: u64,
}
impl MCaptcha {
/// increments the visitor count by one
#[inline]
pub fn add_visitor(&self) -> u32 {
// self.visitor_threshold += 1;
let current_visitor_level = self.visitor_threshold.fetch_add(1, Ordering::SeqCst) + 1;
let current_level = self.defense.current_level(current_visitor_level);
current_level.difficulty_factor
}
/// decrements the visitor count by specified count
#[inline]
pub fn set_visitor_count(&self, new_current: u32) {
self.visitor_threshold
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut current| {
if current != new_current {
Some(new_current)
} else {
None
}
});
}
/// decrements the visitor count by specified count
#[inline]
pub fn decrement_visitor_by(&self, count: u32) {
self.visitor_threshold
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut current| {
if current > 0 {
if current >= count {
current -= count;
} else {
current = 0;
}
Some(current)
} else {
None
}
});
}
/// get [Counter]'s current visitor_threshold
pub fn get_visitors(&self) -> u32 {
self.visitor_threshold.load(Ordering::SeqCst)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Manager {
pub captchas: Arc<DashMap<String, Arc<MCaptcha>>>,
pub gc: u64,
}
impl Manager {
/// add [Counter] actor to [Manager]
pub fn add_captcha(&self, m: Arc<MCaptcha>, id: String) {
self.captchas.insert(id, m);
}
/// create new master
/// accepts a `u64` to configure garbage collection period
pub fn new(gc: u64) -> Self {
Manager {
captchas: Arc::new(DashMap::new()),
gc,
}
}
fn gc(captchas: Arc<DashMap<String, Arc<MCaptcha>>>) {
for captcha in captchas.iter() {
let visitor = { captcha.value().get_visitors() };
if visitor == 0 {
captchas.remove(captcha.key());
}
}
}
/// get [Counter] actor from [Manager]
pub fn get_captcha(&self, id: &str) -> Option<Arc<MCaptcha>> {
if let Some(captcha) = self.captchas.get(id) {
Some(captcha.clone())
} else {
None
}
}
/// removes [Counter] actor from [Manager]
pub fn rm_captcha(&self, id: &str) -> Option<(String, Arc<MCaptcha>)> {
self.captchas.remove(id)
}
/// renames [Counter] actor
pub fn rename(&self, current_id: &str, new_id: String) {
// If actor isn't present, it's okay to not throw an error
// since actors are lazyily initialized and are cleaned up when inactive
if let Some((_, captcha)) = self.captchas.remove(current_id) {
self.add_captcha(captcha, new_id);
}
}
pub async fn clean_all_after_cold_start(&self, updated: Manager) {
updated.captchas.iter().for_each(|x| {
self.captchas
.insert(x.key().to_owned(), x.value().to_owned());
});
let captchas = self.clone();
let keys: Vec<String> = captchas
.captchas
.clone()
.iter()
.map(|x| x.key().to_owned())
.collect();
let fut = async move {
tokio::time::sleep(Duration::new(captchas.gc, 0)).await;
for key in keys.iter() {
captchas.rm_captcha(key);
}
};
tokio::spawn(fut);
}
pub fn add_visitor(
&self,
msg: &ManagerMessages::AddVisitor,
) -> Option<libmcaptcha::master::AddVisitorResult> {
if let Some(captcha) = self.captchas.get(&msg.0) {
let difficulty_factor = captcha.add_visitor();
// let id = msg.0.clone();
let c = captcha.clone();
let captchas = self.captchas.clone();
let fut = async move {
tokio::time::sleep(Duration::new(c.duration, 0)).await;
c.decrement_visitor_by(1);
// Self::gc(captchas);
// if c.get_visitors() == 0 {
// println!("Removing captcha addvivi");
// captchas.remove(&id);
// }
};
tokio::spawn(fut);
Some(libmcaptcha::master::AddVisitorResult {
duration: captcha.duration,
difficulty_factor,
})
} else {
None
}
}
pub fn get_internal_data(&self) -> HashMap<String, libmcaptcha::mcaptcha::MCaptcha> {
let mut res = HashMap::with_capacity(self.captchas.len());
for value in self.captchas.iter() {
res.insert(value.key().to_owned(), value.value().as_ref().into());
}
res
}
pub fn set_internal_data(&self, mut map: HashMap<String, libmcaptcha::mcaptcha::MCaptcha>) {
for (id, captcha) in map.drain() {
let visitors = captcha.get_visitors();
let new_captcha: MCaptcha = (&captcha).into();
let new_captcha = Arc::new(new_captcha);
self.captchas.insert(id.clone(), new_captcha.clone());
let msg = ManagerMessages::AddVisitor(id);
for _ in 0..visitors {
self.add_visitor(&msg);
}
}
}
}
impl From<&libmcaptcha::mcaptcha::MCaptcha> for MCaptcha {
fn from(value: &libmcaptcha::mcaptcha::MCaptcha) -> Self {
let mut defense = super::defense::DefenseBuilder::default();
for level in value.get_defense().get_levels() {
let _ = defense.add_level(level);
}
let defense = defense.build().unwrap();
let new_captcha = MCaptchaBuilder::default()
.defense(defense)
.duration(value.get_duration())
.build()
.unwrap();
// for _ in 0..value.get_visitors() {
// new_captcha.add_visitor();
// }
new_captcha
}
}
impl From<&MCaptcha> for libmcaptcha::mcaptcha::MCaptcha {
fn from(value: &MCaptcha) -> Self {
let mut defense = libmcaptcha::defense::DefenseBuilder::default();
for level in value.defense.get_levels().drain(0..) {
let _ = defense.add_level(level);
}
let defense = defense.build().unwrap();
let mut new_captcha = libmcaptcha::mcaptcha::MCaptchaBuilder::default()
.defense(defense)
.duration(value.duration)
.build()
.unwrap();
for _ in 0..value.get_visitors() {
new_captcha.add_visitor();
}
new_captcha
}
}
#[cfg(test)]
mod tests {
use super::*;
use libmcaptcha::defense::LevelBuilder;
use libmcaptcha::master::messages::*;
pub const LEVEL_1: (u32, u32) = (50, 50);
pub const LEVEL_2: (u32, u32) = (500, 500);
pub const DURATION: u64 = 5;
use crate::mcaptcha::defense::*;
pub fn get_defense() -> Defense {
DefenseBuilder::default()
.add_level(
LevelBuilder::default()
.visitor_threshold(LEVEL_1.0)
.difficulty_factor(LEVEL_1.1)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(LEVEL_2.0)
.difficulty_factor(LEVEL_2.1)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.build()
.unwrap()
}
async fn race(manager: &Manager, id: String, count: (u32, u32)) {
let msg = ManagerMessages::AddVisitor(id);
for _ in 0..count.0 as usize - 1 {
manager.add_visitor(&msg);
}
}
// pub fn get_counter() -> Counter {
// get_mcaptcha().into()
// }
pub fn get_mcaptcha() -> MCaptcha {
MCaptchaBuilder::default()
.defense(get_defense())
.duration(DURATION)
.build()
.unwrap()
}
#[actix_rt::test]
async fn manager_works() {
let manager = Manager::new(1);
// let get_add_site_msg = |id: String, mcaptcha: MCaptcha| {
// AddSiteBuilder::default()
// .id(id)
// .mcaptcha(mcaptcha)
// .build()
// .unwrap()
// };
let id = "yo";
manager.add_captcha(Arc::new(get_mcaptcha()), id.into());
let mcaptcha_addr = manager.get_captcha(id);
assert!(mcaptcha_addr.is_some());
let mut mcaptcha_data = manager.get_internal_data();
mcaptcha_data.get_mut(id).unwrap().add_visitor();
mcaptcha_data.get_mut(id).unwrap().add_visitor();
mcaptcha_data.get_mut(id).unwrap().add_visitor();
// let mcaptcha_data: HashMap<String, libmcaptcha::mcaptcha::MCaptcha> = {
// let serialized = serde_json::to_string(&mcaptcha_data).unwrap();
// serde_json::from_str(&serialized).unwrap()
// };
// println!("{:?}", mcaptcha_data);
manager.set_internal_data(mcaptcha_data);
let mcaptcha_data = manager.get_internal_data();
assert_eq!(
manager.get_captcha(id).unwrap().get_visitors(),
mcaptcha_data.get(id).unwrap().get_visitors()
);
let new_id = "yoyo";
manager.rename(id, new_id.into());
{
let mcaptcha_addr = manager.get_captcha(new_id);
assert!(mcaptcha_addr.is_some());
let addr_doesnt_exist = manager.get_captcha(id);
assert!(addr_doesnt_exist.is_none());
let timer_expire = Duration::new(DURATION, 0);
tokio::time::sleep(timer_expire).await;
tokio::time::sleep(timer_expire).await;
}
// Manager::gc(manager.captchas.clone());
// let mcaptcha_addr = manager.get_captcha(new_id);
// assert_eq!(mcaptcha_addr.as_ref().unwrap().get_visitors(), 0);
// assert!(mcaptcha_addr.is_none());
//
// assert!(
// manager.rm_captcha(new_id.into()).is_some());
}
#[actix_rt::test]
async fn counter_defense_works() {
let manager = Manager::new(1);
let id = "yo";
manager.add_captcha(Arc::new(get_mcaptcha()), id.into());
let mut mcaptcha = manager
.add_visitor(&ManagerMessages::AddVisitor(id.to_string()))
.unwrap();
assert_eq!(mcaptcha.difficulty_factor, LEVEL_1.0);
race(&manager, id.to_string(), LEVEL_2).await;
mcaptcha = manager
.add_visitor(&ManagerMessages::AddVisitor(id.to_string()))
.unwrap();
assert_eq!(mcaptcha.difficulty_factor, LEVEL_2.1);
tokio::time::sleep(Duration::new(DURATION * 2, 0)).await;
assert_eq!(manager.get_captcha(id).unwrap().get_visitors(), 0);
}
}
//
//#[cfg(test)]
//pub mod tests {
// use super::*;
// use crate::defense::*;
// use crate::errors::*;
// use crate::mcaptcha;
// use crate::mcaptcha::MCaptchaBuilder;
//
// // constants for testing
// // (visitor count, level)
// pub const LEVEL_1: (u32, u32) = (50, 50);
// pub const LEVEL_2: (u32, u32) = (500, 500);
// pub const DURATION: u64 = 5;
//
// type MyActor = Addr<Counter>;
//
// pub fn get_defense() -> Defense {
// DefenseBuilder::default()
// .add_level(
// LevelBuilder::default()
// .visitor_threshold(LEVEL_1.0)
// .difficulty_factor(LEVEL_1.1)
// .unwrap()
// .build()
// .unwrap(),
// )
// .unwrap()
// .add_level(
// LevelBuilder::default()
// .visitor_threshold(LEVEL_2.0)
// .difficulty_factor(LEVEL_2.1)
// .unwrap()
// .build()
// .unwrap(),
// )
// .unwrap()
// .build()
// .unwrap()
// }
//
// async fn race(addr: Addr<Counter>, count: (u32, u32)) {
// for _ in 0..count.0 as usize - 1 {
// let _ = addr.send(AddVisitor).await.unwrap();
// }
// }
//
// pub fn get_counter() -> Counter {
// get_mcaptcha().into()
// }
//
// pub fn get_mcaptcha() -> MCaptcha {
// MCaptchaBuilder::default()
// .defense(get_defense())
// .duration(DURATION)
// .build()
// .unwrap()
// }
//
// #[test]
// fn mcaptcha_decrement_by_works() {
// let mut m = get_mcaptcha();
// for _ in 0..100 {
// m.add_visitor();
// }
// m.decrement_visitor_by(50);
// assert_eq!(m.get_visitors(), 50);
// m.decrement_visitor_by(500);
// assert_eq!(m.get_visitors(), 0);
// }
//
//
// #[actix_rt::test]
// async fn counter_defense_loosenup_works() {
// //use actix::clock::sleep;
// //use actix::clock::delay_for;
// let addr: MyActor = get_counter().start();
//
// race(addr.clone(), LEVEL_2).await;
// race(addr.clone(), LEVEL_2).await;
// let mut mcaptcha = addr.send(AddVisitor).await.unwrap();
// assert_eq!(mcaptcha.difficulty_factor, LEVEL_2.1);
//
// let duration = Duration::new(DURATION, 0);
// sleep(duration).await;
// //delay_for(duration).await;
//
// mcaptcha = addr.send(AddVisitor).await.unwrap();
// assert_eq!(mcaptcha.difficulty_factor, LEVEL_1.1);
// }
//
// #[test]
// fn test_mcatcptha_builder() {
// let defense = get_defense();
// let m = MCaptchaBuilder::default()
// .duration(0)
// .defense(defense.clone())
// .build();
//
// assert_eq!(m.err(), Some(CaptchaError::CaptchaDurationZero));
//
// let m = MCaptchaBuilder::default().duration(30).build();
// assert_eq!(
// m.err(),
// Some(CaptchaError::PleaseSetValue("defense".into()))
// );
//
// let m = MCaptchaBuilder::default().defense(defense).build();
// assert_eq!(
// m.err(),
// Some(CaptchaError::PleaseSetValue("duration".into()))
// );
// }
//
// #[actix_rt::test]
// async fn get_current_visitor_count_works() {
// let addr: MyActor = get_counter().start();
//
// addr.send(AddVisitor).await.unwrap();
// addr.send(AddVisitor).await.unwrap();
// addr.send(AddVisitor).await.unwrap();
// addr.send(AddVisitor).await.unwrap();
// let count = addr.send(GetCurrentVisitorCount).await.unwrap();
//
// assert_eq!(count, 4);
// }
//
// #[actix_rt::test]
// #[should_panic]
// async fn stop_works() {
// let addr: MyActor = get_counter().start();
// addr.send(Stop).await.unwrap();
// addr.send(AddVisitor).await.unwrap();
// }
//
// #[actix_rt::test]
// async fn get_set_internal_data_works() {
// let addr: MyActor = get_counter().start();
// let mut mcaptcha = addr.send(GetInternalData).await.unwrap();
// mcaptcha.add_visitor();
// addr.send(SetInternalData(mcaptcha.clone())).await.unwrap();
// assert_eq!(
// addr.send(GetInternalData).await.unwrap().get_visitors(),
// mcaptcha.get_visitors()
// );
//
// let duration = Duration::new(mcaptcha.get_duration() + 3, 0);
// sleep(duration).await;
// assert_eq!(addr.send(GetCurrentVisitorCount).await.unwrap(), 0);
// }
//
// #[actix_rt::test]
// async fn bulk_delete_works() {
// let addr: MyActor = get_counter().start();
// addr.send(AddVisitor).await.unwrap();
// addr.send(AddVisitor).await.unwrap();
// assert_eq!(addr.send(GetCurrentVisitorCount).await.unwrap(), 2);
// addr.send(BulkDecrement(3)).await.unwrap();
// assert_eq!(addr.send(GetCurrentVisitorCount).await.unwrap(), 0);
// }
//}