1use crate::{error::RequestError, handler::handle_request, RpcHandler};
2use anvil_rpc::{
3 error::RpcError,
4 request::Request,
5 response::{Response, ResponseResult},
6};
7
8use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
9use parking_lot::Mutex;
10use serde::de::DeserializeOwned;
11use std::{
12 collections::VecDeque,
13 fmt,
14 future::Future,
15 hash::Hash,
16 pin::Pin,
17 sync::Arc,
18 task::{Context, Poll},
19};
20
21#[async_trait::async_trait]
23pub trait PubSubRpcHandler: Clone + Send + Sync + Unpin + 'static {
24 type Request: DeserializeOwned + Send + Sync + fmt::Debug;
26 type SubscriptionId: Hash + PartialEq + Eq + Send + Sync + fmt::Debug;
28 type Subscription: Stream<Item = serde_json::Value> + Send + Sync + Unpin;
30
31 async fn on_request(&self, request: Self::Request, cx: PubSubContext<Self>) -> ResponseResult;
33}
34
35type Subscriptions<SubscriptionId, Subscription> = Arc<Mutex<Vec<(SubscriptionId, Subscription)>>>;
36
37pub struct PubSubContext<Handler: PubSubRpcHandler> {
39 subscriptions: Subscriptions<Handler::SubscriptionId, Handler::Subscription>,
41}
42
43impl<Handler: PubSubRpcHandler> PubSubContext<Handler> {
44 pub fn add_subscription(
48 &self,
49 id: Handler::SubscriptionId,
50 subscription: Handler::Subscription,
51 ) -> Option<Handler::Subscription> {
52 let mut subscriptions = self.subscriptions.lock();
53 let mut removed = None;
54 if let Some(idx) = subscriptions.iter().position(|(i, _)| id == *i) {
55 trace!(target: "rpc", ?id, "removed subscription");
56 removed = Some(subscriptions.swap_remove(idx).1);
57 }
58 trace!(target: "rpc", ?id, "added subscription");
59 subscriptions.push((id, subscription));
60 removed
61 }
62
63 pub fn remove_subscription(
65 &self,
66 id: &Handler::SubscriptionId,
67 ) -> Option<Handler::Subscription> {
68 let mut subscriptions = self.subscriptions.lock();
69 if let Some(idx) = subscriptions.iter().position(|(i, _)| id == i) {
70 trace!(target: "rpc", ?id, "removed subscription");
71 return Some(subscriptions.swap_remove(idx).1)
72 }
73 None
74 }
75}
76
77impl<Handler: PubSubRpcHandler> Clone for PubSubContext<Handler> {
78 fn clone(&self) -> Self {
79 Self { subscriptions: Arc::clone(&self.subscriptions) }
80 }
81}
82
83impl<Handler: PubSubRpcHandler> Default for PubSubContext<Handler> {
84 fn default() -> Self {
85 Self { subscriptions: Arc::new(Mutex::new(Vec::new())) }
86 }
87}
88
89struct ContextAwareHandler<Handler: PubSubRpcHandler> {
91 handler: Handler,
92 context: PubSubContext<Handler>,
93}
94
95impl<Handler: PubSubRpcHandler> Clone for ContextAwareHandler<Handler> {
96 fn clone(&self) -> Self {
97 Self { handler: self.handler.clone(), context: self.context.clone() }
98 }
99}
100
101#[async_trait::async_trait]
102impl<Handler: PubSubRpcHandler> RpcHandler for ContextAwareHandler<Handler> {
103 type Request = Handler::Request;
104
105 async fn on_request(&self, request: Self::Request) -> ResponseResult {
106 self.handler.on_request(request, self.context.clone()).await
107 }
108}
109
110pub struct PubSubConnection<Handler: PubSubRpcHandler, Connection> {
114 handler: Handler,
116 context: PubSubContext<Handler>,
118 connection: Connection,
120 processing: Vec<Pin<Box<dyn Future<Output = Response> + Send>>>,
122 pending: VecDeque<String>,
124}
125
126impl<Handler: PubSubRpcHandler, Connection> PubSubConnection<Handler, Connection> {
127 pub fn new(connection: Connection, handler: Handler) -> Self {
128 Self {
129 connection,
130 handler,
131 context: Default::default(),
132 pending: Default::default(),
133 processing: Default::default(),
134 }
135 }
136
137 fn compat_helper(&self) -> ContextAwareHandler<Handler> {
139 ContextAwareHandler { handler: self.handler.clone(), context: self.context.clone() }
140 }
141
142 fn process_request(&mut self, req: serde_json::Result<Request>) {
143 let handler = self.compat_helper();
144 self.processing.push(Box::pin(async move {
145 match req {
146 Ok(req) => handle_request(req, handler)
147 .await
148 .unwrap_or_else(|| Response::error(RpcError::invalid_request())),
149 Err(err) => {
150 error!(target: "rpc", ?err, "invalid request");
151 Response::error(RpcError::invalid_request())
152 }
153 }
154 }));
155 }
156}
157
158impl<Handler, Connection> Future for PubSubConnection<Handler, Connection>
159where
160 Handler: PubSubRpcHandler,
161 Connection: Sink<String> + Stream<Item = Result<Option<Request>, RequestError>> + Unpin,
162 <Connection as Sink<String>>::Error: fmt::Debug,
163{
164 type Output = ();
165
166 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167 let pin = self.get_mut();
168 loop {
169 while matches!(pin.connection.poll_ready_unpin(cx), Poll::Ready(Ok(()))) {
171 if let Some(msg) = pin.pending.pop_front() {
173 if let Err(err) = pin.connection.start_send_unpin(msg) {
174 error!(target: "rpc", ?err, "Failed to send message");
175 }
176 } else {
177 break
178 }
179 }
180
181 if let Poll::Ready(Err(err)) = pin.connection.poll_flush_unpin(cx) {
184 trace!(target: "rpc", ?err, "websocket err");
185 return Poll::Ready(())
187 }
188
189 loop {
190 match pin.connection.poll_next_unpin(cx) {
191 Poll::Ready(Some(req)) => match req {
192 Ok(Some(req)) => {
193 pin.process_request(Ok(req));
194 }
195 Err(err) => match err {
196 RequestError::Axum(err) => {
197 trace!(target: "rpc", ?err, "client disconnected");
198 return Poll::Ready(())
199 }
200 RequestError::Io(err) => {
201 trace!(target: "rpc", ?err, "client disconnected");
202 return Poll::Ready(())
203 }
204 RequestError::Serde(err) => {
205 pin.process_request(Err(err));
206 }
207 RequestError::Disconnect => {
208 trace!(target: "rpc", "client disconnected");
209 return Poll::Ready(())
210 }
211 },
212 _ => {}
213 },
214 Poll::Ready(None) => {
215 trace!(target: "rpc", "socket connection finished");
216 return Poll::Ready(())
217 }
218 Poll::Pending => break,
219 }
220 }
221
222 let mut progress = false;
223 for n in (0..pin.processing.len()).rev() {
224 let mut req = pin.processing.swap_remove(n);
225 match req.poll_unpin(cx) {
226 Poll::Ready(resp) => {
227 if let Ok(text) = serde_json::to_string(&resp) {
228 pin.pending.push_back(text);
229 progress = true;
230 }
231 }
232 Poll::Pending => pin.processing.push(req),
233 }
234 }
235
236 {
237 let mut subscriptions = pin.context.subscriptions.lock();
239 'outer: for n in (0..subscriptions.len()).rev() {
240 let (id, mut sub) = subscriptions.swap_remove(n);
241 'inner: loop {
242 match sub.poll_next_unpin(cx) {
243 Poll::Ready(Some(res)) => {
244 if let Ok(text) = serde_json::to_string(&res) {
245 pin.pending.push_back(text);
246 progress = true;
247 }
248 }
249 Poll::Ready(None) => continue 'outer,
250 Poll::Pending => break 'inner,
251 }
252 }
253
254 subscriptions.push((id, sub));
255 }
256 }
257
258 if !progress {
259 return Poll::Pending
260 }
261 }
262 }
263}