anvil_server/
pubsub.rs

1use crate::{error::RequestError, handler::handle_request, RpcHandler};
2use anvil_rpc::{
3    error::RpcError,
4    request::Request,
5    response::{Response, ResponseResult},
6};
7
8use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
9use parking_lot::Mutex;
10use serde::de::DeserializeOwned;
11use std::{
12    collections::VecDeque,
13    fmt,
14    future::Future,
15    hash::Hash,
16    pin::Pin,
17    sync::Arc,
18    task::{Context, Poll},
19};
20
21/// The general purpose trait for handling RPC requests and subscriptions
22#[async_trait::async_trait]
23pub trait PubSubRpcHandler: Clone + Send + Sync + Unpin + 'static {
24    /// The request type to expect
25    type Request: DeserializeOwned + Send + Sync + fmt::Debug;
26    /// The identifier to use for subscriptions
27    type SubscriptionId: Hash + PartialEq + Eq + Send + Sync + fmt::Debug;
28    /// The subscription type this handle may create
29    type Subscription: Stream<Item = serde_json::Value> + Send + Sync + Unpin;
30
31    /// Invoked when the request was received
32    async fn on_request(&self, request: Self::Request, cx: PubSubContext<Self>) -> ResponseResult;
33}
34
35type Subscriptions<SubscriptionId, Subscription> = Arc<Mutex<Vec<(SubscriptionId, Subscription)>>>;
36
37/// Contains additional context and tracks subscriptions
38pub struct PubSubContext<Handler: PubSubRpcHandler> {
39    /// all active subscriptions `id -> Stream`
40    subscriptions: Subscriptions<Handler::SubscriptionId, Handler::Subscription>,
41}
42
43impl<Handler: PubSubRpcHandler> PubSubContext<Handler> {
44    /// Adds new active subscription
45    ///
46    /// Returns the previous subscription, if any
47    pub fn add_subscription(
48        &self,
49        id: Handler::SubscriptionId,
50        subscription: Handler::Subscription,
51    ) -> Option<Handler::Subscription> {
52        let mut subscriptions = self.subscriptions.lock();
53        let mut removed = None;
54        if let Some(idx) = subscriptions.iter().position(|(i, _)| id == *i) {
55            trace!(target: "rpc", ?id,  "removed subscription");
56            removed = Some(subscriptions.swap_remove(idx).1);
57        }
58        trace!(target: "rpc", ?id,  "added subscription");
59        subscriptions.push((id, subscription));
60        removed
61    }
62
63    /// Removes an existing subscription
64    pub fn remove_subscription(
65        &self,
66        id: &Handler::SubscriptionId,
67    ) -> Option<Handler::Subscription> {
68        let mut subscriptions = self.subscriptions.lock();
69        if let Some(idx) = subscriptions.iter().position(|(i, _)| id == i) {
70            trace!(target: "rpc", ?id,  "removed subscription");
71            return Some(subscriptions.swap_remove(idx).1)
72        }
73        None
74    }
75}
76
77impl<Handler: PubSubRpcHandler> Clone for PubSubContext<Handler> {
78    fn clone(&self) -> Self {
79        Self { subscriptions: Arc::clone(&self.subscriptions) }
80    }
81}
82
83impl<Handler: PubSubRpcHandler> Default for PubSubContext<Handler> {
84    fn default() -> Self {
85        Self { subscriptions: Arc::new(Mutex::new(Vec::new())) }
86    }
87}
88
89/// A compatibility helper type to use common `RpcHandler` functions
90struct ContextAwareHandler<Handler: PubSubRpcHandler> {
91    handler: Handler,
92    context: PubSubContext<Handler>,
93}
94
95impl<Handler: PubSubRpcHandler> Clone for ContextAwareHandler<Handler> {
96    fn clone(&self) -> Self {
97        Self { handler: self.handler.clone(), context: self.context.clone() }
98    }
99}
100
101#[async_trait::async_trait]
102impl<Handler: PubSubRpcHandler> RpcHandler for ContextAwareHandler<Handler> {
103    type Request = Handler::Request;
104
105    async fn on_request(&self, request: Self::Request) -> ResponseResult {
106        self.handler.on_request(request, self.context.clone()).await
107    }
108}
109
110/// Represents a connection to a client via websocket
111///
112/// Contains the state for the entire connection
113pub struct PubSubConnection<Handler: PubSubRpcHandler, Connection> {
114    /// the handler for the websocket connection
115    handler: Handler,
116    /// contains all the subscription related context
117    context: PubSubContext<Handler>,
118    /// The established connection
119    connection: Connection,
120    /// currently in progress requests
121    processing: Vec<Pin<Box<dyn Future<Output = Response> + Send>>>,
122    /// pending messages to send
123    pending: VecDeque<String>,
124}
125
126impl<Handler: PubSubRpcHandler, Connection> PubSubConnection<Handler, Connection> {
127    pub fn new(connection: Connection, handler: Handler) -> Self {
128        Self {
129            connection,
130            handler,
131            context: Default::default(),
132            pending: Default::default(),
133            processing: Default::default(),
134        }
135    }
136
137    /// Returns a compatibility `RpcHandler`
138    fn compat_helper(&self) -> ContextAwareHandler<Handler> {
139        ContextAwareHandler { handler: self.handler.clone(), context: self.context.clone() }
140    }
141
142    fn process_request(&mut self, req: serde_json::Result<Request>) {
143        let handler = self.compat_helper();
144        self.processing.push(Box::pin(async move {
145            match req {
146                Ok(req) => handle_request(req, handler)
147                    .await
148                    .unwrap_or_else(|| Response::error(RpcError::invalid_request())),
149                Err(err) => {
150                    error!(target: "rpc", ?err, "invalid request");
151                    Response::error(RpcError::invalid_request())
152                }
153            }
154        }));
155    }
156}
157
158impl<Handler, Connection> Future for PubSubConnection<Handler, Connection>
159where
160    Handler: PubSubRpcHandler,
161    Connection: Sink<String> + Stream<Item = Result<Option<Request>, RequestError>> + Unpin,
162    <Connection as Sink<String>>::Error: fmt::Debug,
163{
164    type Output = ();
165
166    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167        let pin = self.get_mut();
168        loop {
169            // drive the websocket
170            while matches!(pin.connection.poll_ready_unpin(cx), Poll::Ready(Ok(()))) {
171                // only start sending if socket is ready
172                if let Some(msg) = pin.pending.pop_front() {
173                    if let Err(err) = pin.connection.start_send_unpin(msg) {
174                        error!(target: "rpc", ?err, "Failed to send message");
175                    }
176                } else {
177                    break
178                }
179            }
180
181            // Ensure any pending messages are flushed
182            // this needs to be called manually for tungsenite websocket: <https://github.com/foundry-rs/foundry/issues/6345>
183            if let Poll::Ready(Err(err)) = pin.connection.poll_flush_unpin(cx) {
184                trace!(target: "rpc", ?err, "websocket err");
185                // close the connection
186                return Poll::Ready(())
187            }
188
189            loop {
190                match pin.connection.poll_next_unpin(cx) {
191                    Poll::Ready(Some(req)) => match req {
192                        Ok(Some(req)) => {
193                            pin.process_request(Ok(req));
194                        }
195                        Err(err) => match err {
196                            RequestError::Axum(err) => {
197                                trace!(target: "rpc", ?err, "client disconnected");
198                                return Poll::Ready(())
199                            }
200                            RequestError::Io(err) => {
201                                trace!(target: "rpc", ?err, "client disconnected");
202                                return Poll::Ready(())
203                            }
204                            RequestError::Serde(err) => {
205                                pin.process_request(Err(err));
206                            }
207                            RequestError::Disconnect => {
208                                trace!(target: "rpc", "client disconnected");
209                                return Poll::Ready(())
210                            }
211                        },
212                        _ => {}
213                    },
214                    Poll::Ready(None) => {
215                        trace!(target: "rpc", "socket connection finished");
216                        return Poll::Ready(())
217                    }
218                    Poll::Pending => break,
219                }
220            }
221
222            let mut progress = false;
223            for n in (0..pin.processing.len()).rev() {
224                let mut req = pin.processing.swap_remove(n);
225                match req.poll_unpin(cx) {
226                    Poll::Ready(resp) => {
227                        if let Ok(text) = serde_json::to_string(&resp) {
228                            pin.pending.push_back(text);
229                            progress = true;
230                        }
231                    }
232                    Poll::Pending => pin.processing.push(req),
233                }
234            }
235
236            {
237                // process subscription events
238                let mut subscriptions = pin.context.subscriptions.lock();
239                'outer: for n in (0..subscriptions.len()).rev() {
240                    let (id, mut sub) = subscriptions.swap_remove(n);
241                    'inner: loop {
242                        match sub.poll_next_unpin(cx) {
243                            Poll::Ready(Some(res)) => {
244                                if let Ok(text) = serde_json::to_string(&res) {
245                                    pin.pending.push_back(text);
246                                    progress = true;
247                                }
248                            }
249                            Poll::Ready(None) => continue 'outer,
250                            Poll::Pending => break 'inner,
251                        }
252                    }
253
254                    subscriptions.push((id, sub));
255                }
256            }
257
258            if !progress {
259                return Poll::Pending
260            }
261        }
262    }
263}