anvil_server/
ws.rs

1use crate::{error::RequestError, pubsub::PubSubConnection, PubSubRpcHandler};
2use anvil_rpc::request::Request;
3use axum::{
4    extract::{
5        ws::{Message, WebSocket},
6        State, WebSocketUpgrade,
7    },
8    response::Response,
9};
10use futures::{ready, Sink, Stream};
11use std::{
12    pin::Pin,
13    task::{Context, Poll},
14};
15
16/// Handles incoming Websocket upgrade
17///
18/// This is the entrypoint invoked by the axum server for a websocket request
19pub async fn handle_ws<Http, Ws: PubSubRpcHandler>(
20    ws: WebSocketUpgrade,
21    State((_, handler)): State<(Http, Ws)>,
22) -> Response {
23    ws.on_upgrade(|socket| PubSubConnection::new(SocketConn(socket), handler))
24}
25
26#[pin_project::pin_project]
27struct SocketConn(#[pin] WebSocket);
28
29impl Stream for SocketConn {
30    type Item = Result<Option<Request>, RequestError>;
31
32    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
33        match ready!(self.project().0.poll_next(cx)) {
34            Some(msg) => Poll::Ready(Some(on_message(msg))),
35            _ => Poll::Ready(None),
36        }
37    }
38}
39
40impl Sink<String> for SocketConn {
41    type Error = axum::Error;
42
43    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
44        self.project().0.poll_ready(cx)
45    }
46
47    fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
48        self.project().0.start_send(Message::Text(item))
49    }
50
51    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
52        self.project().0.poll_flush(cx)
53    }
54
55    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56        self.project().0.poll_close(cx)
57    }
58}
59
60fn on_message(msg: Result<Message, axum::Error>) -> Result<Option<Request>, RequestError> {
61    match msg? {
62        Message::Text(text) => Ok(Some(serde_json::from_str(&text)?)),
63        Message::Binary(data) => {
64            // the binary payload type is the request as-is but as bytes, if this is a valid
65            // `Request` then we can deserialize the Json from the data Vec
66            Ok(Some(serde_json::from_slice(&data)?))
67        }
68        Message::Close(_) => {
69            trace!(target: "rpc::ws", "ws client disconnected");
70            Err(RequestError::Disconnect)
71        }
72        _ => Ok(None),
73    }
74}