1use crate::{DEFAULT_USER_AGENT, REQUEST_TIMEOUT};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use alloy_pubsub::{PubSubConnect, PubSubFrontend};
7use alloy_rpc_types::engine::{Claims, JwtSecret};
8use alloy_transport::{
9 Authorization, BoxTransport, TransportError, TransportErrorKind, TransportFut,
10};
11use alloy_transport_http::Http;
12use alloy_transport_ipc::IpcConnect;
13use alloy_transport_ws::WsConnect;
14use reqwest::header::{HeaderName, HeaderValue};
15use std::{fmt, path::PathBuf, str::FromStr, sync::Arc};
16use thiserror::Error;
17use tokio::sync::RwLock;
18use tower::Service;
19use url::Url;
20
21#[derive(Clone, Debug)]
24pub enum InnerTransport {
25 Http(Http<reqwest::Client>),
27 Ws(PubSubFrontend),
29 Ipc(PubSubFrontend),
31}
32
33#[derive(Error, Debug)]
35pub enum RuntimeTransportError {
36 #[error("Internal transport error: {0} with {1}")]
38 TransportError(TransportError, String),
39
40 #[error("Failed to lock the transport")]
42 LockError,
43
44 #[error("URL scheme is not supported: {0}")]
46 BadScheme(String),
47
48 #[error("Invalid HTTP header: {0}")]
50 BadHeader(String),
51
52 #[error("Invalid IPC file path: {0}")]
54 BadPath(String),
55
56 #[error(transparent)]
58 HttpConstructionError(#[from] reqwest::Error),
59
60 #[error("Invalid JWT: {0}")]
62 InvalidJwt(String),
63}
64
65#[derive(Clone, Debug, Error)]
72pub struct RuntimeTransport {
73 inner: Arc<RwLock<Option<InnerTransport>>>,
75 url: Url,
77 headers: Vec<String>,
79 jwt: Option<String>,
81 timeout: std::time::Duration,
83}
84
85#[derive(Debug)]
87pub struct RuntimeTransportBuilder {
88 url: Url,
89 headers: Vec<String>,
90 jwt: Option<String>,
91 timeout: std::time::Duration,
92}
93
94impl RuntimeTransportBuilder {
95 pub fn new(url: Url) -> Self {
97 Self { url, headers: vec![], jwt: None, timeout: REQUEST_TIMEOUT }
98 }
99
100 pub fn with_headers(mut self, headers: Vec<String>) -> Self {
102 self.headers = headers;
103 self
104 }
105
106 pub fn with_jwt(mut self, jwt: Option<String>) -> Self {
108 self.jwt = jwt;
109 self
110 }
111
112 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
114 self.timeout = timeout;
115 self
116 }
117
118 pub fn build(self) -> RuntimeTransport {
121 RuntimeTransport {
122 inner: Arc::new(RwLock::new(None)),
123 url: self.url,
124 headers: self.headers,
125 jwt: self.jwt,
126 timeout: self.timeout,
127 }
128 }
129}
130
131impl fmt::Display for RuntimeTransport {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 write!(f, "RuntimeTransport {}", self.url)
134 }
135}
136
137impl RuntimeTransport {
138 pub async fn connect(&self) -> Result<InnerTransport, RuntimeTransportError> {
140 match self.url.scheme() {
141 "http" | "https" => self.connect_http(),
142 "ws" | "wss" => self.connect_ws().await,
143 "file" => self.connect_ipc().await,
144 _ => Err(RuntimeTransportError::BadScheme(self.url.scheme().to_string())),
145 }
146 }
147
148 pub fn reqwest_client(&self) -> Result<reqwest::Client, RuntimeTransportError> {
150 let mut client_builder = reqwest::Client::builder()
151 .timeout(self.timeout)
152 .tls_built_in_root_certs(self.url.scheme() == "https");
153 let mut headers = reqwest::header::HeaderMap::new();
154
155 if let Some(jwt) = self.jwt.clone() {
157 let auth =
158 build_auth(jwt).map_err(|e| RuntimeTransportError::InvalidJwt(e.to_string()))?;
159
160 let mut auth_value: HeaderValue =
161 HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
162 auth_value.set_sensitive(true);
163
164 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
165 };
166
167 for header in &self.headers {
169 let make_err = || RuntimeTransportError::BadHeader(header.to_string());
170
171 let (key, val) = header.split_once(':').ok_or_else(make_err)?;
172
173 headers.insert(
174 HeaderName::from_str(key.trim()).map_err(|_| make_err())?,
175 HeaderValue::from_str(val.trim()).map_err(|_| make_err())?,
176 );
177 }
178
179 if !headers.contains_key(reqwest::header::USER_AGENT) {
180 headers.insert(
181 reqwest::header::USER_AGENT,
182 HeaderValue::from_str(DEFAULT_USER_AGENT)
183 .expect("User-Agent should be valid string"),
184 );
185 }
186
187 client_builder = client_builder.default_headers(headers);
188
189 Ok(client_builder.build()?)
190 }
191
192 fn connect_http(&self) -> Result<InnerTransport, RuntimeTransportError> {
194 let client = self.reqwest_client()?;
195 Ok(InnerTransport::Http(Http::with_client(client, self.url.clone())))
196 }
197
198 async fn connect_ws(&self) -> Result<InnerTransport, RuntimeTransportError> {
200 let auth = self.jwt.as_ref().and_then(|jwt| build_auth(jwt.clone()).ok());
201 let mut ws = WsConnect::new(self.url.to_string());
202 if let Some(auth) = auth {
203 ws = ws.with_auth(auth);
204 };
205 let service = ws
206 .into_service()
207 .await
208 .map_err(|e| RuntimeTransportError::TransportError(e, self.url.to_string()))?;
209 Ok(InnerTransport::Ws(service))
210 }
211
212 async fn connect_ipc(&self) -> Result<InnerTransport, RuntimeTransportError> {
214 let path = url_to_file_path(&self.url)
215 .map_err(|_| RuntimeTransportError::BadPath(self.url.to_string()))?;
216 let ipc_connector = IpcConnect::new(path.clone());
217 let ipc = ipc_connector.into_service().await.map_err(|e| {
218 RuntimeTransportError::TransportError(e, path.clone().display().to_string())
219 })?;
220 Ok(InnerTransport::Ipc(ipc))
221 }
222
223 pub fn request(&self, req: RequestPacket) -> TransportFut<'static> {
231 let this = self.clone();
232 Box::pin(async move {
233 let mut inner = this.inner.read().await;
234 if inner.is_none() {
235 drop(inner);
236 {
237 let mut inner_mut = this.inner.write().await;
238 if inner_mut.is_none() {
239 *inner_mut =
240 Some(this.connect().await.map_err(TransportErrorKind::custom)?);
241 }
242 }
243 inner = this.inner.read().await;
244 }
245
246 match inner.clone().expect("must've been initialized") {
248 InnerTransport::Http(mut http) => http.call(req),
249 InnerTransport::Ws(mut ws) => ws.call(req),
250 InnerTransport::Ipc(mut ipc) => ipc.call(req),
251 }
252 .await
253 })
254 }
255
256 pub fn boxed(self) -> BoxTransport
258 where
259 Self: Sized + Clone + Send + Sync + 'static,
260 {
261 BoxTransport::new(self)
262 }
263}
264
265impl tower::Service<RequestPacket> for RuntimeTransport {
266 type Response = ResponsePacket;
267 type Error = TransportError;
268 type Future = TransportFut<'static>;
269
270 #[inline]
271 fn poll_ready(
272 &mut self,
273 _cx: &mut std::task::Context<'_>,
274 ) -> std::task::Poll<Result<(), Self::Error>> {
275 std::task::Poll::Ready(Ok(()))
276 }
277
278 #[inline]
279 fn call(&mut self, req: RequestPacket) -> Self::Future {
280 self.request(req)
281 }
282}
283
284impl tower::Service<RequestPacket> for &RuntimeTransport {
285 type Response = ResponsePacket;
286 type Error = TransportError;
287 type Future = TransportFut<'static>;
288
289 #[inline]
290 fn poll_ready(
291 &mut self,
292 _cx: &mut std::task::Context<'_>,
293 ) -> std::task::Poll<Result<(), Self::Error>> {
294 std::task::Poll::Ready(Ok(()))
295 }
296
297 #[inline]
298 fn call(&mut self, req: RequestPacket) -> Self::Future {
299 self.request(req)
300 }
301}
302
303fn build_auth(jwt: String) -> eyre::Result<Authorization> {
304 let secret = JwtSecret::from_hex(jwt)?;
306 let claims = Claims::default();
307 let token = secret.encode(&claims)?;
308
309 let auth = Authorization::Bearer(token);
310
311 Ok(auth)
312}
313
314#[cfg(windows)]
315fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
316 const PREFIX: &str = "file:///pipe/";
317
318 let url_str = url.as_str();
319
320 if let Some(pipe_name) = url_str.strip_prefix(PREFIX) {
321 let pipe_path = format!(r"\\.\pipe\{pipe_name}");
322 return Ok(PathBuf::from(pipe_path));
323 }
324
325 url.to_file_path()
326}
327
328#[cfg(not(windows))]
329fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
330 url.to_file_path()
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use reqwest::header::HeaderMap;
337
338 #[tokio::test]
339 async fn test_user_agent_header() {
340 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
341 let url = Url::parse(&format!("http://{}", listener.local_addr().unwrap())).unwrap();
342
343 let http_handler = axum::routing::get(|actual_headers: HeaderMap| {
344 let user_agent = HeaderName::from_str("User-Agent").unwrap();
345 assert_eq!(actual_headers[user_agent], HeaderValue::from_str("test-agent").unwrap());
346
347 async { "" }
348 });
349
350 let server_task = tokio::spawn(async move {
351 axum::serve(listener, http_handler.into_make_service()).await.unwrap()
352 });
353
354 let transport = RuntimeTransportBuilder::new(url.clone())
355 .with_headers(vec!["User-Agent: test-agent".to_string()])
356 .build();
357 let inner = transport.connect_http().unwrap();
358
359 match inner {
360 InnerTransport::Http(http) => {
361 let _ = http.client().get(url).send().await.unwrap();
362
363 }
365 _ => unreachable!(),
366 }
367
368 server_task.abort();
369 }
370}