anvil_server/
ws.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
use crate::{error::RequestError, pubsub::PubSubConnection, PubSubRpcHandler};
use anvil_rpc::request::Request;
use axum::{
    extract::{
        ws::{Message, WebSocket},
        State, WebSocketUpgrade,
    },
    response::Response,
};
use futures::{ready, Sink, Stream};
use std::{
    pin::Pin,
    task::{Context, Poll},
};

/// Handles incoming Websocket upgrade
///
/// This is the entrypoint invoked by the axum server for a websocket request
pub async fn handle_ws<Http, Ws: PubSubRpcHandler>(
    ws: WebSocketUpgrade,
    State((_, handler)): State<(Http, Ws)>,
) -> Response {
    ws.on_upgrade(|socket| PubSubConnection::new(SocketConn(socket), handler))
}

#[pin_project::pin_project]
struct SocketConn(#[pin] WebSocket);

impl Stream for SocketConn {
    type Item = Result<Option<Request>, RequestError>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match ready!(self.project().0.poll_next(cx)) {
            Some(msg) => Poll::Ready(Some(on_message(msg))),
            _ => Poll::Ready(None),
        }
    }
}

impl Sink<String> for SocketConn {
    type Error = axum::Error;

    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.project().0.poll_ready(cx)
    }

    fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
        self.project().0.start_send(Message::Text(item))
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.project().0.poll_flush(cx)
    }

    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.project().0.poll_close(cx)
    }
}

fn on_message(msg: Result<Message, axum::Error>) -> Result<Option<Request>, RequestError> {
    match msg? {
        Message::Text(text) => Ok(Some(serde_json::from_str(&text)?)),
        Message::Binary(data) => {
            // the binary payload type is the request as-is but as bytes, if this is a valid
            // `Request` then we can deserialize the Json from the data Vec
            Ok(Some(serde_json::from_slice(&data)?))
        }
        Message::Close(_) => {
            trace!(target: "rpc::ws", "ws client disconnected");
            Err(RequestError::Disconnect)
        }
        _ => Ok(None),
    }
}