196 lines
6.0 KiB
Rust
196 lines
6.0 KiB
Rust
/*
|
|
* mCaptcha - A proof of work based DoS protection system
|
|
* Copyright © 2023 Aravinth Manivannan <realravinth@batsense.net>
|
|
*
|
|
* This program is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU Affero General Public License as
|
|
* published by the Free Software Foundation, either version 3 of the
|
|
* License, or (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU Affero General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU Affero General Public License
|
|
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
*/
|
|
use std::collections::BTreeSet;
|
|
use std::sync::Arc;
|
|
use std::sync::Mutex;
|
|
use std::time::Duration;
|
|
|
|
use openraft::error::ForwardToLeader;
|
|
use openraft::error::NetworkError;
|
|
use openraft::error::RPCError;
|
|
use openraft::error::RemoteError;
|
|
use openraft::BasicNode;
|
|
use openraft::RaftMetrics;
|
|
use openraft::TryAsRef;
|
|
use reqwest::Client;
|
|
use serde::de::DeserializeOwned;
|
|
use serde::Deserialize;
|
|
use serde::Serialize;
|
|
use tokio::time::timeout;
|
|
|
|
use crate::typ;
|
|
use crate::DcacheNodeId;
|
|
use crate::DcacheRequest;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Empty {}
|
|
|
|
pub struct DcacheClient {
|
|
pub leader: Arc<Mutex<(DcacheNodeId, String)>>,
|
|
|
|
pub inner: Client,
|
|
}
|
|
|
|
impl DcacheClient {
|
|
pub fn new(leader_id: DcacheNodeId, leader_addr: String) -> Self {
|
|
Self {
|
|
leader: Arc::new(Mutex::new((leader_id, leader_addr))),
|
|
inner: reqwest::Client::new(),
|
|
}
|
|
}
|
|
|
|
pub async fn write(
|
|
&self,
|
|
req: &DcacheRequest,
|
|
) -> Result<typ::ClientWriteResponse, typ::RPCError<typ::ClientWriteError>> {
|
|
self.send_rpc_to_leader("write", Some(req)).await
|
|
}
|
|
|
|
pub async fn read(&self, req: &String) -> Result<String, typ::RPCError> {
|
|
self.do_send_rpc_to_leader("read", Some(req)).await
|
|
}
|
|
|
|
pub async fn consistent_read(
|
|
&self,
|
|
req: &String,
|
|
) -> Result<String, typ::RPCError<typ::CheckIsLeaderError>> {
|
|
self.do_send_rpc_to_leader("consistent_read", Some(req))
|
|
.await
|
|
}
|
|
|
|
pub async fn init(&self) -> Result<(), typ::RPCError<typ::InitializeError>> {
|
|
self.do_send_rpc_to_leader("init", Some(&Empty {})).await
|
|
}
|
|
|
|
pub async fn add_learner(
|
|
&self,
|
|
req: (DcacheNodeId, String),
|
|
) -> Result<typ::ClientWriteResponse, typ::RPCError<typ::ClientWriteError>> {
|
|
self.send_rpc_to_leader("add-learner", Some(&req)).await
|
|
}
|
|
|
|
pub async fn change_membership(
|
|
&self,
|
|
req: &BTreeSet<DcacheNodeId>,
|
|
) -> Result<typ::ClientWriteResponse, typ::RPCError<typ::ClientWriteError>> {
|
|
self.send_rpc_to_leader("change-membership", Some(req))
|
|
.await
|
|
}
|
|
|
|
pub async fn metrics(&self) -> Result<RaftMetrics<DcacheNodeId, BasicNode>, typ::RPCError> {
|
|
self.do_send_rpc_to_leader("metrics", None::<&()>).await
|
|
}
|
|
|
|
async fn do_send_rpc_to_leader<Req, Resp, Err>(
|
|
&self,
|
|
uri: &str,
|
|
req: Option<&Req>,
|
|
) -> Result<Resp, typ::RPCError<Err>>
|
|
where
|
|
Req: Serialize + 'static,
|
|
Resp: Serialize + DeserializeOwned,
|
|
Err: std::error::Error + Serialize + DeserializeOwned,
|
|
{
|
|
let (leader_id, url) = {
|
|
let t = self.leader.lock().unwrap();
|
|
let target_addr = &t.1;
|
|
(t.0, format!("http://{}/{}", target_addr, uri))
|
|
};
|
|
|
|
let fu = if let Some(r) = req {
|
|
tracing::debug!(
|
|
">>> client send request to {}: {}",
|
|
url,
|
|
serde_json::to_string_pretty(&r).unwrap()
|
|
);
|
|
self.inner.post(url.clone()).json(r)
|
|
} else {
|
|
tracing::debug!(">>> client send request to {}", url,);
|
|
self.inner.get(url.clone())
|
|
}
|
|
.send();
|
|
|
|
let res = timeout(Duration::from_millis(3_000), fu).await;
|
|
let resp = match res {
|
|
Ok(x) => x.map_err(|e| RPCError::Network(NetworkError::new(&e)))?,
|
|
Err(timeout_err) => {
|
|
tracing::error!("timeout {} to url: {}", timeout_err, url);
|
|
return Err(RPCError::Network(NetworkError::new(&timeout_err)));
|
|
}
|
|
};
|
|
|
|
let res: Result<Resp, typ::RaftError<Err>> = resp
|
|
.json()
|
|
.await
|
|
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
|
|
tracing::debug!(
|
|
"<<< client recv reply from {}: {}",
|
|
url,
|
|
serde_json::to_string_pretty(&res).unwrap()
|
|
);
|
|
|
|
res.map_err(|e| RPCError::RemoteError(RemoteError::new(leader_id, e)))
|
|
}
|
|
|
|
async fn send_rpc_to_leader<Req, Resp, Err>(
|
|
&self,
|
|
uri: &str,
|
|
req: Option<&Req>,
|
|
) -> Result<Resp, typ::RPCError<Err>>
|
|
where
|
|
Req: Serialize + 'static,
|
|
Resp: Serialize + DeserializeOwned,
|
|
Err: std::error::Error
|
|
+ Serialize
|
|
+ DeserializeOwned
|
|
+ TryAsRef<typ::ForwardToLeader>
|
|
+ Clone,
|
|
{
|
|
// Retry at most 3 times to find a valid leader.
|
|
let mut n_retry = 3;
|
|
|
|
loop {
|
|
let res: Result<Resp, typ::RPCError<Err>> = self.do_send_rpc_to_leader(uri, req).await;
|
|
|
|
let rpc_err = match res {
|
|
Ok(x) => return Ok(x),
|
|
Err(rpc_err) => rpc_err,
|
|
};
|
|
|
|
if let Some(ForwardToLeader {
|
|
leader_id: Some(leader_id),
|
|
leader_node: Some(leader_node),
|
|
}) = rpc_err.forward_to_leader()
|
|
{
|
|
// Update target to the new leader.
|
|
{
|
|
let mut t = self.leader.lock().unwrap();
|
|
*t = (*leader_id, leader_node.addr.clone());
|
|
}
|
|
|
|
n_retry -= 1;
|
|
if n_retry > 0 {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
return Err(rpc_err);
|
|
}
|
|
}
|
|
}
|