anvil_server/
ws.rs
1use crate::{error::RequestError, pubsub::PubSubConnection, PubSubRpcHandler};
2use anvil_rpc::request::Request;
3use axum::{
4 extract::{
5 ws::{Message, WebSocket},
6 State, WebSocketUpgrade,
7 },
8 response::Response,
9};
10use futures::{ready, Sink, Stream};
11use std::{
12 pin::Pin,
13 task::{Context, Poll},
14};
15
16pub async fn handle_ws<Http, Ws: PubSubRpcHandler>(
20 ws: WebSocketUpgrade,
21 State((_, handler)): State<(Http, Ws)>,
22) -> Response {
23 ws.on_upgrade(|socket| PubSubConnection::new(SocketConn(socket), handler))
24}
25
26#[pin_project::pin_project]
27struct SocketConn(#[pin] WebSocket);
28
29impl Stream for SocketConn {
30 type Item = Result<Option<Request>, RequestError>;
31
32 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
33 match ready!(self.project().0.poll_next(cx)) {
34 Some(msg) => Poll::Ready(Some(on_message(msg))),
35 _ => Poll::Ready(None),
36 }
37 }
38}
39
40impl Sink<String> for SocketConn {
41 type Error = axum::Error;
42
43 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
44 self.project().0.poll_ready(cx)
45 }
46
47 fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
48 self.project().0.start_send(Message::Text(item))
49 }
50
51 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
52 self.project().0.poll_flush(cx)
53 }
54
55 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56 self.project().0.poll_close(cx)
57 }
58}
59
60fn on_message(msg: Result<Message, axum::Error>) -> Result<Option<Request>, RequestError> {
61 match msg? {
62 Message::Text(text) => Ok(Some(serde_json::from_str(&text)?)),
63 Message::Binary(data) => {
64 Ok(Some(serde_json::from_slice(&data)?))
67 }
68 Message::Close(_) => {
69 trace!(target: "rpc::ws", "ws client disconnected");
70 Err(RequestError::Disconnect)
71 }
72 _ => Ok(None),
73 }
74}