Aravinth Manivannan
77d4720e7d
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
246 lines
7.6 KiB
Rust
246 lines
7.6 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::collections::BTreeSet;
|
|
use std::collections::HashSet;
|
|
/*
|
|
* mCaptcha - A proof of work based DoS protection system
|
|
* Copyright © 2023 Aravinth Manivannan <realravinth@batsense.net>
|
|
*
|
|
* This program is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU Affero General Public License as
|
|
* published by the Free Software Foundation, either version 3 of the
|
|
* License, or (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU Affero General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU Affero General Public License
|
|
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
*/
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::sync::RwLock;
|
|
use std::time::Duration;
|
|
use std::time::Instant;
|
|
|
|
use futures_util::{future, pin_mut, StreamExt};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
|
|
|
|
use tokio::sync::mpsc;
|
|
|
|
use async_trait::async_trait;
|
|
use openraft::error::InstallSnapshotError;
|
|
use openraft::error::NetworkError;
|
|
use openraft::error::RPCError;
|
|
use openraft::error::RaftError;
|
|
use openraft::error::RemoteError;
|
|
use openraft::raft::AppendEntriesRequest;
|
|
use openraft::raft::AppendEntriesResponse;
|
|
use openraft::raft::InstallSnapshotRequest;
|
|
use openraft::raft::InstallSnapshotResponse;
|
|
use openraft::raft::VoteRequest;
|
|
use openraft::raft::VoteResponse;
|
|
use openraft::BasicNode;
|
|
use openraft::RaftNetwork;
|
|
use openraft::RaftNetworkFactory;
|
|
use reqwest::Client;
|
|
use serde::de::DeserializeOwned;
|
|
use serde::Serialize;
|
|
use tokio::sync::mpsc::Sender;
|
|
|
|
use super::management::HealthStatus;
|
|
use super::raft::{RaftMessage, RaftRes};
|
|
use crate::store::DcacheRequest;
|
|
use crate::store::DcacheResponse;
|
|
use crate::DcacheNodeId;
|
|
use crate::DcacheTypeConfig;
|
|
|
|
#[derive(Clone)]
|
|
pub struct DcacheNetwork {
|
|
pub signal: Sender<HealthStatus>,
|
|
pub client: Client,
|
|
}
|
|
|
|
impl DcacheNetwork {
|
|
pub fn new(signal: Sender<HealthStatus>, client: Client) -> Self {
|
|
Self { signal, client }
|
|
}
|
|
pub async fn send_rpc<Req, Resp, Err>(
|
|
&self,
|
|
target: DcacheNodeId,
|
|
target_node: &BasicNode,
|
|
uri: &str,
|
|
req: Req,
|
|
) -> Result<Resp, RPCError<DcacheNodeId, BasicNode, Err>>
|
|
where
|
|
Req: Serialize,
|
|
Err: std::error::Error + DeserializeOwned,
|
|
Resp: DeserializeOwned,
|
|
{
|
|
let addr = &target_node.addr;
|
|
|
|
let url = format!("http://{}/{}", addr, uri);
|
|
|
|
tracing::debug!("send_rpc to url: {}", url);
|
|
|
|
let resp = match self.client.post(url).json(&req).send().await {
|
|
Ok(resp) => Ok(resp),
|
|
Err(e) => {
|
|
self.signal.send(HealthStatus::Down(target)).await;
|
|
Err(RPCError::Network(NetworkError::new(&e)))
|
|
}
|
|
}?;
|
|
|
|
tracing::debug!("client.post() is sent");
|
|
|
|
let res: Result<Resp, Err> = resp
|
|
.json()
|
|
.await
|
|
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
|
|
|
|
let res = res.map_err(|e| RPCError::RemoteError(RemoteError::new(target, e)));
|
|
if res.is_ok() {
|
|
let signal2 = self.signal.clone();
|
|
let fut = async move {
|
|
let _ = signal2.send(HealthStatus::Healthy(target)).await;
|
|
};
|
|
tokio::spawn(fut);
|
|
}
|
|
res
|
|
}
|
|
}
|
|
|
|
// NOTE: This could be implemented also on `Arc<DcacheNetwork>`, but since it's empty, implemented
|
|
// directly.
|
|
#[async_trait]
|
|
impl RaftNetworkFactory<DcacheTypeConfig> for Arc<DcacheNetwork> {
|
|
type Network = DcacheNetworkConnection;
|
|
|
|
async fn new_client(&mut self, target: DcacheNodeId, node: &BasicNode) -> Self::Network {
|
|
let addr = &node.addr;
|
|
let url = format!("ws://{}/{}", addr, "ws/write");
|
|
|
|
let (write, rx) = mpsc::channel(30);
|
|
let (tx, read) = mpsc::channel(30);
|
|
let ws_client = WSClient::spawn(rx, tx, url).await;
|
|
|
|
DcacheNetworkConnection {
|
|
owner: self.clone(),
|
|
target,
|
|
target_node: node.clone(),
|
|
// ws_client,
|
|
read,
|
|
write,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct DcacheNetworkConnection {
|
|
owner: Arc<DcacheNetwork>,
|
|
target: DcacheNodeId,
|
|
target_node: BasicNode,
|
|
// ws_client: WSClient,
|
|
write: mpsc::Sender<RaftMessage>,
|
|
read: mpsc::Receiver<RaftRes>,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl RaftNetwork<DcacheTypeConfig> for DcacheNetworkConnection {
|
|
async fn send_append_entries(
|
|
&mut self,
|
|
req: AppendEntriesRequest<DcacheTypeConfig>,
|
|
) -> Result<
|
|
AppendEntriesResponse<DcacheNodeId>,
|
|
RPCError<DcacheNodeId, BasicNode, RaftError<DcacheNodeId>>,
|
|
> {
|
|
self.write.send(RaftMessage::Append(req)).await.unwrap();
|
|
match self.read.recv().await.unwrap() {
|
|
RaftRes::AppendRes(res) => {
|
|
res.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
|
|
}
|
|
_ => unimplemented!(),
|
|
}
|
|
// self.owner
|
|
// .send_rpc(self.target, &self.target_node, "raft-append", req)
|
|
// .await
|
|
}
|
|
|
|
async fn send_install_snapshot(
|
|
&mut self,
|
|
req: InstallSnapshotRequest<DcacheTypeConfig>,
|
|
) -> Result<
|
|
InstallSnapshotResponse<DcacheNodeId>,
|
|
RPCError<DcacheNodeId, BasicNode, RaftError<DcacheNodeId, InstallSnapshotError>>,
|
|
> {
|
|
// self.owner
|
|
// .send_rpc(self.target, &self.target_node, "raft-snapshot", req)
|
|
// .await
|
|
self.write.send(RaftMessage::Snapshot(req)).await.unwrap();
|
|
match self.read.recv().await.unwrap() {
|
|
RaftRes::SnapshotRes(res) => {
|
|
res.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
|
|
}
|
|
_ => unimplemented!(),
|
|
}
|
|
}
|
|
|
|
async fn send_vote(
|
|
&mut self,
|
|
req: VoteRequest<DcacheNodeId>,
|
|
) -> Result<
|
|
VoteResponse<DcacheNodeId>,
|
|
RPCError<DcacheNodeId, BasicNode, RaftError<DcacheNodeId>>,
|
|
> {
|
|
// self.owner
|
|
// .send_rpc(self.target, &self.target_node, "raft-vote", req)
|
|
// .await
|
|
self.write
|
|
.send(RaftMessage::VoteRequest(req))
|
|
.await
|
|
.unwrap();
|
|
match self.read.recv().await.unwrap() {
|
|
RaftRes::VoteRes(res) => {
|
|
res.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
|
|
}
|
|
_ => unimplemented!(),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct WSClient;
|
|
|
|
impl WSClient {
|
|
pub async fn spawn(
|
|
mut rx: mpsc::Receiver<RaftMessage>,
|
|
tx: mpsc::Sender<RaftRes>,
|
|
url: String,
|
|
) {
|
|
use futures_util::SinkExt;
|
|
|
|
let (ws_stream, _) = connect_async(url).await.expect("Failed to connect");
|
|
println!("WebSocket handshake has been successfully completed");
|
|
|
|
let (mut write, mut read) = ws_stream.split();
|
|
|
|
let fut = async move {
|
|
while let Some(msg) = rx.recv().await {
|
|
write
|
|
.send(Message::Text(serde_json::to_string(&msg).unwrap()))
|
|
.await
|
|
.unwrap();
|
|
|
|
match read.next().await.unwrap().unwrap() {
|
|
Message::Text(msg) => {
|
|
tx.send(serde_json::from_str(&msg).unwrap()).await;
|
|
}
|
|
_ => (),
|
|
}
|
|
}
|
|
};
|
|
|
|
tokio::spawn(fut);
|
|
}
|
|
}
|