anvil_server/
ipc.rs

1//! IPC handling
2
3use crate::{error::RequestError, pubsub::PubSubConnection, PubSubRpcHandler};
4use anvil_rpc::request::Request;
5use bytes::{BufMut, BytesMut};
6use futures::{ready, Sink, Stream, StreamExt};
7use interprocess::local_socket::{self as ls, tokio::prelude::*};
8use std::{
9    future::Future,
10    io,
11    pin::Pin,
12    task::{Context, Poll},
13};
14
15/// An IPC connection for anvil
16///
17/// A Future that listens for incoming connections and spawns new connections
18pub struct IpcEndpoint<Handler> {
19    /// the handler for the websocket connection
20    handler: Handler,
21    /// The path to the socket
22    path: String,
23}
24
25impl<Handler: PubSubRpcHandler> IpcEndpoint<Handler> {
26    /// Creates a new endpoint with the given handler
27    pub fn new(handler: Handler, path: String) -> Self {
28        Self { handler, path }
29    }
30
31    /// Returns a stream of incoming connection handlers.
32    ///
33    /// This establishes the IPC endpoint, converts the incoming connections into handled
34    /// connections.
35    #[instrument(target = "ipc", skip_all)]
36    pub fn incoming(self) -> io::Result<impl Stream<Item = impl Future<Output = ()>>> {
37        let Self { handler, path } = self;
38
39        trace!(%path, "starting IPC server");
40
41        if cfg!(unix) {
42            // ensure the file does not exist
43            if std::fs::remove_file(&path).is_ok() {
44                warn!(%path, "removed existing file");
45            }
46        }
47
48        let name = to_name(path.as_ref())?;
49        let listener = ls::ListenerOptions::new().name(name).create_tokio()?;
50        let connections = futures::stream::unfold(listener, |listener| async move {
51            let conn = listener.accept().await;
52            Some((conn, listener))
53        });
54
55        trace!("established connection listener");
56
57        Ok(connections.filter_map(move |stream| {
58            let handler = handler.clone();
59            async move {
60                match stream {
61                    Ok(stream) => {
62                        trace!("successful incoming IPC connection");
63                        let framed = tokio_util::codec::Decoder::framed(JsonRpcCodec, stream);
64                        Some(PubSubConnection::new(IpcConn(framed), handler))
65                    }
66                    Err(err) => {
67                        trace!(%err, "unsuccessful incoming IPC connection");
68                        None
69                    }
70                }
71            }
72        }))
73    }
74}
75
76#[pin_project::pin_project]
77struct IpcConn<T>(#[pin] T);
78
79impl<T> Stream for IpcConn<T>
80where
81    T: Stream<Item = io::Result<String>>,
82{
83    type Item = Result<Option<Request>, RequestError>;
84
85    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
86        fn on_request(msg: io::Result<String>) -> Result<Option<Request>, RequestError> {
87            let text = msg?;
88            Ok(Some(serde_json::from_str(&text)?))
89        }
90        match ready!(self.project().0.poll_next(cx)) {
91            Some(req) => Poll::Ready(Some(on_request(req))),
92            _ => Poll::Ready(None),
93        }
94    }
95}
96
97impl<T> Sink<String> for IpcConn<T>
98where
99    T: Sink<String, Error = io::Error>,
100{
101    type Error = io::Error;
102
103    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104        // NOTE: we always flush here this prevents any backpressure buffer in the underlying
105        // `Framed` impl that would cause stalled requests
106        self.project().0.poll_flush(cx)
107    }
108
109    fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
110        self.project().0.start_send(item)
111    }
112
113    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114        self.project().0.poll_flush(cx)
115    }
116
117    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118        self.project().0.poll_close(cx)
119    }
120}
121
122struct JsonRpcCodec;
123
124// Adapted from <https://github.com/paritytech/jsonrpc/blob/38af3c9439aa75481805edf6c05c6622a5ab1e70/server-utils/src/stream_codec.rs#L47-L105>
125impl tokio_util::codec::Decoder for JsonRpcCodec {
126    type Item = String;
127    type Error = io::Error;
128
129    fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<Self::Item>> {
130        const fn is_whitespace(byte: u8) -> bool {
131            matches!(byte, 0x0D | 0x0A | 0x20 | 0x09)
132        }
133
134        let mut depth = 0;
135        let mut in_str = false;
136        let mut is_escaped = false;
137        let mut start_idx = 0;
138        let mut whitespaces = 0;
139
140        for idx in 0..buf.as_ref().len() {
141            let byte = buf.as_ref()[idx];
142
143            if (byte == b'{' || byte == b'[') && !in_str {
144                if depth == 0 {
145                    start_idx = idx;
146                }
147                depth += 1;
148            } else if (byte == b'}' || byte == b']') && !in_str {
149                depth -= 1;
150            } else if byte == b'"' && !is_escaped {
151                in_str = !in_str;
152            } else if is_whitespace(byte) {
153                whitespaces += 1;
154            }
155            is_escaped = byte == b'\\' && !is_escaped && in_str;
156
157            if depth == 0 && idx != start_idx && idx - start_idx + 1 > whitespaces {
158                let bts = buf.split_to(idx + 1);
159                return match String::from_utf8(bts.as_ref().to_vec()) {
160                    Ok(val) => Ok(Some(val)),
161                    Err(_) => Ok(None),
162                }
163            }
164        }
165        Ok(None)
166    }
167}
168
169impl tokio_util::codec::Encoder<String> for JsonRpcCodec {
170    type Error = io::Error;
171
172    fn encode(&mut self, msg: String, buf: &mut BytesMut) -> io::Result<()> {
173        buf.extend_from_slice(msg.as_bytes());
174        // Add newline character
175        buf.put_u8(b'\n');
176        Ok(())
177    }
178}
179
180fn to_name(path: &std::ffi::OsStr) -> io::Result<ls::Name<'_>> {
181    if cfg!(windows) && !path.as_encoded_bytes().starts_with(br"\\.\pipe\") {
182        ls::ToNsName::to_ns_name::<ls::GenericNamespaced>(path)
183    } else {
184        ls::ToFsName::to_fs_name::<ls::GenericFilePath>(path)
185    }
186}