anvil_server/
ipc.rs

1//! IPC handling
2
3use crate::{PubSubRpcHandler, error::RequestError, pubsub::PubSubConnection};
4use anvil_rpc::request::Request;
5use bytes::{BufMut, BytesMut};
6use futures::{Sink, Stream, StreamExt, ready};
7use interprocess::local_socket::{self as ls, tokio::prelude::*};
8use std::{
9    io,
10    pin::Pin,
11    task::{Context, Poll},
12};
13
14/// An IPC connection for anvil
15///
16/// A Future that listens for incoming connections and spawns new connections
17pub struct IpcEndpoint<Handler> {
18    /// the handler for the websocket connection
19    handler: Handler,
20    /// The path to the socket
21    path: String,
22}
23
24impl<Handler: PubSubRpcHandler> IpcEndpoint<Handler> {
25    /// Creates a new endpoint with the given handler
26    pub fn new(handler: Handler, path: String) -> Self {
27        Self { handler, path }
28    }
29
30    /// Returns a stream of incoming connection handlers.
31    ///
32    /// This establishes the IPC endpoint, converts the incoming connections into handled
33    /// connections.
34    #[instrument(target = "ipc", skip_all)]
35    pub fn incoming(self) -> io::Result<impl Stream<Item = impl Future<Output = ()>>> {
36        let Self { handler, path } = self;
37
38        trace!(%path, "starting IPC server");
39
40        if cfg!(unix) {
41            // ensure the file does not exist
42            if std::fs::remove_file(&path).is_ok() {
43                warn!(%path, "removed existing file");
44            }
45        }
46
47        let name = to_name(path.as_ref())?;
48        let listener = ls::ListenerOptions::new().name(name).create_tokio()?;
49        let connections = futures::stream::unfold(listener, |listener| async move {
50            let conn = listener.accept().await;
51            Some((conn, listener))
52        });
53
54        trace!("established connection listener");
55
56        Ok(connections.filter_map(move |stream| {
57            let handler = handler.clone();
58            async move {
59                match stream {
60                    Ok(stream) => {
61                        trace!("successful incoming IPC connection");
62                        let framed = tokio_util::codec::Decoder::framed(JsonRpcCodec, stream);
63                        Some(PubSubConnection::new(IpcConn(framed), handler))
64                    }
65                    Err(err) => {
66                        trace!(%err, "unsuccessful incoming IPC connection");
67                        None
68                    }
69                }
70            }
71        }))
72    }
73}
74
75#[pin_project::pin_project]
76struct IpcConn<T>(#[pin] T);
77
78impl<T> Stream for IpcConn<T>
79where
80    T: Stream<Item = io::Result<String>>,
81{
82    type Item = Result<Option<Request>, RequestError>;
83
84    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
85        fn on_request(msg: io::Result<String>) -> Result<Option<Request>, RequestError> {
86            let text = msg?;
87            Ok(Some(serde_json::from_str(&text)?))
88        }
89        match ready!(self.project().0.poll_next(cx)) {
90            Some(req) => Poll::Ready(Some(on_request(req))),
91            _ => Poll::Ready(None),
92        }
93    }
94}
95
96impl<T> Sink<String> for IpcConn<T>
97where
98    T: Sink<String, Error = io::Error>,
99{
100    type Error = io::Error;
101
102    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103        // NOTE: we always flush here this prevents any backpressure buffer in the underlying
104        // `Framed` impl that would cause stalled requests
105        self.project().0.poll_flush(cx)
106    }
107
108    fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
109        self.project().0.start_send(item)
110    }
111
112    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113        self.project().0.poll_flush(cx)
114    }
115
116    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117        self.project().0.poll_close(cx)
118    }
119}
120
121struct JsonRpcCodec;
122
123// Adapted from <https://github.com/paritytech/jsonrpc/blob/38af3c9439aa75481805edf6c05c6622a5ab1e70/server-utils/src/stream_codec.rs#L47-L105>
124impl tokio_util::codec::Decoder for JsonRpcCodec {
125    type Item = String;
126    type Error = io::Error;
127
128    fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<Self::Item>> {
129        const fn is_whitespace(byte: u8) -> bool {
130            matches!(byte, 0x0D | 0x0A | 0x20 | 0x09)
131        }
132
133        let mut depth = 0;
134        let mut in_str = false;
135        let mut is_escaped = false;
136        let mut start_idx = 0;
137        let mut whitespaces = 0;
138
139        for idx in 0..buf.as_ref().len() {
140            let byte = buf.as_ref()[idx];
141
142            if (byte == b'{' || byte == b'[') && !in_str {
143                if depth == 0 {
144                    start_idx = idx;
145                }
146                depth += 1;
147            } else if (byte == b'}' || byte == b']') && !in_str {
148                depth -= 1;
149            } else if byte == b'"' && !is_escaped {
150                in_str = !in_str;
151            } else if is_whitespace(byte) {
152                whitespaces += 1;
153            }
154            is_escaped = byte == b'\\' && !is_escaped && in_str;
155
156            if depth == 0 && idx != start_idx && idx - start_idx + 1 > whitespaces {
157                let bts = buf.split_to(idx + 1);
158                return match String::from_utf8(bts.as_ref().to_vec()) {
159                    Ok(val) => Ok(Some(val)),
160                    Err(_) => Ok(None),
161                };
162            }
163        }
164        Ok(None)
165    }
166}
167
168impl tokio_util::codec::Encoder<String> for JsonRpcCodec {
169    type Error = io::Error;
170
171    fn encode(&mut self, msg: String, buf: &mut BytesMut) -> io::Result<()> {
172        buf.extend_from_slice(msg.as_bytes());
173        // Add newline character
174        buf.put_u8(b'\n');
175        Ok(())
176    }
177}
178
179fn to_name(path: &std::ffi::OsStr) -> io::Result<ls::Name<'_>> {
180    if cfg!(windows) && !path.as_encoded_bytes().starts_with(br"\\.\pipe\") {
181        ls::ToNsName::to_ns_name::<ls::GenericNamespaced>(path)
182    } else {
183        ls::ToFsName::to_fs_name::<ls::GenericFilePath>(path)
184    }
185}