1use crate::{RpcHandler, error::RequestError, handler::handle_request};
2use anvil_rpc::{
3 error::RpcError,
4 request::Request,
5 response::{Response, ResponseResult},
6};
7
8use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
9use parking_lot::Mutex;
10use serde::de::DeserializeOwned;
11use std::{
12 collections::VecDeque,
13 fmt,
14 hash::Hash,
15 pin::Pin,
16 sync::Arc,
17 task::{Context, Poll},
18};
19
20#[async_trait::async_trait]
22pub trait PubSubRpcHandler: Clone + Send + Sync + Unpin + 'static {
23 type Request: DeserializeOwned + Send + Sync + fmt::Debug;
25 type SubscriptionId: Hash + PartialEq + Eq + Send + Sync + fmt::Debug;
27 type Subscription: Stream<Item = serde_json::Value> + Send + Sync + Unpin;
29
30 async fn on_request(&self, request: Self::Request, cx: PubSubContext<Self>) -> ResponseResult;
32}
33
34type Subscriptions<SubscriptionId, Subscription> = Arc<Mutex<Vec<(SubscriptionId, Subscription)>>>;
35
36pub struct PubSubContext<Handler: PubSubRpcHandler> {
38 subscriptions: Subscriptions<Handler::SubscriptionId, Handler::Subscription>,
40}
41
42impl<Handler: PubSubRpcHandler> PubSubContext<Handler> {
43 pub fn add_subscription(
47 &self,
48 id: Handler::SubscriptionId,
49 subscription: Handler::Subscription,
50 ) -> Option<Handler::Subscription> {
51 let mut subscriptions = self.subscriptions.lock();
52 let mut removed = None;
53 if let Some(idx) = subscriptions.iter().position(|(i, _)| id == *i) {
54 trace!(target: "rpc", ?id, "removed subscription");
55 removed = Some(subscriptions.swap_remove(idx).1);
56 }
57 trace!(target: "rpc", ?id, "added subscription");
58 subscriptions.push((id, subscription));
59 removed
60 }
61
62 pub fn remove_subscription(
64 &self,
65 id: &Handler::SubscriptionId,
66 ) -> Option<Handler::Subscription> {
67 let mut subscriptions = self.subscriptions.lock();
68 if let Some(idx) = subscriptions.iter().position(|(i, _)| id == i) {
69 trace!(target: "rpc", ?id, "removed subscription");
70 return Some(subscriptions.swap_remove(idx).1);
71 }
72 None
73 }
74}
75
76impl<Handler: PubSubRpcHandler> Clone for PubSubContext<Handler> {
77 fn clone(&self) -> Self {
78 Self { subscriptions: Arc::clone(&self.subscriptions) }
79 }
80}
81
82impl<Handler: PubSubRpcHandler> Default for PubSubContext<Handler> {
83 fn default() -> Self {
84 Self { subscriptions: Arc::new(Mutex::new(Vec::new())) }
85 }
86}
87
88struct ContextAwareHandler<Handler: PubSubRpcHandler> {
90 handler: Handler,
91 context: PubSubContext<Handler>,
92}
93
94impl<Handler: PubSubRpcHandler> Clone for ContextAwareHandler<Handler> {
95 fn clone(&self) -> Self {
96 Self { handler: self.handler.clone(), context: self.context.clone() }
97 }
98}
99
100#[async_trait::async_trait]
101impl<Handler: PubSubRpcHandler> RpcHandler for ContextAwareHandler<Handler> {
102 type Request = Handler::Request;
103
104 async fn on_request(&self, request: Self::Request) -> ResponseResult {
105 self.handler.on_request(request, self.context.clone()).await
106 }
107}
108
109pub struct PubSubConnection<Handler: PubSubRpcHandler, Connection> {
113 handler: Handler,
115 context: PubSubContext<Handler>,
117 connection: Connection,
119 processing: Vec<Pin<Box<dyn Future<Output = Response> + Send>>>,
121 pending: VecDeque<String>,
123}
124
125impl<Handler: PubSubRpcHandler, Connection> PubSubConnection<Handler, Connection> {
126 pub fn new(connection: Connection, handler: Handler) -> Self {
127 Self {
128 connection,
129 handler,
130 context: Default::default(),
131 pending: Default::default(),
132 processing: Default::default(),
133 }
134 }
135
136 fn compat_helper(&self) -> ContextAwareHandler<Handler> {
138 ContextAwareHandler { handler: self.handler.clone(), context: self.context.clone() }
139 }
140
141 fn process_request(&mut self, req: serde_json::Result<Request>) {
142 let handler = self.compat_helper();
143 self.processing.push(Box::pin(async move {
144 match req {
145 Ok(req) => handle_request(req, handler)
146 .await
147 .unwrap_or_else(|| Response::error(RpcError::invalid_request())),
148 Err(err) => {
149 error!(target: "rpc", ?err, "invalid request");
150 Response::error(RpcError::invalid_request())
151 }
152 }
153 }));
154 }
155}
156
157impl<Handler, Connection> Future for PubSubConnection<Handler, Connection>
158where
159 Handler: PubSubRpcHandler,
160 Connection: Sink<String> + Stream<Item = Result<Option<Request>, RequestError>> + Unpin,
161 <Connection as Sink<String>>::Error: fmt::Debug,
162{
163 type Output = ();
164
165 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166 let pin = self.get_mut();
167 loop {
168 while matches!(pin.connection.poll_ready_unpin(cx), Poll::Ready(Ok(()))) {
170 if let Some(msg) = pin.pending.pop_front() {
172 if let Err(err) = pin.connection.start_send_unpin(msg) {
173 error!(target: "rpc", ?err, "Failed to send message");
174 }
175 } else {
176 break;
177 }
178 }
179
180 if let Poll::Ready(Err(err)) = pin.connection.poll_flush_unpin(cx) {
183 trace!(target: "rpc", ?err, "websocket err");
184 return Poll::Ready(());
186 }
187
188 loop {
189 match pin.connection.poll_next_unpin(cx) {
190 Poll::Ready(Some(req)) => match req {
191 Ok(Some(req)) => {
192 pin.process_request(Ok(req));
193 }
194 Err(err) => match err {
195 RequestError::Axum(err) => {
196 trace!(target: "rpc", ?err, "client disconnected");
197 return Poll::Ready(());
198 }
199 RequestError::Io(err) => {
200 trace!(target: "rpc", ?err, "client disconnected");
201 return Poll::Ready(());
202 }
203 RequestError::Serde(err) => {
204 pin.process_request(Err(err));
205 }
206 RequestError::Disconnect => {
207 trace!(target: "rpc", "client disconnected");
208 return Poll::Ready(());
209 }
210 },
211 _ => {}
212 },
213 Poll::Ready(None) => {
214 trace!(target: "rpc", "socket connection finished");
215 return Poll::Ready(());
216 }
217 Poll::Pending => break,
218 }
219 }
220
221 let mut progress = false;
222 for n in (0..pin.processing.len()).rev() {
223 let mut req = pin.processing.swap_remove(n);
224 match req.poll_unpin(cx) {
225 Poll::Ready(resp) => {
226 if let Ok(text) = serde_json::to_string(&resp) {
227 pin.pending.push_back(text);
228 progress = true;
229 }
230 }
231 Poll::Pending => pin.processing.push(req),
232 }
233 }
234
235 {
236 let mut subscriptions = pin.context.subscriptions.lock();
238 'outer: for n in (0..subscriptions.len()).rev() {
239 let (id, mut sub) = subscriptions.swap_remove(n);
240 'inner: loop {
241 match sub.poll_next_unpin(cx) {
242 Poll::Ready(Some(res)) => {
243 if let Ok(text) = serde_json::to_string(&res) {
244 pin.pending.push_back(text);
245 progress = true;
246 }
247 }
248 Poll::Ready(None) => continue 'outer,
249 Poll::Pending => break 'inner,
250 }
251 }
252
253 subscriptions.push((id, sub));
254 }
255 }
256
257 if !progress {
258 return Poll::Pending;
259 }
260 }
261 }
262}