use crate::{error::RequestError, handler::handle_request, RpcHandler};
use anvil_rpc::{
error::RpcError,
request::Request,
response::{Response, ResponseResult},
};
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use parking_lot::Mutex;
use serde::de::DeserializeOwned;
use std::{
collections::VecDeque,
fmt,
future::Future,
hash::Hash,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
#[async_trait::async_trait]
pub trait PubSubRpcHandler: Clone + Send + Sync + Unpin + 'static {
type Request: DeserializeOwned + Send + Sync + fmt::Debug;
type SubscriptionId: Hash + PartialEq + Eq + Send + Sync + fmt::Debug;
type Subscription: Stream<Item = serde_json::Value> + Send + Sync + Unpin;
async fn on_request(&self, request: Self::Request, cx: PubSubContext<Self>) -> ResponseResult;
}
type Subscriptions<SubscriptionId, Subscription> = Arc<Mutex<Vec<(SubscriptionId, Subscription)>>>;
pub struct PubSubContext<Handler: PubSubRpcHandler> {
subscriptions: Subscriptions<Handler::SubscriptionId, Handler::Subscription>,
}
impl<Handler: PubSubRpcHandler> PubSubContext<Handler> {
pub fn add_subscription(
&self,
id: Handler::SubscriptionId,
subscription: Handler::Subscription,
) -> Option<Handler::Subscription> {
let mut subscriptions = self.subscriptions.lock();
let mut removed = None;
if let Some(idx) = subscriptions.iter().position(|(i, _)| id == *i) {
trace!(target: "rpc", ?id, "removed subscription");
removed = Some(subscriptions.swap_remove(idx).1);
}
trace!(target: "rpc", ?id, "added subscription");
subscriptions.push((id, subscription));
removed
}
pub fn remove_subscription(
&self,
id: &Handler::SubscriptionId,
) -> Option<Handler::Subscription> {
let mut subscriptions = self.subscriptions.lock();
if let Some(idx) = subscriptions.iter().position(|(i, _)| id == i) {
trace!(target: "rpc", ?id, "removed subscription");
return Some(subscriptions.swap_remove(idx).1)
}
None
}
}
impl<Handler: PubSubRpcHandler> Clone for PubSubContext<Handler> {
fn clone(&self) -> Self {
Self { subscriptions: Arc::clone(&self.subscriptions) }
}
}
impl<Handler: PubSubRpcHandler> Default for PubSubContext<Handler> {
fn default() -> Self {
Self { subscriptions: Arc::new(Mutex::new(Vec::new())) }
}
}
struct ContextAwareHandler<Handler: PubSubRpcHandler> {
handler: Handler,
context: PubSubContext<Handler>,
}
impl<Handler: PubSubRpcHandler> Clone for ContextAwareHandler<Handler> {
fn clone(&self) -> Self {
Self { handler: self.handler.clone(), context: self.context.clone() }
}
}
#[async_trait::async_trait]
impl<Handler: PubSubRpcHandler> RpcHandler for ContextAwareHandler<Handler> {
type Request = Handler::Request;
async fn on_request(&self, request: Self::Request) -> ResponseResult {
self.handler.on_request(request, self.context.clone()).await
}
}
pub struct PubSubConnection<Handler: PubSubRpcHandler, Connection> {
handler: Handler,
context: PubSubContext<Handler>,
connection: Connection,
processing: Vec<Pin<Box<dyn Future<Output = Response> + Send>>>,
pending: VecDeque<String>,
}
impl<Handler: PubSubRpcHandler, Connection> PubSubConnection<Handler, Connection> {
pub fn new(connection: Connection, handler: Handler) -> Self {
Self {
connection,
handler,
context: Default::default(),
pending: Default::default(),
processing: Default::default(),
}
}
fn compat_helper(&self) -> ContextAwareHandler<Handler> {
ContextAwareHandler { handler: self.handler.clone(), context: self.context.clone() }
}
fn process_request(&mut self, req: serde_json::Result<Request>) {
let handler = self.compat_helper();
self.processing.push(Box::pin(async move {
match req {
Ok(req) => handle_request(req, handler)
.await
.unwrap_or_else(|| Response::error(RpcError::invalid_request())),
Err(err) => {
error!(target: "rpc", ?err, "invalid request");
Response::error(RpcError::invalid_request())
}
}
}));
}
}
impl<Handler, Connection> Future for PubSubConnection<Handler, Connection>
where
Handler: PubSubRpcHandler,
Connection: Sink<String> + Stream<Item = Result<Option<Request>, RequestError>> + Unpin,
<Connection as Sink<String>>::Error: fmt::Debug,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let pin = self.get_mut();
loop {
while matches!(pin.connection.poll_ready_unpin(cx), Poll::Ready(Ok(()))) {
if let Some(msg) = pin.pending.pop_front() {
if let Err(err) = pin.connection.start_send_unpin(msg) {
error!(target: "rpc", ?err, "Failed to send message");
}
} else {
break
}
}
if let Poll::Ready(Err(err)) = pin.connection.poll_flush_unpin(cx) {
trace!(target: "rpc", ?err, "websocket err");
return Poll::Ready(())
}
loop {
match pin.connection.poll_next_unpin(cx) {
Poll::Ready(Some(req)) => match req {
Ok(Some(req)) => {
pin.process_request(Ok(req));
}
Err(err) => match err {
RequestError::Axum(err) => {
trace!(target: "rpc", ?err, "client disconnected");
return Poll::Ready(())
}
RequestError::Io(err) => {
trace!(target: "rpc", ?err, "client disconnected");
return Poll::Ready(())
}
RequestError::Serde(err) => {
pin.process_request(Err(err));
}
RequestError::Disconnect => {
trace!(target: "rpc", "client disconnected");
return Poll::Ready(())
}
},
_ => {}
},
Poll::Ready(None) => {
trace!(target: "rpc", "socket connection finished");
return Poll::Ready(())
}
Poll::Pending => break,
}
}
let mut progress = false;
for n in (0..pin.processing.len()).rev() {
let mut req = pin.processing.swap_remove(n);
match req.poll_unpin(cx) {
Poll::Ready(resp) => {
if let Ok(text) = serde_json::to_string(&resp) {
pin.pending.push_back(text);
progress = true;
}
}
Poll::Pending => pin.processing.push(req),
}
}
{
let mut subscriptions = pin.context.subscriptions.lock();
'outer: for n in (0..subscriptions.len()).rev() {
let (id, mut sub) = subscriptions.swap_remove(n);
'inner: loop {
match sub.poll_next_unpin(cx) {
Poll::Ready(Some(res)) => {
if let Ok(text) = serde_json::to_string(&res) {
pin.pending.push_back(text);
progress = true;
}
}
Poll::Ready(None) => continue 'outer,
Poll::Pending => break 'inner,
}
}
subscriptions.push((id, sub));
}
}
if !progress {
return Poll::Pending
}
}
}
}