dcache/src/client.rs

196 lines
6.0 KiB
Rust

/*
* mCaptcha - A proof of work based DoS protection system
* Copyright © 2023 Aravinth Manivannan <realravinth@batsense.net>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
use std::collections::BTreeSet;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use openraft::error::ForwardToLeader;
use openraft::error::NetworkError;
use openraft::error::RPCError;
use openraft::error::RemoteError;
use openraft::BasicNode;
use openraft::RaftMetrics;
use openraft::TryAsRef;
use reqwest::Client;
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde::Serialize;
use tokio::time::timeout;
use crate::typ;
use crate::DcacheNodeId;
use crate::DcacheRequest;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Empty {}
pub struct DcacheClient {
pub leader: Arc<Mutex<(DcacheNodeId, String)>>,
pub inner: Client,
}
impl DcacheClient {
pub fn new(leader_id: DcacheNodeId, leader_addr: String) -> Self {
Self {
leader: Arc::new(Mutex::new((leader_id, leader_addr))),
inner: reqwest::Client::new(),
}
}
pub async fn write(
&self,
req: &DcacheRequest,
) -> Result<typ::ClientWriteResponse, typ::RPCError<typ::ClientWriteError>> {
self.send_rpc_to_leader("write", Some(req)).await
}
pub async fn read(&self, req: &String) -> Result<String, typ::RPCError> {
self.do_send_rpc_to_leader("read", Some(req)).await
}
pub async fn consistent_read(
&self,
req: &String,
) -> Result<String, typ::RPCError<typ::CheckIsLeaderError>> {
self.do_send_rpc_to_leader("consistent_read", Some(req))
.await
}
pub async fn init(&self) -> Result<(), typ::RPCError<typ::InitializeError>> {
self.do_send_rpc_to_leader("init", Some(&Empty {})).await
}
pub async fn add_learner(
&self,
req: (DcacheNodeId, String),
) -> Result<typ::ClientWriteResponse, typ::RPCError<typ::ClientWriteError>> {
self.send_rpc_to_leader("add-learner", Some(&req)).await
}
pub async fn change_membership(
&self,
req: &BTreeSet<DcacheNodeId>,
) -> Result<typ::ClientWriteResponse, typ::RPCError<typ::ClientWriteError>> {
self.send_rpc_to_leader("change-membership", Some(req))
.await
}
pub async fn metrics(&self) -> Result<RaftMetrics<DcacheNodeId, BasicNode>, typ::RPCError> {
self.do_send_rpc_to_leader("metrics", None::<&()>).await
}
async fn do_send_rpc_to_leader<Req, Resp, Err>(
&self,
uri: &str,
req: Option<&Req>,
) -> Result<Resp, typ::RPCError<Err>>
where
Req: Serialize + 'static,
Resp: Serialize + DeserializeOwned,
Err: std::error::Error + Serialize + DeserializeOwned,
{
let (leader_id, url) = {
let t = self.leader.lock().unwrap();
let target_addr = &t.1;
(t.0, format!("http://{}/{}", target_addr, uri))
};
let fu = if let Some(r) = req {
tracing::debug!(
">>> client send request to {}: {}",
url,
serde_json::to_string_pretty(&r).unwrap()
);
self.inner.post(url.clone()).json(r)
} else {
tracing::debug!(">>> client send request to {}", url,);
self.inner.get(url.clone())
}
.send();
let res = timeout(Duration::from_millis(3_000), fu).await;
let resp = match res {
Ok(x) => x.map_err(|e| RPCError::Network(NetworkError::new(&e)))?,
Err(timeout_err) => {
tracing::error!("timeout {} to url: {}", timeout_err, url);
return Err(RPCError::Network(NetworkError::new(&timeout_err)));
}
};
let res: Result<Resp, typ::RaftError<Err>> = resp
.json()
.await
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
tracing::debug!(
"<<< client recv reply from {}: {}",
url,
serde_json::to_string_pretty(&res).unwrap()
);
res.map_err(|e| RPCError::RemoteError(RemoteError::new(leader_id, e)))
}
async fn send_rpc_to_leader<Req, Resp, Err>(
&self,
uri: &str,
req: Option<&Req>,
) -> Result<Resp, typ::RPCError<Err>>
where
Req: Serialize + 'static,
Resp: Serialize + DeserializeOwned,
Err: std::error::Error
+ Serialize
+ DeserializeOwned
+ TryAsRef<typ::ForwardToLeader>
+ Clone,
{
// Retry at most 3 times to find a valid leader.
let mut n_retry = 3;
loop {
let res: Result<Resp, typ::RPCError<Err>> = self.do_send_rpc_to_leader(uri, req).await;
let rpc_err = match res {
Ok(x) => return Ok(x),
Err(rpc_err) => rpc_err,
};
if let Some(ForwardToLeader {
leader_id: Some(leader_id),
leader_node: Some(leader_node),
}) = rpc_err.forward_to_leader()
{
// Update target to the new leader.
{
let mut t = self.leader.lock().unwrap();
*t = (*leader_id, leader_node.addr.clone());
}
n_retry -= 1;
if n_retry > 0 {
continue;
}
}
return Err(rpc_err);
}
}
}