foundry_common/provider/
runtime_transport.rs

1//! Runtime transport that connects on first request, which can take either of an HTTP,
2//! WebSocket, or IPC transport and supports retries based on CUPS logic.
3
4use crate::{DEFAULT_USER_AGENT, REQUEST_TIMEOUT};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use alloy_pubsub::{PubSubConnect, PubSubFrontend};
7use alloy_rpc_types::engine::{Claims, JwtSecret};
8use alloy_transport::{
9    Authorization, BoxTransport, TransportError, TransportErrorKind, TransportFut,
10};
11use alloy_transport_http::Http;
12use alloy_transport_ipc::IpcConnect;
13use alloy_transport_ws::WsConnect;
14use reqwest::header::{HeaderName, HeaderValue};
15use std::{fmt, path::PathBuf, str::FromStr, sync::Arc};
16use thiserror::Error;
17use tokio::sync::RwLock;
18use tower::Service;
19use url::Url;
20
21/// An enum representing the different transports that can be used to connect to a runtime.
22/// Only meant to be used internally by [RuntimeTransport].
23#[derive(Clone, Debug)]
24pub enum InnerTransport {
25    /// HTTP transport
26    Http(Http<reqwest::Client>),
27    /// WebSocket transport
28    Ws(PubSubFrontend),
29    /// IPC transport
30    Ipc(PubSubFrontend),
31}
32
33/// Error type for the runtime transport.
34#[derive(Error, Debug)]
35pub enum RuntimeTransportError {
36    /// Internal transport error
37    #[error("Internal transport error: {0} with {1}")]
38    TransportError(TransportError, String),
39
40    /// Failed to lock the transport
41    #[error("Failed to lock the transport")]
42    LockError,
43
44    /// Invalid URL scheme
45    #[error("URL scheme is not supported: {0}")]
46    BadScheme(String),
47
48    /// Invalid HTTP header
49    #[error("Invalid HTTP header: {0}")]
50    BadHeader(String),
51
52    /// Invalid file path
53    #[error("Invalid IPC file path: {0}")]
54    BadPath(String),
55
56    /// Invalid construction of Http provider
57    #[error(transparent)]
58    HttpConstructionError(#[from] reqwest::Error),
59
60    /// Invalid JWT
61    #[error("Invalid JWT: {0}")]
62    InvalidJwt(String),
63}
64
65/// Runtime transport that only connects on first request.
66///
67/// A runtime transport is a custom [`alloy_transport::Transport`] that only connects when the
68/// *first* request is made. When the first request is made, it will connect to the runtime using
69/// either an HTTP WebSocket, or IPC transport depending on the URL used.
70/// It also supports retries for rate-limiting and timeout-related errors.
71#[derive(Clone, Debug, Error)]
72pub struct RuntimeTransport {
73    /// The inner actual transport used.
74    inner: Arc<RwLock<Option<InnerTransport>>>,
75    /// The URL to connect to.
76    url: Url,
77    /// The headers to use for requests.
78    headers: Vec<String>,
79    /// The JWT to use for requests.
80    jwt: Option<String>,
81    /// The timeout for requests.
82    timeout: std::time::Duration,
83    /// Whether to accept invalid certificates.
84    accept_invalid_certs: bool,
85}
86
87/// A builder for [RuntimeTransport].
88#[derive(Debug)]
89pub struct RuntimeTransportBuilder {
90    url: Url,
91    headers: Vec<String>,
92    jwt: Option<String>,
93    timeout: std::time::Duration,
94    accept_invalid_certs: bool,
95}
96
97impl RuntimeTransportBuilder {
98    /// Create a new builder with the given URL.
99    pub fn new(url: Url) -> Self {
100        Self {
101            url,
102            headers: vec![],
103            jwt: None,
104            timeout: REQUEST_TIMEOUT,
105            accept_invalid_certs: false,
106        }
107    }
108
109    /// Set the URL for the transport.
110    pub fn with_headers(mut self, headers: Vec<String>) -> Self {
111        self.headers = headers;
112        self
113    }
114
115    /// Set the JWT for the transport.
116    pub fn with_jwt(mut self, jwt: Option<String>) -> Self {
117        self.jwt = jwt;
118        self
119    }
120
121    /// Set the timeout for the transport.
122    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
123        self.timeout = timeout;
124        self
125    }
126
127    /// Set whether to accept invalid certificates.
128    pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
129        self.accept_invalid_certs = accept_invalid_certs;
130        self
131    }
132
133    /// Builds the [RuntimeTransport] and returns it in a disconnected state.
134    /// The runtime transport will then connect when the first request happens.
135    pub fn build(self) -> RuntimeTransport {
136        RuntimeTransport {
137            inner: Arc::new(RwLock::new(None)),
138            url: self.url,
139            headers: self.headers,
140            jwt: self.jwt,
141            timeout: self.timeout,
142            accept_invalid_certs: self.accept_invalid_certs,
143        }
144    }
145}
146
147impl fmt::Display for RuntimeTransport {
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        write!(f, "RuntimeTransport {}", self.url)
150    }
151}
152
153impl RuntimeTransport {
154    /// Connects the underlying transport, depending on the URL scheme.
155    pub async fn connect(&self) -> Result<InnerTransport, RuntimeTransportError> {
156        match self.url.scheme() {
157            "http" | "https" => self.connect_http(),
158            "ws" | "wss" => self.connect_ws().await,
159            "file" => self.connect_ipc().await,
160            _ => Err(RuntimeTransportError::BadScheme(self.url.scheme().to_string())),
161        }
162    }
163
164    /// Creates a new reqwest client from this transport.
165    pub fn reqwest_client(&self) -> Result<reqwest::Client, RuntimeTransportError> {
166        let mut client_builder = reqwest::Client::builder()
167            .timeout(self.timeout)
168            .tls_built_in_root_certs(self.url.scheme() == "https")
169            .danger_accept_invalid_certs(self.accept_invalid_certs);
170        let mut headers = reqwest::header::HeaderMap::new();
171
172        // If there's a JWT, add it to the headers if we can decode it.
173        if let Some(jwt) = self.jwt.clone() {
174            let auth =
175                build_auth(jwt).map_err(|e| RuntimeTransportError::InvalidJwt(e.to_string()))?;
176
177            let mut auth_value: HeaderValue =
178                HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
179            auth_value.set_sensitive(true);
180
181            headers.insert(reqwest::header::AUTHORIZATION, auth_value);
182        };
183
184        // Add any custom headers.
185        for header in &self.headers {
186            let make_err = || RuntimeTransportError::BadHeader(header.to_string());
187
188            let (key, val) = header.split_once(':').ok_or_else(make_err)?;
189
190            headers.insert(
191                HeaderName::from_str(key.trim()).map_err(|_| make_err())?,
192                HeaderValue::from_str(val.trim()).map_err(|_| make_err())?,
193            );
194        }
195
196        if !headers.contains_key(reqwest::header::USER_AGENT) {
197            headers.insert(
198                reqwest::header::USER_AGENT,
199                HeaderValue::from_str(DEFAULT_USER_AGENT)
200                    .expect("User-Agent should be valid string"),
201            );
202        }
203
204        client_builder = client_builder.default_headers(headers);
205
206        Ok(client_builder.build()?)
207    }
208
209    /// Connects to an HTTP [alloy_transport_http::Http] transport.
210    fn connect_http(&self) -> Result<InnerTransport, RuntimeTransportError> {
211        let client = self.reqwest_client()?;
212        Ok(InnerTransport::Http(Http::with_client(client, self.url.clone())))
213    }
214
215    /// Connects to a WS transport.
216    async fn connect_ws(&self) -> Result<InnerTransport, RuntimeTransportError> {
217        let auth = self.jwt.as_ref().and_then(|jwt| build_auth(jwt.clone()).ok());
218        let mut ws = WsConnect::new(self.url.to_string());
219        if let Some(auth) = auth {
220            ws = ws.with_auth(auth);
221        };
222        let service = ws
223            .into_service()
224            .await
225            .map_err(|e| RuntimeTransportError::TransportError(e, self.url.to_string()))?;
226        Ok(InnerTransport::Ws(service))
227    }
228
229    /// Connects to an IPC transport.
230    async fn connect_ipc(&self) -> Result<InnerTransport, RuntimeTransportError> {
231        let path = url_to_file_path(&self.url)
232            .map_err(|_| RuntimeTransportError::BadPath(self.url.to_string()))?;
233        let ipc_connector = IpcConnect::new(path.clone());
234        let ipc = ipc_connector.into_service().await.map_err(|e| {
235            RuntimeTransportError::TransportError(e, path.clone().display().to_string())
236        })?;
237        Ok(InnerTransport::Ipc(ipc))
238    }
239
240    /// Sends a request using the underlying transport.
241    /// If this is the first request, it will connect to the appropriate transport depending on the
242    /// URL scheme. When sending the request, retries will be automatically handled depending
243    /// on the parameters set on the [RuntimeTransport].
244    /// For sending the actual request, this action is delegated down to the
245    /// underlying transport through Tower's [tower::Service::call]. See tower's [tower::Service]
246    /// trait for more information.
247    pub fn request(&self, req: RequestPacket) -> TransportFut<'static> {
248        let this = self.clone();
249        Box::pin(async move {
250            let mut inner = this.inner.read().await;
251            if inner.is_none() {
252                drop(inner);
253                {
254                    let mut inner_mut = this.inner.write().await;
255                    if inner_mut.is_none() {
256                        *inner_mut =
257                            Some(this.connect().await.map_err(TransportErrorKind::custom)?);
258                    }
259                }
260                inner = this.inner.read().await;
261            }
262
263            // SAFETY: We just checked that the inner transport exists.
264            match inner.clone().expect("must've been initialized") {
265                InnerTransport::Http(mut http) => http.call(req),
266                InnerTransport::Ws(mut ws) => ws.call(req),
267                InnerTransport::Ipc(mut ipc) => ipc.call(req),
268            }
269            .await
270        })
271    }
272
273    /// Convert this transport into a boxed trait object.
274    pub fn boxed(self) -> BoxTransport
275    where
276        Self: Sized + Clone + Send + Sync + 'static,
277    {
278        BoxTransport::new(self)
279    }
280}
281
282impl tower::Service<RequestPacket> for RuntimeTransport {
283    type Response = ResponsePacket;
284    type Error = TransportError;
285    type Future = TransportFut<'static>;
286
287    #[inline]
288    fn poll_ready(
289        &mut self,
290        _cx: &mut std::task::Context<'_>,
291    ) -> std::task::Poll<Result<(), Self::Error>> {
292        std::task::Poll::Ready(Ok(()))
293    }
294
295    #[inline]
296    fn call(&mut self, req: RequestPacket) -> Self::Future {
297        self.request(req)
298    }
299}
300
301impl tower::Service<RequestPacket> for &RuntimeTransport {
302    type Response = ResponsePacket;
303    type Error = TransportError;
304    type Future = TransportFut<'static>;
305
306    #[inline]
307    fn poll_ready(
308        &mut self,
309        _cx: &mut std::task::Context<'_>,
310    ) -> std::task::Poll<Result<(), Self::Error>> {
311        std::task::Poll::Ready(Ok(()))
312    }
313
314    #[inline]
315    fn call(&mut self, req: RequestPacket) -> Self::Future {
316        self.request(req)
317    }
318}
319
320fn build_auth(jwt: String) -> eyre::Result<Authorization> {
321    // Decode jwt from hex, then generate claims (iat with current timestamp)
322    let secret = JwtSecret::from_hex(jwt)?;
323    let claims = Claims::default();
324    let token = secret.encode(&claims)?;
325
326    let auth = Authorization::Bearer(token);
327
328    Ok(auth)
329}
330
331#[cfg(windows)]
332fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
333    const PREFIX: &str = "file:///pipe/";
334
335    let url_str = url.as_str();
336
337    if let Some(pipe_name) = url_str.strip_prefix(PREFIX) {
338        let pipe_path = format!(r"\\.\pipe\{pipe_name}");
339        return Ok(PathBuf::from(pipe_path));
340    }
341
342    url.to_file_path()
343}
344
345#[cfg(not(windows))]
346fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
347    url.to_file_path()
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use reqwest::header::HeaderMap;
354
355    #[tokio::test]
356    async fn test_user_agent_header() {
357        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
358        let url = Url::parse(&format!("http://{}", listener.local_addr().unwrap())).unwrap();
359
360        let http_handler = axum::routing::get(|actual_headers: HeaderMap| {
361            let user_agent = HeaderName::from_str("User-Agent").unwrap();
362            assert_eq!(actual_headers[user_agent], HeaderValue::from_str("test-agent").unwrap());
363
364            async { "" }
365        });
366
367        let server_task = tokio::spawn(async move {
368            axum::serve(listener, http_handler.into_make_service()).await.unwrap()
369        });
370
371        let transport = RuntimeTransportBuilder::new(url.clone())
372            .with_headers(vec!["User-Agent: test-agent".to_string()])
373            .build();
374        let inner = transport.connect_http().unwrap();
375
376        match inner {
377            InnerTransport::Http(http) => {
378                let _ = http.client().get(url).send().await.unwrap();
379
380                // assert inside http_handler
381            }
382            _ => unreachable!(),
383        }
384
385        server_task.abort();
386    }
387}