Skip to main content

anvil_server/
pubsub.rs

1use crate::{RpcHandler, error::RequestError, handler::handle_request};
2use anvil_rpc::{
3    error::RpcError,
4    request::Request,
5    response::{Response, ResponseResult},
6};
7
8use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
9use parking_lot::Mutex;
10use serde::de::DeserializeOwned;
11use std::{
12    collections::VecDeque,
13    fmt,
14    hash::Hash,
15    pin::Pin,
16    sync::Arc,
17    task::{Context, Poll},
18};
19
20/// The general purpose trait for handling RPC requests and subscriptions
21#[async_trait::async_trait]
22pub trait PubSubRpcHandler: Clone + Send + Sync + Unpin + 'static {
23    /// The request type to expect
24    type Request: DeserializeOwned + Send + Sync + fmt::Debug;
25    /// The identifier to use for subscriptions
26    type SubscriptionId: Hash + PartialEq + Eq + Send + Sync + fmt::Debug;
27    /// The subscription type this handle may create
28    type Subscription: Stream<Item = serde_json::Value> + Send + Sync + Unpin;
29
30    /// Invoked when the request was received
31    async fn on_request(&self, request: Self::Request, cx: PubSubContext<Self>) -> ResponseResult;
32}
33
34type Subscriptions<SubscriptionId, Subscription> = Arc<Mutex<Vec<(SubscriptionId, Subscription)>>>;
35
36/// Contains additional context and tracks subscriptions
37pub struct PubSubContext<Handler: PubSubRpcHandler> {
38    /// all active subscriptions `id -> Stream`
39    subscriptions: Subscriptions<Handler::SubscriptionId, Handler::Subscription>,
40}
41
42impl<Handler: PubSubRpcHandler> PubSubContext<Handler> {
43    /// Adds new active subscription
44    ///
45    /// Returns the previous subscription, if any
46    pub fn add_subscription(
47        &self,
48        id: Handler::SubscriptionId,
49        subscription: Handler::Subscription,
50    ) -> Option<Handler::Subscription> {
51        let mut subscriptions = self.subscriptions.lock();
52        let mut removed = None;
53        if let Some(idx) = subscriptions.iter().position(|(i, _)| id == *i) {
54            trace!(target: "rpc", ?id,  "removed subscription");
55            removed = Some(subscriptions.swap_remove(idx).1);
56        }
57        trace!(target: "rpc", ?id,  "added subscription");
58        subscriptions.push((id, subscription));
59        removed
60    }
61
62    /// Removes an existing subscription
63    pub fn remove_subscription(
64        &self,
65        id: &Handler::SubscriptionId,
66    ) -> Option<Handler::Subscription> {
67        let mut subscriptions = self.subscriptions.lock();
68        if let Some(idx) = subscriptions.iter().position(|(i, _)| id == i) {
69            trace!(target: "rpc", ?id,  "removed subscription");
70            return Some(subscriptions.swap_remove(idx).1);
71        }
72        None
73    }
74}
75
76impl<Handler: PubSubRpcHandler> Clone for PubSubContext<Handler> {
77    fn clone(&self) -> Self {
78        Self { subscriptions: Arc::clone(&self.subscriptions) }
79    }
80}
81
82impl<Handler: PubSubRpcHandler> Default for PubSubContext<Handler> {
83    fn default() -> Self {
84        Self { subscriptions: Arc::new(Mutex::new(Vec::new())) }
85    }
86}
87
88/// A compatibility helper type to use common `RpcHandler` functions
89struct ContextAwareHandler<Handler: PubSubRpcHandler> {
90    handler: Handler,
91    context: PubSubContext<Handler>,
92}
93
94impl<Handler: PubSubRpcHandler> Clone for ContextAwareHandler<Handler> {
95    fn clone(&self) -> Self {
96        Self { handler: self.handler.clone(), context: self.context.clone() }
97    }
98}
99
100#[async_trait::async_trait]
101impl<Handler: PubSubRpcHandler> RpcHandler for ContextAwareHandler<Handler> {
102    type Request = Handler::Request;
103
104    async fn on_request(&self, request: Self::Request) -> ResponseResult {
105        self.handler.on_request(request, self.context.clone()).await
106    }
107}
108
109/// Represents a connection to a client via websocket
110///
111/// Contains the state for the entire connection
112pub struct PubSubConnection<Handler: PubSubRpcHandler, Connection> {
113    /// the handler for the websocket connection
114    handler: Handler,
115    /// contains all the subscription related context
116    context: PubSubContext<Handler>,
117    /// The established connection
118    connection: Connection,
119    /// currently in progress requests
120    processing: Vec<Pin<Box<dyn Future<Output = Option<Response>> + Send>>>,
121    /// pending messages to send
122    pending: VecDeque<String>,
123}
124
125impl<Handler: PubSubRpcHandler, Connection> PubSubConnection<Handler, Connection> {
126    pub fn new(connection: Connection, handler: Handler) -> Self {
127        Self {
128            connection,
129            handler,
130            context: Default::default(),
131            pending: Default::default(),
132            processing: Default::default(),
133        }
134    }
135
136    /// Returns a compatibility `RpcHandler`
137    fn compat_helper(&self) -> ContextAwareHandler<Handler> {
138        ContextAwareHandler { handler: self.handler.clone(), context: self.context.clone() }
139    }
140
141    fn process_request(&mut self, req: serde_json::Result<Request>) {
142        let handler = self.compat_helper();
143        self.processing.push(Box::pin(async move {
144            match req {
145                Ok(req) => handle_request(req, handler).await,
146                Err(err) => {
147                    error!(target: "rpc", ?err, "invalid request");
148                    Some(Response::error(RpcError::invalid_request()))
149                }
150            }
151        }));
152    }
153}
154
155impl<Handler, Connection> Future for PubSubConnection<Handler, Connection>
156where
157    Handler: PubSubRpcHandler,
158    Connection: Sink<String> + Stream<Item = Result<Option<Request>, RequestError>> + Unpin,
159    <Connection as Sink<String>>::Error: fmt::Debug,
160{
161    type Output = ();
162
163    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164        let pin = self.get_mut();
165        loop {
166            // drive the websocket
167            while matches!(pin.connection.poll_ready_unpin(cx), Poll::Ready(Ok(()))) {
168                // only start sending if socket is ready
169                if let Some(msg) = pin.pending.pop_front() {
170                    if let Err(err) = pin.connection.start_send_unpin(msg) {
171                        error!(target: "rpc", ?err, "Failed to send message");
172                    }
173                } else {
174                    break;
175                }
176            }
177
178            // Ensure any pending messages are flushed
179            // this needs to be called manually for tungsenite websocket: <https://github.com/foundry-rs/foundry/issues/6345>
180            if let Poll::Ready(Err(err)) = pin.connection.poll_flush_unpin(cx) {
181                trace!(target: "rpc", ?err, "websocket err");
182                // close the connection
183                return Poll::Ready(());
184            }
185
186            loop {
187                match pin.connection.poll_next_unpin(cx) {
188                    Poll::Ready(Some(req)) => match req {
189                        Ok(Some(req)) => {
190                            pin.process_request(Ok(req));
191                        }
192                        Err(err) => match err {
193                            RequestError::Axum(err) => {
194                                trace!(target: "rpc", ?err, "client disconnected");
195                                return Poll::Ready(());
196                            }
197                            RequestError::Io(err) => {
198                                trace!(target: "rpc", ?err, "client disconnected");
199                                return Poll::Ready(());
200                            }
201                            RequestError::Serde(err) => {
202                                pin.process_request(Err(err));
203                            }
204                            RequestError::Disconnect => {
205                                trace!(target: "rpc", "client disconnected");
206                                return Poll::Ready(());
207                            }
208                        },
209                        _ => {}
210                    },
211                    Poll::Ready(None) => {
212                        trace!(target: "rpc", "socket connection finished");
213                        return Poll::Ready(());
214                    }
215                    Poll::Pending => break,
216                }
217            }
218
219            let mut progress = false;
220            for n in (0..pin.processing.len()).rev() {
221                let mut req = pin.processing.swap_remove(n);
222                #[allow(clippy::collapsible_match)]
223                match req.poll_unpin(cx) {
224                    Poll::Ready(Some(resp)) => {
225                        if let Ok(text) = serde_json::to_string(&resp) {
226                            pin.pending.push_back(text);
227                            progress = true;
228                        }
229                    }
230                    Poll::Ready(None) => {}
231                    Poll::Pending => pin.processing.push(req),
232                }
233            }
234
235            {
236                // process subscription events
237                let mut subscriptions = pin.context.subscriptions.lock();
238                'outer: for n in (0..subscriptions.len()).rev() {
239                    let (id, mut sub) = subscriptions.swap_remove(n);
240                    'inner: loop {
241                        #[allow(clippy::collapsible_match)]
242                        match sub.poll_next_unpin(cx) {
243                            Poll::Ready(Some(res)) => {
244                                if let Ok(text) = serde_json::to_string(&res) {
245                                    pin.pending.push_back(text);
246                                    progress = true;
247                                }
248                            }
249                            Poll::Ready(None) => continue 'outer,
250                            Poll::Pending => break 'inner,
251                        }
252                    }
253
254                    subscriptions.push((id, sub));
255                }
256            }
257
258            if !progress {
259                return Poll::Pending;
260            }
261        }
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use anvil_rpc::{
269        request::{RequestParams, RpcCall, RpcNotification, Version},
270        response::RpcResponse,
271    };
272    use std::{pin::pin, task::Waker};
273
274    #[derive(Clone)]
275    struct TestHandler;
276
277    #[async_trait::async_trait]
278    impl PubSubRpcHandler for TestHandler {
279        type Request = serde_json::Value;
280        type SubscriptionId = u64;
281        type Subscription = futures::stream::Empty<serde_json::Value>;
282
283        async fn on_request(
284            &self,
285            _request: Self::Request,
286            _cx: PubSubContext<Self>,
287        ) -> ResponseResult {
288            ResponseResult::success(serde_json::Value::Null)
289        }
290    }
291
292    fn notification() -> RpcCall {
293        RpcCall::Notification(RpcNotification {
294            jsonrpc: Some(Version::V2),
295            method: "eth_subscribe".to_owned(),
296            params: RequestParams::None,
297        })
298    }
299
300    fn run_ready<F: Future>(future: F) -> F::Output {
301        let waker = Waker::noop();
302        let mut cx = Context::from_waker(waker);
303        let mut future = pin!(future);
304        match future.as_mut().poll(&mut cx) {
305            Poll::Ready(output) => output,
306            Poll::Pending => panic!("future unexpectedly pending"),
307        }
308    }
309
310    #[test]
311    fn process_request_keeps_empty_batch_invalid() {
312        let mut connection = PubSubConnection::new((), TestHandler);
313        connection.process_request(Ok(Request::Batch(vec![])));
314
315        let response = run_ready(connection.processing.pop().unwrap());
316        assert_eq!(
317            response,
318            Some(Response::Single(RpcResponse::from(RpcError::invalid_request())))
319        );
320    }
321
322    #[test]
323    fn process_request_skips_notification_only_batch_response() {
324        let mut connection = PubSubConnection::new((), TestHandler);
325        connection.process_request(Ok(Request::Batch(vec![notification()])));
326
327        let response = run_ready(connection.processing.pop().unwrap());
328        assert_eq!(response, None);
329    }
330}