1use crate::{DEFAULT_USER_AGENT, REQUEST_TIMEOUT};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use alloy_pubsub::{PubSubConnect, PubSubFrontend};
7use alloy_rpc_types::engine::{Claims, JwtSecret};
8use alloy_transport::{
9 Authorization, BoxTransport, TransportError, TransportErrorKind, TransportFut,
10};
11use alloy_transport_http::Http;
12use alloy_transport_ipc::IpcConnect;
13use alloy_transport_ws::WsConnect;
14use reqwest::header::{HeaderName, HeaderValue};
15use std::{fmt, path::PathBuf, str::FromStr, sync::Arc};
16use thiserror::Error;
17use tokio::sync::RwLock;
18use tower::Service;
19use url::Url;
20
21#[derive(Clone, Debug)]
24pub enum InnerTransport {
25 Http(Http<reqwest::Client>),
27 Ws(PubSubFrontend),
29 Ipc(PubSubFrontend),
31}
32
33#[derive(Error, Debug)]
35pub enum RuntimeTransportError {
36 #[error("Internal transport error: {0} with {1}")]
38 TransportError(TransportError, String),
39
40 #[error("Failed to lock the transport")]
42 LockError,
43
44 #[error("URL scheme is not supported: {0}")]
46 BadScheme(String),
47
48 #[error("Invalid HTTP header: {0}")]
50 BadHeader(String),
51
52 #[error("Invalid IPC file path: {0}")]
54 BadPath(String),
55
56 #[error(transparent)]
58 HttpConstructionError(#[from] reqwest::Error),
59
60 #[error("Invalid JWT: {0}")]
62 InvalidJwt(String),
63}
64
65#[derive(Clone, Debug, Error)]
72pub struct RuntimeTransport {
73 inner: Arc<RwLock<Option<InnerTransport>>>,
75 url: Url,
77 headers: Vec<String>,
79 jwt: Option<String>,
81 timeout: std::time::Duration,
83 accept_invalid_certs: bool,
85}
86
87#[derive(Debug)]
89pub struct RuntimeTransportBuilder {
90 url: Url,
91 headers: Vec<String>,
92 jwt: Option<String>,
93 timeout: std::time::Duration,
94 accept_invalid_certs: bool,
95}
96
97impl RuntimeTransportBuilder {
98 pub fn new(url: Url) -> Self {
100 Self {
101 url,
102 headers: vec![],
103 jwt: None,
104 timeout: REQUEST_TIMEOUT,
105 accept_invalid_certs: false,
106 }
107 }
108
109 pub fn with_headers(mut self, headers: Vec<String>) -> Self {
111 self.headers = headers;
112 self
113 }
114
115 pub fn with_jwt(mut self, jwt: Option<String>) -> Self {
117 self.jwt = jwt;
118 self
119 }
120
121 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
123 self.timeout = timeout;
124 self
125 }
126
127 pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
129 self.accept_invalid_certs = accept_invalid_certs;
130 self
131 }
132
133 pub fn build(self) -> RuntimeTransport {
136 RuntimeTransport {
137 inner: Arc::new(RwLock::new(None)),
138 url: self.url,
139 headers: self.headers,
140 jwt: self.jwt,
141 timeout: self.timeout,
142 accept_invalid_certs: self.accept_invalid_certs,
143 }
144 }
145}
146
147impl fmt::Display for RuntimeTransport {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 write!(f, "RuntimeTransport {}", self.url)
150 }
151}
152
153impl RuntimeTransport {
154 pub async fn connect(&self) -> Result<InnerTransport, RuntimeTransportError> {
156 match self.url.scheme() {
157 "http" | "https" => self.connect_http(),
158 "ws" | "wss" => self.connect_ws().await,
159 "file" => self.connect_ipc().await,
160 _ => Err(RuntimeTransportError::BadScheme(self.url.scheme().to_string())),
161 }
162 }
163
164 pub fn reqwest_client(&self) -> Result<reqwest::Client, RuntimeTransportError> {
166 let mut client_builder = reqwest::Client::builder()
167 .timeout(self.timeout)
168 .tls_built_in_root_certs(self.url.scheme() == "https")
169 .danger_accept_invalid_certs(self.accept_invalid_certs);
170 let mut headers = reqwest::header::HeaderMap::new();
171
172 if let Some(jwt) = self.jwt.clone() {
174 let auth =
175 build_auth(jwt).map_err(|e| RuntimeTransportError::InvalidJwt(e.to_string()))?;
176
177 let mut auth_value: HeaderValue =
178 HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
179 auth_value.set_sensitive(true);
180
181 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
182 };
183
184 for header in &self.headers {
186 let make_err = || RuntimeTransportError::BadHeader(header.to_string());
187
188 let (key, val) = header.split_once(':').ok_or_else(make_err)?;
189
190 headers.insert(
191 HeaderName::from_str(key.trim()).map_err(|_| make_err())?,
192 HeaderValue::from_str(val.trim()).map_err(|_| make_err())?,
193 );
194 }
195
196 if !headers.contains_key(reqwest::header::USER_AGENT) {
197 headers.insert(
198 reqwest::header::USER_AGENT,
199 HeaderValue::from_str(DEFAULT_USER_AGENT)
200 .expect("User-Agent should be valid string"),
201 );
202 }
203
204 client_builder = client_builder.default_headers(headers);
205
206 Ok(client_builder.build()?)
207 }
208
209 fn connect_http(&self) -> Result<InnerTransport, RuntimeTransportError> {
211 let client = self.reqwest_client()?;
212 Ok(InnerTransport::Http(Http::with_client(client, self.url.clone())))
213 }
214
215 async fn connect_ws(&self) -> Result<InnerTransport, RuntimeTransportError> {
217 let auth = self.jwt.as_ref().and_then(|jwt| build_auth(jwt.clone()).ok());
218 let mut ws = WsConnect::new(self.url.to_string());
219 if let Some(auth) = auth {
220 ws = ws.with_auth(auth);
221 };
222 let service = ws
223 .into_service()
224 .await
225 .map_err(|e| RuntimeTransportError::TransportError(e, self.url.to_string()))?;
226 Ok(InnerTransport::Ws(service))
227 }
228
229 async fn connect_ipc(&self) -> Result<InnerTransport, RuntimeTransportError> {
231 let path = url_to_file_path(&self.url)
232 .map_err(|_| RuntimeTransportError::BadPath(self.url.to_string()))?;
233 let ipc_connector = IpcConnect::new(path.clone());
234 let ipc = ipc_connector.into_service().await.map_err(|e| {
235 RuntimeTransportError::TransportError(e, path.clone().display().to_string())
236 })?;
237 Ok(InnerTransport::Ipc(ipc))
238 }
239
240 pub fn request(&self, req: RequestPacket) -> TransportFut<'static> {
248 let this = self.clone();
249 Box::pin(async move {
250 let mut inner = this.inner.read().await;
251 if inner.is_none() {
252 drop(inner);
253 {
254 let mut inner_mut = this.inner.write().await;
255 if inner_mut.is_none() {
256 *inner_mut =
257 Some(this.connect().await.map_err(TransportErrorKind::custom)?);
258 }
259 }
260 inner = this.inner.read().await;
261 }
262
263 match inner.clone().expect("must've been initialized") {
265 InnerTransport::Http(mut http) => http.call(req),
266 InnerTransport::Ws(mut ws) => ws.call(req),
267 InnerTransport::Ipc(mut ipc) => ipc.call(req),
268 }
269 .await
270 })
271 }
272
273 pub fn boxed(self) -> BoxTransport
275 where
276 Self: Sized + Clone + Send + Sync + 'static,
277 {
278 BoxTransport::new(self)
279 }
280}
281
282impl tower::Service<RequestPacket> for RuntimeTransport {
283 type Response = ResponsePacket;
284 type Error = TransportError;
285 type Future = TransportFut<'static>;
286
287 #[inline]
288 fn poll_ready(
289 &mut self,
290 _cx: &mut std::task::Context<'_>,
291 ) -> std::task::Poll<Result<(), Self::Error>> {
292 std::task::Poll::Ready(Ok(()))
293 }
294
295 #[inline]
296 fn call(&mut self, req: RequestPacket) -> Self::Future {
297 self.request(req)
298 }
299}
300
301impl tower::Service<RequestPacket> for &RuntimeTransport {
302 type Response = ResponsePacket;
303 type Error = TransportError;
304 type Future = TransportFut<'static>;
305
306 #[inline]
307 fn poll_ready(
308 &mut self,
309 _cx: &mut std::task::Context<'_>,
310 ) -> std::task::Poll<Result<(), Self::Error>> {
311 std::task::Poll::Ready(Ok(()))
312 }
313
314 #[inline]
315 fn call(&mut self, req: RequestPacket) -> Self::Future {
316 self.request(req)
317 }
318}
319
320fn build_auth(jwt: String) -> eyre::Result<Authorization> {
321 let secret = JwtSecret::from_hex(jwt)?;
323 let claims = Claims::default();
324 let token = secret.encode(&claims)?;
325
326 let auth = Authorization::Bearer(token);
327
328 Ok(auth)
329}
330
331#[cfg(windows)]
332fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
333 const PREFIX: &str = "file:///pipe/";
334
335 let url_str = url.as_str();
336
337 if let Some(pipe_name) = url_str.strip_prefix(PREFIX) {
338 let pipe_path = format!(r"\\.\pipe\{pipe_name}");
339 return Ok(PathBuf::from(pipe_path));
340 }
341
342 url.to_file_path()
343}
344
345#[cfg(not(windows))]
346fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
347 url.to_file_path()
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use reqwest::header::HeaderMap;
354
355 #[tokio::test]
356 async fn test_user_agent_header() {
357 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
358 let url = Url::parse(&format!("http://{}", listener.local_addr().unwrap())).unwrap();
359
360 let http_handler = axum::routing::get(|actual_headers: HeaderMap| {
361 let user_agent = HeaderName::from_str("User-Agent").unwrap();
362 assert_eq!(actual_headers[user_agent], HeaderValue::from_str("test-agent").unwrap());
363
364 async { "" }
365 });
366
367 let server_task = tokio::spawn(async move {
368 axum::serve(listener, http_handler.into_make_service()).await.unwrap()
369 });
370
371 let transport = RuntimeTransportBuilder::new(url.clone())
372 .with_headers(vec!["User-Agent: test-agent".to_string()])
373 .build();
374 let inner = transport.connect_http().unwrap();
375
376 match inner {
377 InnerTransport::Http(http) => {
378 let _ = http.client().get(url).send().await.unwrap();
379
380 }
382 _ => unreachable!(),
383 }
384
385 server_task.abort();
386 }
387}