anvil_server/
pubsub.rs

1use crate::{RpcHandler, error::RequestError, handler::handle_request};
2use anvil_rpc::{
3    error::RpcError,
4    request::Request,
5    response::{Response, ResponseResult},
6};
7
8use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
9use parking_lot::Mutex;
10use serde::de::DeserializeOwned;
11use std::{
12    collections::VecDeque,
13    fmt,
14    hash::Hash,
15    pin::Pin,
16    sync::Arc,
17    task::{Context, Poll},
18};
19
20/// The general purpose trait for handling RPC requests and subscriptions
21#[async_trait::async_trait]
22pub trait PubSubRpcHandler: Clone + Send + Sync + Unpin + 'static {
23    /// The request type to expect
24    type Request: DeserializeOwned + Send + Sync + fmt::Debug;
25    /// The identifier to use for subscriptions
26    type SubscriptionId: Hash + PartialEq + Eq + Send + Sync + fmt::Debug;
27    /// The subscription type this handle may create
28    type Subscription: Stream<Item = serde_json::Value> + Send + Sync + Unpin;
29
30    /// Invoked when the request was received
31    async fn on_request(&self, request: Self::Request, cx: PubSubContext<Self>) -> ResponseResult;
32}
33
34type Subscriptions<SubscriptionId, Subscription> = Arc<Mutex<Vec<(SubscriptionId, Subscription)>>>;
35
36/// Contains additional context and tracks subscriptions
37pub struct PubSubContext<Handler: PubSubRpcHandler> {
38    /// all active subscriptions `id -> Stream`
39    subscriptions: Subscriptions<Handler::SubscriptionId, Handler::Subscription>,
40}
41
42impl<Handler: PubSubRpcHandler> PubSubContext<Handler> {
43    /// Adds new active subscription
44    ///
45    /// Returns the previous subscription, if any
46    pub fn add_subscription(
47        &self,
48        id: Handler::SubscriptionId,
49        subscription: Handler::Subscription,
50    ) -> Option<Handler::Subscription> {
51        let mut subscriptions = self.subscriptions.lock();
52        let mut removed = None;
53        if let Some(idx) = subscriptions.iter().position(|(i, _)| id == *i) {
54            trace!(target: "rpc", ?id,  "removed subscription");
55            removed = Some(subscriptions.swap_remove(idx).1);
56        }
57        trace!(target: "rpc", ?id,  "added subscription");
58        subscriptions.push((id, subscription));
59        removed
60    }
61
62    /// Removes an existing subscription
63    pub fn remove_subscription(
64        &self,
65        id: &Handler::SubscriptionId,
66    ) -> Option<Handler::Subscription> {
67        let mut subscriptions = self.subscriptions.lock();
68        if let Some(idx) = subscriptions.iter().position(|(i, _)| id == i) {
69            trace!(target: "rpc", ?id,  "removed subscription");
70            return Some(subscriptions.swap_remove(idx).1);
71        }
72        None
73    }
74}
75
76impl<Handler: PubSubRpcHandler> Clone for PubSubContext<Handler> {
77    fn clone(&self) -> Self {
78        Self { subscriptions: Arc::clone(&self.subscriptions) }
79    }
80}
81
82impl<Handler: PubSubRpcHandler> Default for PubSubContext<Handler> {
83    fn default() -> Self {
84        Self { subscriptions: Arc::new(Mutex::new(Vec::new())) }
85    }
86}
87
88/// A compatibility helper type to use common `RpcHandler` functions
89struct ContextAwareHandler<Handler: PubSubRpcHandler> {
90    handler: Handler,
91    context: PubSubContext<Handler>,
92}
93
94impl<Handler: PubSubRpcHandler> Clone for ContextAwareHandler<Handler> {
95    fn clone(&self) -> Self {
96        Self { handler: self.handler.clone(), context: self.context.clone() }
97    }
98}
99
100#[async_trait::async_trait]
101impl<Handler: PubSubRpcHandler> RpcHandler for ContextAwareHandler<Handler> {
102    type Request = Handler::Request;
103
104    async fn on_request(&self, request: Self::Request) -> ResponseResult {
105        self.handler.on_request(request, self.context.clone()).await
106    }
107}
108
109/// Represents a connection to a client via websocket
110///
111/// Contains the state for the entire connection
112pub struct PubSubConnection<Handler: PubSubRpcHandler, Connection> {
113    /// the handler for the websocket connection
114    handler: Handler,
115    /// contains all the subscription related context
116    context: PubSubContext<Handler>,
117    /// The established connection
118    connection: Connection,
119    /// currently in progress requests
120    processing: Vec<Pin<Box<dyn Future<Output = Response> + Send>>>,
121    /// pending messages to send
122    pending: VecDeque<String>,
123}
124
125impl<Handler: PubSubRpcHandler, Connection> PubSubConnection<Handler, Connection> {
126    pub fn new(connection: Connection, handler: Handler) -> Self {
127        Self {
128            connection,
129            handler,
130            context: Default::default(),
131            pending: Default::default(),
132            processing: Default::default(),
133        }
134    }
135
136    /// Returns a compatibility `RpcHandler`
137    fn compat_helper(&self) -> ContextAwareHandler<Handler> {
138        ContextAwareHandler { handler: self.handler.clone(), context: self.context.clone() }
139    }
140
141    fn process_request(&mut self, req: serde_json::Result<Request>) {
142        let handler = self.compat_helper();
143        self.processing.push(Box::pin(async move {
144            match req {
145                Ok(req) => handle_request(req, handler)
146                    .await
147                    .unwrap_or_else(|| Response::error(RpcError::invalid_request())),
148                Err(err) => {
149                    error!(target: "rpc", ?err, "invalid request");
150                    Response::error(RpcError::invalid_request())
151                }
152            }
153        }));
154    }
155}
156
157impl<Handler, Connection> Future for PubSubConnection<Handler, Connection>
158where
159    Handler: PubSubRpcHandler,
160    Connection: Sink<String> + Stream<Item = Result<Option<Request>, RequestError>> + Unpin,
161    <Connection as Sink<String>>::Error: fmt::Debug,
162{
163    type Output = ();
164
165    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166        let pin = self.get_mut();
167        loop {
168            // drive the websocket
169            while matches!(pin.connection.poll_ready_unpin(cx), Poll::Ready(Ok(()))) {
170                // only start sending if socket is ready
171                if let Some(msg) = pin.pending.pop_front() {
172                    if let Err(err) = pin.connection.start_send_unpin(msg) {
173                        error!(target: "rpc", ?err, "Failed to send message");
174                    }
175                } else {
176                    break;
177                }
178            }
179
180            // Ensure any pending messages are flushed
181            // this needs to be called manually for tungsenite websocket: <https://github.com/foundry-rs/foundry/issues/6345>
182            if let Poll::Ready(Err(err)) = pin.connection.poll_flush_unpin(cx) {
183                trace!(target: "rpc", ?err, "websocket err");
184                // close the connection
185                return Poll::Ready(());
186            }
187
188            loop {
189                match pin.connection.poll_next_unpin(cx) {
190                    Poll::Ready(Some(req)) => match req {
191                        Ok(Some(req)) => {
192                            pin.process_request(Ok(req));
193                        }
194                        Err(err) => match err {
195                            RequestError::Axum(err) => {
196                                trace!(target: "rpc", ?err, "client disconnected");
197                                return Poll::Ready(());
198                            }
199                            RequestError::Io(err) => {
200                                trace!(target: "rpc", ?err, "client disconnected");
201                                return Poll::Ready(());
202                            }
203                            RequestError::Serde(err) => {
204                                pin.process_request(Err(err));
205                            }
206                            RequestError::Disconnect => {
207                                trace!(target: "rpc", "client disconnected");
208                                return Poll::Ready(());
209                            }
210                        },
211                        _ => {}
212                    },
213                    Poll::Ready(None) => {
214                        trace!(target: "rpc", "socket connection finished");
215                        return Poll::Ready(());
216                    }
217                    Poll::Pending => break,
218                }
219            }
220
221            let mut progress = false;
222            for n in (0..pin.processing.len()).rev() {
223                let mut req = pin.processing.swap_remove(n);
224                match req.poll_unpin(cx) {
225                    Poll::Ready(resp) => {
226                        if let Ok(text) = serde_json::to_string(&resp) {
227                            pin.pending.push_back(text);
228                            progress = true;
229                        }
230                    }
231                    Poll::Pending => pin.processing.push(req),
232                }
233            }
234
235            {
236                // process subscription events
237                let mut subscriptions = pin.context.subscriptions.lock();
238                'outer: for n in (0..subscriptions.len()).rev() {
239                    let (id, mut sub) = subscriptions.swap_remove(n);
240                    'inner: loop {
241                        match sub.poll_next_unpin(cx) {
242                            Poll::Ready(Some(res)) => {
243                                if let Ok(text) = serde_json::to_string(&res) {
244                                    pin.pending.push_back(text);
245                                    progress = true;
246                                }
247                            }
248                            Poll::Ready(None) => continue 'outer,
249                            Poll::Pending => break 'inner,
250                        }
251                    }
252
253                    subscriptions.push((id, sub));
254                }
255            }
256
257            if !progress {
258                return Poll::Pending;
259            }
260        }
261    }
262}