feat: custom counter implementation to solve bottleneck

This commit is contained in:
Aravinth Manivannan 2023-12-29 20:08:14 +05:30
parent 59180fd86f
commit 45a49288b7
Signed by: realaravinth
GPG key ID: F8F50389936984FF
3 changed files with 861 additions and 0 deletions

395
src/mcaptcha/defense.rs Normal file
View file

@ -0,0 +1,395 @@
/*
* mCaptcha - A proof of work based DoS protection system
* Copyright © 2021 Aravinth Manivannan <realravinth@batsense.net>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
use serde::{Deserialize, Serialize};
use libmcaptcha::defense::Level;
use libmcaptcha::errors::*;
//
///// Level struct that describes threshold-difficulty factor mapping
//#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq)]
//pub struct Level {
// pub visitor_threshold: u32,
// pub difficulty_factor: u32,
//}
//
///// Bulder struct for [Level] to describe threshold-difficulty factor mapping
//#[derive(Debug, Copy, Clone, PartialEq)]
//pub struct LevelBuilder {
// visitor_threshold: Option<u32>,
// difficulty_factor: Option<u32>,
//}
//
//impl Default for LevelBuilder {
// fn default() -> Self {
// LevelBuilder {
// visitor_threshold: None,
// difficulty_factor: None,
// }
// }
//}
//
//impl LevelBuilder {
// /// set visitor count for level
// pub fn visitor_threshold(&mut self, visitor_threshold: u32) -> &mut Self {
// self.visitor_threshold = Some(visitor_threshold);
// self
// }
//
// /// set difficulty factor for level. difficulty_factor can't be zero because
// /// Difficulty is calculated as:
// /// ```no_run
// /// let difficulty_factor = 500;
// /// let difficulty = u128::max_value() - u128::max_value() / difficulty_factor;
// /// ```
// /// the higher the `difficulty_factor`, the higher the difficulty.
// pub fn difficulty_factor(&mut self, difficulty_factor: u32) -> CaptchaResult<&mut Self> {
// if difficulty_factor > 0 {
// self.difficulty_factor = Some(difficulty_factor);
// Ok(self)
// } else {
// Err(CaptchaError::DifficultyFactorZero)
// }
// }
//
// /// build Level struct
// pub fn build(&mut self) -> CaptchaResult<Level> {
// if self.visitor_threshold.is_none() {
// Err(CaptchaError::SetVisitorThreshold)
// } else if self.difficulty_factor.is_none() {
// Err(CaptchaError::SetDifficultyFactor)
// } else {
// Ok(Level {
// difficulty_factor: self.difficulty_factor.unwrap(),
// visitor_threshold: self.visitor_threshold.unwrap(),
// })
// }
// }
//}
//
/// Builder struct for [Defense]
#[derive(Debug, Clone, PartialEq)]
pub struct DefenseBuilder {
levels: Vec<Level>,
}
impl Default for DefenseBuilder {
fn default() -> Self {
DefenseBuilder { levels: vec![] }
}
}
impl DefenseBuilder {
/// add a level to [Defense]
pub fn add_level(&mut self, level: Level) -> CaptchaResult<&mut Self> {
for i in self.levels.iter() {
if i.visitor_threshold == level.visitor_threshold {
return Err(CaptchaError::DuplicateVisitorCount);
}
}
self.levels.push(level);
Ok(self)
}
/// Build [Defense]
pub fn build(&mut self) -> CaptchaResult<Defense> {
if !self.levels.is_empty() {
// sort levels to arrange in ascending order
self.levels.sort_by_key(|a| a.visitor_threshold);
for level in self.levels.iter() {
if level.difficulty_factor == 0 {
return Err(CaptchaError::DifficultyFactorZero);
}
}
// as visitor count increases, difficulty_factor too should increse
// if it decreses, an error must be thrown
for i in 0..self.levels.len() - 1 {
if self.levels[i].difficulty_factor > self.levels[i + 1].difficulty_factor {
return Err(CaptchaError::DecreaseingDifficultyFactor);
}
}
Ok(Defense {
levels: self.levels.to_owned(),
})
} else {
Err(CaptchaError::LevelEmpty)
}
}
}
/// struct describes all the different [Level]s at which an mCaptcha system operates
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Defense {
levels: Vec<Level>,
// index of current visitor threshold
}
impl From<Defense> for Vec<Level> {
fn from(d: Defense) -> Self {
d.levels
}
}
impl Defense {
///! Difficulty is calculated as:
///! ```rust
///! let difficulty = u128::max_value() - u128::max_value() / difficulty_factor;
///! ```
///! The higher the `difficulty_factor`, the higher the difficulty.
// /// Get difficulty factor of current level of defense
// pub fn get_difficulty(&self, current_visitor_threshold: usize) -> u32 {
// self.levels[current_visitor_threshold].difficulty_factor
// }
//
// /// tighten up defense. Increases defense level by a factor of one.
// /// When defense is at max level, calling this method will have no effect
// pub fn tighten_up(&mut self) {
// if self.current_visitor_threshold < self.levels.len() - 1 {
// self.current_visitor_threshold += 1;
// }
// }
// /// Loosen up defense. Decreases defense level by a factor of one.
// /// When defense is at the lowest level, calling this method will have no effect.
// pub fn loosen_up(&mut self) {
// if self.current_visitor_threshold > 0 {
// self.current_visitor_threshold -= 1;
// }
// }
//
// /// Set defense to maximum level
// pub fn max_defense(&mut self) {
// self.current_visitor_threshold = self.levels.len() - 1;
// }
//
// /// Set defense to minimum level
// pub fn min_defense(&mut self) {
// self.current_visitor_threshold = 0;
// }
//
/// Get current level's visitor threshold
pub fn current_level(&self, current_visitor_level: u32) -> &Level {
for level in self.levels.iter() {
if current_visitor_level <= level.visitor_threshold {
return level;
}
}
self.levels.last().as_ref().unwrap()
// &self.levels[self.current_visitor_threshold]
}
//
// /// Get current level's visitor threshold
// pub fn visitor_threshold(&self) -> u32 {
// self.levels[self.current_visitor_threshold].difficulty_factor
// }
}
#[cfg(test)]
mod tests {
use super::*;
use libmcaptcha::defense::Level;
use libmcaptcha::LevelBuilder;
#[test]
fn defense_builder_duplicate_visitor_threshold() {
let mut defense_builder = DefenseBuilder::default();
let err = defense_builder
.add_level(
LevelBuilder::default()
.visitor_threshold(50)
.difficulty_factor(50)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(50)
.difficulty_factor(50)
.unwrap()
.build()
.unwrap(),
);
assert_eq!(err, Err(CaptchaError::DuplicateVisitorCount));
}
#[test]
fn defense_builder_decreasing_difficulty_factor() {
let mut defense_builder = DefenseBuilder::default();
let err = defense_builder
.add_level(
LevelBuilder::default()
.visitor_threshold(50)
.difficulty_factor(50)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(500)
.difficulty_factor(10)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.build();
assert_eq!(err, Err(CaptchaError::DecreaseingDifficultyFactor));
}
#[test]
fn checking_for_integer_overflow() {
let mut defense = DefenseBuilder::default()
.add_level(
LevelBuilder::default()
.visitor_threshold(5)
.difficulty_factor(5)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(10)
.difficulty_factor(50)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(20)
.difficulty_factor(60)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(30)
.difficulty_factor(65)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.build()
.unwrap();
// for _ in 0..500 {
// defense.tighten_up();
// }
//
// defense.get_difficulty();
// for _ in 0..500000 {
// defense.tighten_up();
// }
//
defense.current_level(10_000_000);
}
fn get_defense() -> Defense {
DefenseBuilder::default()
.add_level(
LevelBuilder::default()
.visitor_threshold(50)
.difficulty_factor(50)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(500)
.difficulty_factor(5000)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(5000)
.difficulty_factor(50000)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(50000)
.difficulty_factor(500000)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.add_level(
LevelBuilder::default()
.visitor_threshold(500000)
.difficulty_factor(5000000)
.unwrap()
.build()
.unwrap(),
)
.unwrap()
.build()
.unwrap()
}
#[test]
fn defense_builder_works() {
let defense = get_defense();
assert_eq!(defense.levels[0].difficulty_factor, 50);
assert_eq!(defense.levels[1].difficulty_factor, 5000);
assert_eq!(defense.levels[2].difficulty_factor, 50_000);
assert_eq!(defense.levels[3].difficulty_factor, 500_000);
assert_eq!(defense.levels[4].difficulty_factor, 5_000_000);
}
#[test]
fn tighten_up_works() {
let defense = get_defense();
assert_eq!(defense.current_level(0).difficulty_factor, 50);
assert_eq!(defense.current_level(500).difficulty_factor, 5_000);
assert_eq!(defense.current_level(501).difficulty_factor, 50_000);
assert_eq!(defense.current_level(5_000).difficulty_factor, 50_000);
assert_eq!(defense.current_level(5_001).difficulty_factor, 500_000);
assert_eq!(defense.current_level(50_000).difficulty_factor, 500_000);
assert_eq!(defense.current_level(50_001).difficulty_factor, 5_000_000);
assert_eq!(defense.current_level(500_000).difficulty_factor, 5_000_000);
assert_eq!(defense.current_level(500_001).difficulty_factor, 5_000_000);
}
}

464
src/mcaptcha/mcaptcha.rs Normal file
View file

@ -0,0 +1,464 @@
/* mCaptcha - A proof of work based DoS protection system
* Copyright © 2021 Aravinth Manivannan <realravinth@batsense.net>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use super::defense::Defense;
use libmcaptcha::errors::*;
use libmcaptcha::master::messages as MasterMessages;
/// Builder for [MCaptcha]
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct MCaptchaBuilder {
visitor_threshold: u32,
defense: Option<Defense>,
duration: Option<u64>,
}
impl Default for MCaptchaBuilder {
fn default() -> Self {
MCaptchaBuilder {
visitor_threshold: 0,
defense: None,
duration: None,
}
}
}
impl MCaptchaBuilder {
/// set defense
pub fn defense(&mut self, d: Defense) -> &mut Self {
self.defense = Some(d);
self
}
/// set duration
pub fn duration(&mut self, d: u64) -> &mut Self {
self.duration = Some(d);
self
}
/// Builds new [MCaptcha]
pub fn build(self: &mut MCaptchaBuilder) -> CaptchaResult<MCaptcha> {
if self.duration.is_none() {
Err(CaptchaError::PleaseSetValue("duration".into()))
} else if self.defense.is_none() {
Err(CaptchaError::PleaseSetValue("defense".into()))
} else if self.duration <= Some(0) {
Err(CaptchaError::CaptchaDurationZero)
} else {
let m = MCaptcha {
duration: self.duration.unwrap(),
defense: self.defense.clone().unwrap(),
visitor_threshold: Arc::new(AtomicU32::new(self.visitor_threshold)),
};
Ok(m)
}
}
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct MCaptcha {
visitor_threshold: Arc<AtomicU32>,
defense: Defense,
duration: u64,
}
//impl From<MCaptcha> for crate::master::CreateMCaptcha {
// fn from(m: MCaptcha) -> Self {
// Self {
// levels: m.defense.into(),
// duration: m.duration,
// }
// }
//}
impl MCaptcha {
/// increments the visitor count by one
#[inline]
pub fn add_visitor(&self) -> u32 {
// self.visitor_threshold += 1;
let current_visitor_level = self.visitor_threshold.fetch_add(1, Ordering::SeqCst) + 1;
let current_level = self.defense.current_level(current_visitor_level);
current_level.difficulty_factor
//if current_visitor_level > current_level.visitor_threshold {
// self.defense.tighten_up();
//} else {
// self.defense.loosen_up();
//}
}
/// decrements the visitor count by specified count
#[inline]
pub fn set_visitor_count(&self, new_current: u32) {
self.visitor_threshold
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut current| {
if current != new_current {
Some(new_current)
} else {
None
}
});
}
/// decrements the visitor count by specified count
#[inline]
pub fn decrement_visitor_by(&self, count: u32) {
self.visitor_threshold
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut current| {
if current > 0 {
if current >= count {
current -= count;
} else {
current = 0;
}
Some(current)
} else {
None
}
});
}
// /// get current difficulty factor
// #[inline]
// pub fn get_difficulty(&self) -> u32 {
// self.defense.get_difficulty()
// }
//
// /// get [Counter]'s lifetime
// #[inline]
// pub fn get_duration(&self) -> u64 {
// self.duration
// }
//
// /// get [Counter]'s current visitor_threshold
// #[inline]
// pub fn get_visitors(&self) -> u32 {
// self.visitor_threshold.load(Ordering::Relaxed)
// }
//
// /// get mCaptcha's defense configuration
// #[inline]
// pub fn get_defense(&self) -> Defense {
// self.defense.clone()
// }
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Manager {
pub captchas: DashMap<String, Arc<MCaptcha>>,
pub gc: u64,
}
impl Manager {
/// add [Counter] actor to [Master]
pub fn add_captcha(&self, m: Arc<MCaptcha>, id: String) {
self.captchas.insert(id, m);
}
/// create new master
/// accepts a `u64` to configure garbage collection period
pub fn new(gc: u64) -> Self {
Manager {
captchas: DashMap::new(),
gc,
}
}
/// get [Counter] actor from [Master]
pub fn get_captcha(&self, id: &str) -> Option<Arc<MCaptcha>> {
if let Some(captcha) = self.captchas.get(id) {
Some(captcha.clone())
} else {
None
}
}
/// removes [Counter] actor from [Master]
pub fn rm_captcha(&self, id: &str) -> Option<(String, Arc<MCaptcha>)> {
self.captchas.remove(id)
}
/// renames [Counter] actor
pub fn rename(&self, current_id: &str, new_id: String) {
// If actor isn't present, it's okay to not throw an error
// since actors are lazyily initialized and are cleaned up when inactive
if let Some((_, captcha)) = self.captchas.remove(current_id) {
self.add_captcha(captcha, new_id);
}
}
pub fn add_visitor(
&self,
msg: &MasterMessages::AddVisitor,
) -> Option<libmcaptcha::master::AddVisitorResult> {
if let Some(captcha) = self.captchas.get(&msg.0) {
let difficulty_factor = captcha.add_visitor();
let c = captcha.clone();
let fut = async move {
tokio::time::sleep(Duration::new(c.duration, 0)).await;
c.decrement_visitor_by(1);
};
tokio::spawn(fut);
Some(libmcaptcha::master::AddVisitorResult {
duration: captcha.duration,
difficulty_factor,
})
} else {
None
}
}
pub fn get_internal_data(&self) -> DashMap<String, Arc<MCaptcha>> {
self.captchas.clone()
}
pub fn set_internal_data(&self, mut map: HashMap<String, libmcaptcha::mcaptcha::MCaptcha>) {
for (id, captcha) in map.drain() {
let visitors = captcha.get_visitors();
let new_captcha: MCaptcha = (&captcha).into();
new_captcha.set_visitor_count(visitors);
self.captchas.insert(id, Arc::new(new_captcha));
}
}
}
impl From<&libmcaptcha::mcaptcha::MCaptcha> for MCaptcha {
fn from(value: &libmcaptcha::mcaptcha::MCaptcha) -> Self {
let mut defense = super::defense::DefenseBuilder::default();
for level in value.get_defense().get_levels() {
let _ = defense.add_level(level);
}
let defense = defense.build().unwrap();
let new_captcha = MCaptchaBuilder::default()
.defense(defense)
.duration(value.get_duration())
.build()
.unwrap();
new_captcha
}
}
//impl Actor for Master {
// type Context = Context<Self>;
//
// fn started(&mut self, ctx: &mut Self::Context) {
// let addr = ctx.address();
// let task = async move {
// addr.send(CleanUp).await.unwrap();
// }
// .into_actor(self);
// ctx.spawn(task);
// }
//}
//
//impl Handler<AddVisitor> for Master {
// type Result = MessageResult<AddVisitor>;
//
// fn handle(&mut self, m: AddVisitor, ctx: &mut Self::Context) -> Self::Result {
// let (tx, rx) = channel();
// match self.get_site(&m.0) {
// None => {
// let _ = tx.send(Ok(None));
// }
// Some(addr) => {
// let fut = async move {
// match addr.send(super::counter::AddVisitor).await {
// Ok(val) => {
// let _ = tx.send(Ok(Some(val)));
// }
// Err(e) => {
// let err: CaptchaError = e.into();
// let _ = tx.send(Err(err));
// }
// }
// }
// .into_actor(self);
// ctx.spawn(fut);
// }
// }
// MessageResult(rx)
// }
//}
//
//impl Handler<Rename> for Master {
// type Result = MessageResult<Rename>;
//
// fn handle(&mut self, m: Rename, _ctx: &mut Self::Context) -> Self::Result {
// self.rename(m);
// let (tx, rx) = channel();
// let _ = tx.send(Ok(()));
// MessageResult(rx)
// }
//}
//
///// Message to get an [Counter] actor from master
//#[derive(Message)]
//#[rtype(result = "Option<Addr<Counter>>")]
//pub struct GetSite(pub String);
//
//impl Handler<GetSite> for Master {
// type Result = MessageResult<GetSite>;
//
// fn handle(&mut self, m: GetSite, _ctx: &mut Self::Context) -> Self::Result {
// let addr = self.get_site(&m.0);
// match addr {
// None => MessageResult(None),
// Some(addr) => MessageResult(Some(addr)),
// }
// }
//}
//
///// Message to clean up master of [Counter] actors with zero visitor count
//#[derive(Message)]
//#[rtype(result = "()")]
//pub struct CleanUp;
//
//impl Handler<CleanUp> for Master {
// type Result = ();
//
// fn handle(&mut self, _: CleanUp, ctx: &mut Self::Context) -> Self::Result {
// let sites = self.sites.clone();
// let gc = self.gc;
// let master = ctx.address();
// info!("init master actor cleanup up");
// let task = async move {
// for (id, (new, addr)) in sites.iter() {
// let visitor_count = addr.send(GetCurrentVisitorCount).await.unwrap();
// if visitor_count == 0 && new.is_some() {
// addr.send(Stop).await.unwrap();
// master.send(RemoveCaptcha(id.to_owned())).await.unwrap();
// }
// }
//
// let duration = Duration::new(gc, 0);
// sleep(duration).await;
// //delay_for(duration).await;
// master.send(CleanUp).await.unwrap();
// }
// .into_actor(self);
// ctx.spawn(task);
// }
//}
//
//impl Handler<RemoveCaptcha> for Master {
// type Result = MessageResult<RemoveCaptcha>;
//
// fn handle(&mut self, m: RemoveCaptcha, ctx: &mut Self::Context) -> Self::Result {
// let (tx, rx) = channel();
// if let Some((_, addr)) = self.rm_site(&m.0) {
// let fut = async move {
// //addr.send(Stop).await?;
// let res: CaptchaResult<()> = addr.send(Stop).await.map_err(|e| e.into());
// let _ = tx.send(res);
// }
// .into_actor(self);
// ctx.spawn(fut);
// } else {
// tx.send(Ok(())).unwrap();
// }
// MessageResult(rx)
// }
//}
//
//impl Handler<AddSite> for Master {
// type Result = MessageResult<AddSite>;
//
// fn handle(&mut self, m: AddSite, _ctx: &mut Self::Context) -> Self::Result {
// let (tx, rx) = channel();
// let counter: Counter = m.mcaptcha.into();
// let addr = counter.start();
// self.add_site(addr, m.id);
// tx.send(Ok(())).unwrap();
// MessageResult(rx)
// }
//}
//impl Handler<GetInternalData> for Master {
// type Result = MessageResult<GetInternalData>;
//
// fn handle(&mut self, _m: GetInternalData, ctx: &mut Self::Context) -> Self::Result {
// let (tx, rx) = channel();
// let mut data = HashMap::with_capacity(self.sites.len());
//
// let sites = self.sites.clone();
// let fut = async move {
// for (name, (_read_val, addr)) in sites.iter() {
// match addr.send(super::counter::GetInternalData).await {
// Ok(val) => {
// data.insert(name.to_owned(), val);
// }
// Err(_e) => {
// println!("Trying to get data {name}. Failed");
// continue;
// // best-effort basis persistence
// // let err: CaptchaError = e.into();
// // let _ = tx.send(Err(err));
// // break;
// }
// }
//
// }
// tx.send(Ok(data));
// }
// .into_actor(self);
// ctx.spawn(fut);
//
// MessageResult(rx)
// }
//}
//
//impl Handler<SetInternalData> for Master {
// type Result = MessageResult<SetInternalData>;
//
// fn handle(&mut self, mut m: SetInternalData, ctx: &mut Self::Context) -> Self::Result {
// let (tx, rx) = channel();
// for (name, mcaptcha) in m.mcaptcha.drain() {
// let addr = self.get_site(&name);
// let master = ctx.address();
// let fut = async move {
// match addr {
// None => {
// master.send(AddSite { id: name, mcaptcha }).await.unwrap();
// }
// Some(addr) => {
// let _ = addr.send(super::counter::SetInternalData(mcaptcha)).await;
// // best effort basis
// //let err: CaptchaError = e.into();
// //let _ = tx.send(Err(err));
// }
// }
// }
// .into_actor(self);
// ctx.spawn(fut);
// }
//
// let _ = tx.send(Ok(()));
// MessageResult(rx)
// }
//}
//

2
src/mcaptcha/mod.rs Normal file
View file

@ -0,0 +1,2 @@
mod defense;
pub mod mcaptcha;