1use crate::{RpcHandler, error::RequestError, handler::handle_request};
2use anvil_rpc::{
3 error::RpcError,
4 request::Request,
5 response::{Response, ResponseResult},
6};
7
8use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
9use parking_lot::Mutex;
10use serde::de::DeserializeOwned;
11use std::{
12 collections::VecDeque,
13 fmt,
14 hash::Hash,
15 pin::Pin,
16 sync::Arc,
17 task::{Context, Poll},
18};
19
20#[async_trait::async_trait]
22pub trait PubSubRpcHandler: Clone + Send + Sync + Unpin + 'static {
23 type Request: DeserializeOwned + Send + Sync + fmt::Debug;
25 type SubscriptionId: Hash + PartialEq + Eq + Send + Sync + fmt::Debug;
27 type Subscription: Stream<Item = serde_json::Value> + Send + Sync + Unpin;
29
30 async fn on_request(&self, request: Self::Request, cx: PubSubContext<Self>) -> ResponseResult;
32}
33
34type Subscriptions<SubscriptionId, Subscription> = Arc<Mutex<Vec<(SubscriptionId, Subscription)>>>;
35
36pub struct PubSubContext<Handler: PubSubRpcHandler> {
38 subscriptions: Subscriptions<Handler::SubscriptionId, Handler::Subscription>,
40}
41
42impl<Handler: PubSubRpcHandler> PubSubContext<Handler> {
43 pub fn add_subscription(
47 &self,
48 id: Handler::SubscriptionId,
49 subscription: Handler::Subscription,
50 ) -> Option<Handler::Subscription> {
51 let mut subscriptions = self.subscriptions.lock();
52 let mut removed = None;
53 if let Some(idx) = subscriptions.iter().position(|(i, _)| id == *i) {
54 trace!(target: "rpc", ?id, "removed subscription");
55 removed = Some(subscriptions.swap_remove(idx).1);
56 }
57 trace!(target: "rpc", ?id, "added subscription");
58 subscriptions.push((id, subscription));
59 removed
60 }
61
62 pub fn remove_subscription(
64 &self,
65 id: &Handler::SubscriptionId,
66 ) -> Option<Handler::Subscription> {
67 let mut subscriptions = self.subscriptions.lock();
68 if let Some(idx) = subscriptions.iter().position(|(i, _)| id == i) {
69 trace!(target: "rpc", ?id, "removed subscription");
70 return Some(subscriptions.swap_remove(idx).1);
71 }
72 None
73 }
74}
75
76impl<Handler: PubSubRpcHandler> Clone for PubSubContext<Handler> {
77 fn clone(&self) -> Self {
78 Self { subscriptions: Arc::clone(&self.subscriptions) }
79 }
80}
81
82impl<Handler: PubSubRpcHandler> Default for PubSubContext<Handler> {
83 fn default() -> Self {
84 Self { subscriptions: Arc::new(Mutex::new(Vec::new())) }
85 }
86}
87
88struct ContextAwareHandler<Handler: PubSubRpcHandler> {
90 handler: Handler,
91 context: PubSubContext<Handler>,
92}
93
94impl<Handler: PubSubRpcHandler> Clone for ContextAwareHandler<Handler> {
95 fn clone(&self) -> Self {
96 Self { handler: self.handler.clone(), context: self.context.clone() }
97 }
98}
99
100#[async_trait::async_trait]
101impl<Handler: PubSubRpcHandler> RpcHandler for ContextAwareHandler<Handler> {
102 type Request = Handler::Request;
103
104 async fn on_request(&self, request: Self::Request) -> ResponseResult {
105 self.handler.on_request(request, self.context.clone()).await
106 }
107}
108
109pub struct PubSubConnection<Handler: PubSubRpcHandler, Connection> {
113 handler: Handler,
115 context: PubSubContext<Handler>,
117 connection: Connection,
119 processing: Vec<Pin<Box<dyn Future<Output = Option<Response>> + Send>>>,
121 pending: VecDeque<String>,
123}
124
125impl<Handler: PubSubRpcHandler, Connection> PubSubConnection<Handler, Connection> {
126 pub fn new(connection: Connection, handler: Handler) -> Self {
127 Self {
128 connection,
129 handler,
130 context: Default::default(),
131 pending: Default::default(),
132 processing: Default::default(),
133 }
134 }
135
136 fn compat_helper(&self) -> ContextAwareHandler<Handler> {
138 ContextAwareHandler { handler: self.handler.clone(), context: self.context.clone() }
139 }
140
141 fn process_request(&mut self, req: serde_json::Result<Request>) {
142 let handler = self.compat_helper();
143 self.processing.push(Box::pin(async move {
144 match req {
145 Ok(req) => handle_request(req, handler).await,
146 Err(err) => {
147 error!(target: "rpc", ?err, "invalid request");
148 Some(Response::error(RpcError::invalid_request()))
149 }
150 }
151 }));
152 }
153}
154
155impl<Handler, Connection> Future for PubSubConnection<Handler, Connection>
156where
157 Handler: PubSubRpcHandler,
158 Connection: Sink<String> + Stream<Item = Result<Option<Request>, RequestError>> + Unpin,
159 <Connection as Sink<String>>::Error: fmt::Debug,
160{
161 type Output = ();
162
163 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164 let pin = self.get_mut();
165 loop {
166 while matches!(pin.connection.poll_ready_unpin(cx), Poll::Ready(Ok(()))) {
168 if let Some(msg) = pin.pending.pop_front() {
170 if let Err(err) = pin.connection.start_send_unpin(msg) {
171 error!(target: "rpc", ?err, "Failed to send message");
172 }
173 } else {
174 break;
175 }
176 }
177
178 if let Poll::Ready(Err(err)) = pin.connection.poll_flush_unpin(cx) {
181 trace!(target: "rpc", ?err, "websocket err");
182 return Poll::Ready(());
184 }
185
186 loop {
187 match pin.connection.poll_next_unpin(cx) {
188 Poll::Ready(Some(req)) => match req {
189 Ok(Some(req)) => {
190 pin.process_request(Ok(req));
191 }
192 Err(err) => match err {
193 RequestError::Axum(err) => {
194 trace!(target: "rpc", ?err, "client disconnected");
195 return Poll::Ready(());
196 }
197 RequestError::Io(err) => {
198 trace!(target: "rpc", ?err, "client disconnected");
199 return Poll::Ready(());
200 }
201 RequestError::Serde(err) => {
202 pin.process_request(Err(err));
203 }
204 RequestError::Disconnect => {
205 trace!(target: "rpc", "client disconnected");
206 return Poll::Ready(());
207 }
208 },
209 _ => {}
210 },
211 Poll::Ready(None) => {
212 trace!(target: "rpc", "socket connection finished");
213 return Poll::Ready(());
214 }
215 Poll::Pending => break,
216 }
217 }
218
219 let mut progress = false;
220 for n in (0..pin.processing.len()).rev() {
221 let mut req = pin.processing.swap_remove(n);
222 #[allow(clippy::collapsible_match)]
223 match req.poll_unpin(cx) {
224 Poll::Ready(Some(resp)) => {
225 if let Ok(text) = serde_json::to_string(&resp) {
226 pin.pending.push_back(text);
227 progress = true;
228 }
229 }
230 Poll::Ready(None) => {}
231 Poll::Pending => pin.processing.push(req),
232 }
233 }
234
235 {
236 let mut subscriptions = pin.context.subscriptions.lock();
238 'outer: for n in (0..subscriptions.len()).rev() {
239 let (id, mut sub) = subscriptions.swap_remove(n);
240 'inner: loop {
241 #[allow(clippy::collapsible_match)]
242 match sub.poll_next_unpin(cx) {
243 Poll::Ready(Some(res)) => {
244 if let Ok(text) = serde_json::to_string(&res) {
245 pin.pending.push_back(text);
246 progress = true;
247 }
248 }
249 Poll::Ready(None) => continue 'outer,
250 Poll::Pending => break 'inner,
251 }
252 }
253
254 subscriptions.push((id, sub));
255 }
256 }
257
258 if !progress {
259 return Poll::Pending;
260 }
261 }
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use anvil_rpc::{
269 request::{RequestParams, RpcCall, RpcNotification, Version},
270 response::RpcResponse,
271 };
272 use std::{pin::pin, task::Waker};
273
274 #[derive(Clone)]
275 struct TestHandler;
276
277 #[async_trait::async_trait]
278 impl PubSubRpcHandler for TestHandler {
279 type Request = serde_json::Value;
280 type SubscriptionId = u64;
281 type Subscription = futures::stream::Empty<serde_json::Value>;
282
283 async fn on_request(
284 &self,
285 _request: Self::Request,
286 _cx: PubSubContext<Self>,
287 ) -> ResponseResult {
288 ResponseResult::success(serde_json::Value::Null)
289 }
290 }
291
292 fn notification() -> RpcCall {
293 RpcCall::Notification(RpcNotification {
294 jsonrpc: Some(Version::V2),
295 method: "eth_subscribe".to_owned(),
296 params: RequestParams::None,
297 })
298 }
299
300 fn run_ready<F: Future>(future: F) -> F::Output {
301 let waker = Waker::noop();
302 let mut cx = Context::from_waker(waker);
303 let mut future = pin!(future);
304 match future.as_mut().poll(&mut cx) {
305 Poll::Ready(output) => output,
306 Poll::Pending => panic!("future unexpectedly pending"),
307 }
308 }
309
310 #[test]
311 fn process_request_keeps_empty_batch_invalid() {
312 let mut connection = PubSubConnection::new((), TestHandler);
313 connection.process_request(Ok(Request::Batch(vec![])));
314
315 let response = run_ready(connection.processing.pop().unwrap());
316 assert_eq!(
317 response,
318 Some(Response::Single(RpcResponse::from(RpcError::invalid_request())))
319 );
320 }
321
322 #[test]
323 fn process_request_skips_notification_only_batch_response() {
324 let mut connection = PubSubConnection::new((), TestHandler);
325 connection.process_request(Ok(Request::Batch(vec![notification()])));
326
327 let response = run_ready(connection.processing.pop().unwrap());
328 assert_eq!(response, None);
329 }
330}