dcache/src/store/mod.rs

522 lines
16 KiB
Rust

/*
* mCaptcha - A proof of work based DoS protection system
* Copyright © 2023 Aravinth Manivannan <realravinth@batsense.net>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::fmt::Debug;
use std::io::Cursor;
use std::ops::RangeBounds;
use std::sync::Arc;
use std::sync::Mutex;
use libmcaptcha::AddVisitorResult;
use libmcaptcha::MCaptcha;
use openraft::async_trait::async_trait;
use openraft::storage::LogState;
use openraft::storage::Snapshot;
use openraft::AnyError;
use openraft::BasicNode;
use openraft::Entry;
use openraft::EntryPayload;
use openraft::ErrorSubject;
use openraft::ErrorVerb;
use openraft::LogId;
use openraft::RaftLogReader;
use openraft::RaftSnapshotBuilder;
use openraft::RaftStorage;
use openraft::RaftTypeConfig;
use openraft::SnapshotMeta;
use openraft::StorageError;
use openraft::StorageIOError;
use openraft::StoredMembership;
use openraft::Vote;
use serde::Deserialize;
use serde::Serialize;
use tokio::sync::RwLock;
use crate::DcacheNodeId;
use crate::DcacheTypeConfig;
use libmcaptcha::cache::messages::{CachePoW, CacheResult, DeleteCaptchaResult, DeletePoW};
use libmcaptcha::master::messages::{
AddSite as AddCaptcha, AddVisitor, GetInternalData, RemoveCaptcha, Rename as RenameCaptcha,
SetInternalData,
};
use libmcaptcha::{master::embedded::master::Master as EmbeddedMaster, system::System, HashCache};
pub mod system;
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum DcacheRequest {
// master
AddVisitor(AddVisitor),
AddCaptcha(AddCaptcha),
RenameCaptcha(RenameCaptcha),
RemoveCaptcha(RemoveCaptcha),
//cache
CachePoW(CachePoW),
DeletePoW(DeletePoW),
CacheResult(CacheResult),
DeleteCaptchaResult(DeleteCaptchaResult),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum DcacheResponse {
AddVisitorResult(Option<AddVisitorResult>),
Empty, // AddCaptcha, RenameCaptcha, RemoveCaptcha, Cachepow, CacheResult,
// DeletePoW, DeleteCaptchaResult
}
#[derive(Debug)]
pub struct DcacheSnapshot {
pub meta: SnapshotMeta<DcacheNodeId, BasicNode>,
pub data: Vec<u8>,
}
pub struct DcacheStateMachine {
pub last_applied_log: Option<LogId<DcacheNodeId>>,
pub last_membership: StoredMembership<DcacheNodeId, BasicNode>,
/// Application data.
// pub data: Arc<System<HashCache, EmbeddedMaster>>,
pub counter: crate::mcaptcha::mcaptcha::Manager,
pub results: crate::mcaptcha::cache::HashCache,
}
#[derive(Serialize, Deserialize, Clone)]
struct PersistableStateMachine {
last_applied_log: Option<LogId<DcacheNodeId>>,
last_membership: StoredMembership<DcacheNodeId, BasicNode>,
counter: crate::mcaptcha::mcaptcha::Manager,
results: crate::mcaptcha::cache::HashCache,
}
impl PersistableStateMachine {
async fn from_statemachine(m: &DcacheStateMachine) -> Self {
let counter = m.counter.clone();
let results = m.results.clone();
Self {
last_applied_log: m.last_applied_log,
last_membership: m.last_membership.clone(),
counter,
results,
}
}
async fn to_statemachine(
self,
counter: crate::mcaptcha::mcaptcha::Manager,
results: crate::mcaptcha::cache::HashCache,
) -> DcacheStateMachine {
self.counter.clean_all_after_cold_start(counter).await;
self.results.clean_all_after_cold_start(results).await;
DcacheStateMachine {
last_applied_log: self.last_applied_log,
last_membership: self.last_membership,
results: self.results,
counter: self.counter,
}
}
}
pub struct DcacheStore {
last_purged_log_id: RwLock<Option<LogId<DcacheNodeId>>>,
/// The Raft log.
log: RwLock<BTreeMap<u64, Entry<DcacheTypeConfig>>>,
/// The Raft state machine.
pub state_machine: RwLock<DcacheStateMachine>,
/// The current granted vote.
vote: RwLock<Option<Vote<DcacheNodeId>>>,
snapshot_idx: Arc<Mutex<u64>>,
current_snapshot: RwLock<Option<DcacheSnapshot>>,
}
impl DcacheStore {
pub fn new(salt: String) -> Self {
let state_machine = RwLock::new(DcacheStateMachine {
last_applied_log: Default::default(),
last_membership: Default::default(),
counter: crate::mcaptcha::mcaptcha::Manager::new(30),
results: crate::mcaptcha::cache::HashCache::default(),
});
Self {
last_purged_log_id: Default::default(),
log: Default::default(),
state_machine,
vote: Default::default(),
snapshot_idx: Default::default(),
current_snapshot: Default::default(),
}
}
}
#[async_trait]
impl RaftLogReader<DcacheTypeConfig> for Arc<DcacheStore> {
async fn get_log_state(
&mut self,
) -> Result<LogState<DcacheTypeConfig>, StorageError<DcacheNodeId>> {
let log = self.log.read().await;
let last = log.iter().rev().next().map(|(_, ent)| ent.log_id);
let last_purged = *self.last_purged_log_id.read().await;
let last = match last {
None => last_purged,
Some(x) => Some(x),
};
Ok(LogState {
last_purged_log_id: last_purged,
last_log_id: last,
})
}
async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + Send + Sync>(
&mut self,
range: RB,
) -> Result<Vec<Entry<DcacheTypeConfig>>, StorageError<DcacheNodeId>> {
let log = self.log.read().await;
let response = log
.range(range.clone())
.map(|(_, val)| val.clone())
.collect::<Vec<_>>();
Ok(response)
}
}
#[async_trait]
impl RaftSnapshotBuilder<DcacheTypeConfig> for Arc<DcacheStore> {
#[tracing::instrument(level = "trace", skip(self))]
async fn build_snapshot(
&mut self,
) -> Result<Snapshot<DcacheTypeConfig>, StorageError<DcacheNodeId>> {
let data;
let last_applied_log;
let last_membership;
{
// Serialize the data of the state machine.
let state_machine = self.state_machine.read().await;
let persistable_state_machine =
PersistableStateMachine::from_statemachine(&state_machine).await;
data = serde_json::to_vec(&persistable_state_machine).map_err(|e| {
StorageIOError::new(
ErrorSubject::StateMachine,
ErrorVerb::Read,
AnyError::new(&e),
)
})?;
last_applied_log = state_machine.last_applied_log;
last_membership = state_machine.last_membership.clone();
}
let snapshot_idx = {
let mut l = self.snapshot_idx.lock().unwrap();
*l += 1;
*l
};
let snapshot_id = if let Some(last) = last_applied_log {
format!("{}-{}-{}", last.leader_id, last.index, snapshot_idx)
} else {
format!("--{}", snapshot_idx)
};
let meta = SnapshotMeta {
last_log_id: last_applied_log,
last_membership,
snapshot_id,
};
let snapshot = DcacheSnapshot {
meta: meta.clone(),
data: data.clone(),
};
{
let mut current_snapshot = self.current_snapshot.write().await;
*current_snapshot = Some(snapshot);
}
Ok(Snapshot {
meta,
snapshot: Box::new(Cursor::new(data)),
})
}
}
#[async_trait]
impl RaftStorage<DcacheTypeConfig> for Arc<DcacheStore> {
type LogReader = Self;
type SnapshotBuilder = Self;
#[tracing::instrument(level = "trace", skip(self))]
async fn save_vote(
&mut self,
vote: &Vote<DcacheNodeId>,
) -> Result<(), StorageError<DcacheNodeId>> {
let mut v = self.vote.write().await;
*v = Some(*vote);
Ok(())
}
async fn read_vote(
&mut self,
) -> Result<Option<Vote<DcacheNodeId>>, StorageError<DcacheNodeId>> {
Ok(*self.vote.read().await)
}
#[tracing::instrument(level = "trace", skip(self, entries))]
async fn append_to_log<I>(&mut self, entries: I) -> Result<(), StorageError<DcacheNodeId>>
where
I: IntoIterator<Item = Entry<DcacheTypeConfig>> + Send,
{
let mut log = self.log.write().await;
for entry in entries {
log.insert(entry.log_id.index, entry);
}
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn delete_conflict_logs_since(
&mut self,
log_id: LogId<DcacheNodeId>,
) -> Result<(), StorageError<DcacheNodeId>> {
tracing::debug!("delete_log: [{:?}, +oo)", log_id);
let mut log = self.log.write().await;
let keys = log
.range(log_id.index..)
.map(|(k, _v)| *k)
.collect::<Vec<_>>();
for key in keys {
log.remove(&key);
}
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn purge_logs_upto(
&mut self,
log_id: LogId<DcacheNodeId>,
) -> Result<(), StorageError<DcacheNodeId>> {
tracing::debug!("delete_log: [{:?}, +oo)", log_id);
{
let mut ld = self.last_purged_log_id.write().await;
assert!(*ld <= Some(log_id));
*ld = Some(log_id);
}
{
let mut log = self.log.write().await;
let keys = log
.range(..=log_id.index)
.map(|(k, _v)| *k)
.collect::<Vec<_>>();
for key in keys {
log.remove(&key);
}
}
Ok(())
}
async fn last_applied_state(
&mut self,
) -> Result<
(
Option<LogId<DcacheNodeId>>,
StoredMembership<DcacheNodeId, BasicNode>,
),
StorageError<DcacheNodeId>,
> {
let state_machine = self.state_machine.read().await;
Ok((
state_machine.last_applied_log,
state_machine.last_membership.clone(),
))
}
#[tracing::instrument(level = "trace", skip(self, entries))]
async fn apply_to_state_machine(
&mut self,
// entries: &[&Entry<DcacheTypeConfig>],
entries: &[Entry<DcacheTypeConfig>],
) -> Result<Vec<DcacheResponse>, StorageError<DcacheNodeId>> {
let mut res = Vec::with_capacity(entries.len());
let mut sm = self.state_machine.write().await;
for entry in entries {
tracing::debug!(%entry.log_id, "replicate to sm");
sm.last_applied_log = Some(entry.log_id);
match entry.payload {
EntryPayload::Blank => res.push(DcacheResponse::Empty),
EntryPayload::Normal(ref req) => match req {
DcacheRequest::AddVisitor(msg) => {
let r = sm.counter.add_visitor(msg);
res.push(DcacheResponse::AddVisitorResult(r));
}
DcacheRequest::AddCaptcha(msg) => {
sm.counter
.add_captcha(Arc::new((&msg.mcaptcha).into()), msg.id.clone());
res.push(DcacheResponse::Empty);
}
DcacheRequest::RenameCaptcha(msg) => {
sm.counter.rename(&msg.name, msg.rename_to.clone());
res.push(DcacheResponse::Empty);
}
DcacheRequest::RemoveCaptcha(msg) => {
sm.counter.rm_captcha(&msg.0);
res.push(DcacheResponse::Empty);
}
// cache
DcacheRequest::CachePoW(msg) => {
sm.results.cache_pow(msg.clone());
res.push(DcacheResponse::Empty);
}
DcacheRequest::DeletePoW(msg) => {
sm.results.remove_pow_config(&msg.0);
// sm.data.cache.send(msg.clone()).await.unwrap().unwrap();
res.push(DcacheResponse::Empty);
}
DcacheRequest::CacheResult(msg) => {
sm.results.cache_result(msg.clone());
res.push(DcacheResponse::Empty);
}
DcacheRequest::DeleteCaptchaResult(msg) => {
sm.results.remove_cache_result(&msg.token);
res.push(DcacheResponse::Empty);
}
},
EntryPayload::Membership(ref mem) => {
sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone());
res.push(DcacheResponse::Empty)
}
};
}
Ok(res)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn begin_receiving_snapshot(
&mut self,
// ) -> Result<Box<DcacheTypeConfig as RaftTypeConfig>::SnapshotData>, StorageError<DcacheNodeId>> {
) -> Result<Box<<DcacheTypeConfig as RaftTypeConfig>::SnapshotData>, StorageError<DcacheNodeId>>
{
Ok(Box::new(Cursor::new(Vec::new())))
}
#[tracing::instrument(level = "trace", skip(self, snapshot))]
async fn install_snapshot(
&mut self,
meta: &SnapshotMeta<DcacheNodeId, BasicNode>,
snapshot: Box<<DcacheTypeConfig as RaftTypeConfig>::SnapshotData>,
) -> Result<(), StorageError<DcacheNodeId>> {
tracing::info!(
{ snapshot_size = snapshot.get_ref().len() },
"decoding snapshot for installation"
);
let new_snapshot = DcacheSnapshot {
meta: meta.clone(),
data: snapshot.into_inner(),
};
// Update the state machine.
{
let updated_state_machine: PersistableStateMachine =
serde_json::from_slice(&new_snapshot.data).map_err(|e| {
StorageIOError::read_snapshot(Some(new_snapshot.meta.signature()), &e)
})?;
let mut state_machine = self.state_machine.write().await;
let updated_state_machine = updated_state_machine
.to_statemachine(state_machine.counter.clone(), state_machine.results.clone())
.await;
*state_machine = updated_state_machine;
}
// Update current snapshot.
let mut current_snapshot = self.current_snapshot.write().await;
*current_snapshot = Some(new_snapshot);
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
async fn get_current_snapshot(
&mut self,
) -> Result<Option<Snapshot<DcacheTypeConfig>>, StorageError<DcacheNodeId>> {
match &*self.current_snapshot.read().await {
Some(snapshot) => {
let data = snapshot.data.clone();
Ok(Some(Snapshot {
meta: snapshot.meta.clone(),
snapshot: Box::new(Cursor::new(data)),
}))
}
None => Ok(None),
}
}
async fn get_log_reader(&mut self) -> Self::LogReader {
self.clone()
}
async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
self.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn provision_dcache_store() -> Arc<DcacheStore> {
Arc::new(DcacheStore::new(
"adsfasdfasdfadsfadfadfadfadsfasdfasdfasdfasdf".into(),
))
}
#[test]
pub fn test_dcche_store() {
openraft::testing::Suite::test_all(provision_dcache_store).unwrap()
}
}