1use crate::{DEFAULT_USER_AGENT, REQUEST_TIMEOUT};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use alloy_pubsub::{PubSubConnect, PubSubFrontend};
7use alloy_rpc_types::engine::{Claims, JwtSecret};
8use alloy_transport::{
9 Authorization, BoxTransport, TransportError, TransportErrorKind, TransportFut,
10};
11use alloy_transport_http::Http;
12use alloy_transport_ipc::IpcConnect;
13use alloy_transport_ws::WsConnect;
14use reqwest::header::{HeaderName, HeaderValue};
15use std::{fmt, path::PathBuf, str::FromStr, sync::Arc};
16use thiserror::Error;
17use tokio::sync::RwLock;
18use tower::Service;
19use url::Url;
20
21#[derive(Clone, Debug)]
24pub enum InnerTransport {
25 Http(Http<reqwest::Client>),
27 Ws(PubSubFrontend),
29 Ipc(PubSubFrontend),
31}
32
33#[derive(Error, Debug)]
35pub enum RuntimeTransportError {
36 #[error("Internal transport error: {0} with {1}")]
38 TransportError(TransportError, String),
39
40 #[error("Failed to lock the transport")]
42 LockError,
43
44 #[error("URL scheme is not supported: {0}")]
46 BadScheme(String),
47
48 #[error("Invalid HTTP header: {0}")]
50 BadHeader(String),
51
52 #[error("Invalid IPC file path: {0}")]
54 BadPath(String),
55
56 #[error(transparent)]
58 HttpConstructionError(#[from] reqwest::Error),
59
60 #[error("Invalid JWT: {0}")]
62 InvalidJwt(String),
63}
64
65#[derive(Clone, Debug, Error)]
72pub struct RuntimeTransport {
73 inner: Arc<RwLock<Option<InnerTransport>>>,
75 url: Url,
77 headers: Vec<String>,
79 jwt: Option<String>,
81 timeout: std::time::Duration,
83}
84
85#[derive(Debug)]
87pub struct RuntimeTransportBuilder {
88 url: Url,
89 headers: Vec<String>,
90 jwt: Option<String>,
91 timeout: std::time::Duration,
92}
93
94impl RuntimeTransportBuilder {
95 pub fn new(url: Url) -> Self {
97 Self { url, headers: vec![], jwt: None, timeout: REQUEST_TIMEOUT }
98 }
99
100 pub fn with_headers(mut self, headers: Vec<String>) -> Self {
102 self.headers = headers;
103 self
104 }
105
106 pub fn with_jwt(mut self, jwt: Option<String>) -> Self {
108 self.jwt = jwt;
109 self
110 }
111
112 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
114 self.timeout = timeout;
115 self
116 }
117
118 pub fn build(self) -> RuntimeTransport {
121 RuntimeTransport {
122 inner: Arc::new(RwLock::new(None)),
123 url: self.url,
124 headers: self.headers,
125 jwt: self.jwt,
126 timeout: self.timeout,
127 }
128 }
129}
130
131impl fmt::Display for RuntimeTransport {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 write!(f, "RuntimeTransport {}", self.url)
134 }
135}
136
137impl RuntimeTransport {
138 pub async fn connect(&self) -> Result<InnerTransport, RuntimeTransportError> {
140 match self.url.scheme() {
141 "http" | "https" => self.connect_http().await,
142 "ws" | "wss" => self.connect_ws().await,
143 "file" => self.connect_ipc().await,
144 _ => Err(RuntimeTransportError::BadScheme(self.url.scheme().to_string())),
145 }
146 }
147
148 pub fn reqwest_client(&self) -> Result<reqwest::Client, RuntimeTransportError> {
150 let mut client_builder = reqwest::Client::builder()
151 .timeout(self.timeout)
152 .tls_built_in_root_certs(self.url.scheme() == "https");
153 let mut headers = reqwest::header::HeaderMap::new();
154
155 if let Some(jwt) = self.jwt.clone() {
157 let auth =
158 build_auth(jwt).map_err(|e| RuntimeTransportError::InvalidJwt(e.to_string()))?;
159
160 let mut auth_value: HeaderValue =
161 HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
162 auth_value.set_sensitive(true);
163
164 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
165 };
166
167 for header in &self.headers {
169 let make_err = || RuntimeTransportError::BadHeader(header.to_string());
170
171 let (key, val) = header.split_once(':').ok_or_else(make_err)?;
172
173 headers.insert(
174 HeaderName::from_str(key.trim()).map_err(|_| make_err())?,
175 HeaderValue::from_str(val.trim()).map_err(|_| make_err())?,
176 );
177 }
178
179 if !headers.contains_key(reqwest::header::USER_AGENT) {
180 headers.insert(
181 reqwest::header::USER_AGENT,
182 HeaderValue::from_str(DEFAULT_USER_AGENT)
183 .expect("User-Agent should be valid string"),
184 );
185 }
186
187 client_builder = client_builder.default_headers(headers);
188
189 Ok(client_builder.build()?)
190 }
191
192 async fn connect_http(&self) -> Result<InnerTransport, RuntimeTransportError> {
194 let client = self.reqwest_client()?;
195 Ok(InnerTransport::Http(Http::with_client(client, self.url.clone())))
196 }
197
198 async fn connect_ws(&self) -> Result<InnerTransport, RuntimeTransportError> {
200 let auth = self.jwt.as_ref().and_then(|jwt| build_auth(jwt.clone()).ok());
201 let ws = WsConnect { url: self.url.to_string(), auth, config: None }
202 .into_service()
203 .await
204 .map_err(|e| RuntimeTransportError::TransportError(e, self.url.to_string()))?;
205 Ok(InnerTransport::Ws(ws))
206 }
207
208 async fn connect_ipc(&self) -> Result<InnerTransport, RuntimeTransportError> {
210 let path = url_to_file_path(&self.url)
211 .map_err(|_| RuntimeTransportError::BadPath(self.url.to_string()))?;
212 let ipc_connector = IpcConnect::new(path.clone());
213 let ipc = ipc_connector.into_service().await.map_err(|e| {
214 RuntimeTransportError::TransportError(e, path.clone().display().to_string())
215 })?;
216 Ok(InnerTransport::Ipc(ipc))
217 }
218
219 pub fn request(&self, req: RequestPacket) -> TransportFut<'static> {
227 let this = self.clone();
228 Box::pin(async move {
229 let mut inner = this.inner.read().await;
230 if inner.is_none() {
231 drop(inner);
232 {
233 let mut inner_mut = this.inner.write().await;
234 if inner_mut.is_none() {
235 *inner_mut =
236 Some(this.connect().await.map_err(TransportErrorKind::custom)?);
237 }
238 }
239 inner = this.inner.read().await;
240 }
241
242 match inner.clone().expect("must've been initialized") {
244 InnerTransport::Http(mut http) => http.call(req),
245 InnerTransport::Ws(mut ws) => ws.call(req),
246 InnerTransport::Ipc(mut ipc) => ipc.call(req),
247 }
248 .await
249 })
250 }
251
252 pub fn boxed(self) -> BoxTransport
254 where
255 Self: Sized + Clone + Send + Sync + 'static,
256 {
257 BoxTransport::new(self)
258 }
259}
260
261impl tower::Service<RequestPacket> for RuntimeTransport {
262 type Response = ResponsePacket;
263 type Error = TransportError;
264 type Future = TransportFut<'static>;
265
266 #[inline]
267 fn poll_ready(
268 &mut self,
269 _cx: &mut std::task::Context<'_>,
270 ) -> std::task::Poll<Result<(), Self::Error>> {
271 std::task::Poll::Ready(Ok(()))
272 }
273
274 #[inline]
275 fn call(&mut self, req: RequestPacket) -> Self::Future {
276 self.request(req)
277 }
278}
279
280impl tower::Service<RequestPacket> for &RuntimeTransport {
281 type Response = ResponsePacket;
282 type Error = TransportError;
283 type Future = TransportFut<'static>;
284
285 #[inline]
286 fn poll_ready(
287 &mut self,
288 _cx: &mut std::task::Context<'_>,
289 ) -> std::task::Poll<Result<(), Self::Error>> {
290 std::task::Poll::Ready(Ok(()))
291 }
292
293 #[inline]
294 fn call(&mut self, req: RequestPacket) -> Self::Future {
295 self.request(req)
296 }
297}
298
299fn build_auth(jwt: String) -> eyre::Result<Authorization> {
300 let secret = JwtSecret::from_hex(jwt)?;
302 let claims = Claims::default();
303 let token = secret.encode(&claims)?;
304
305 let auth = Authorization::Bearer(token);
306
307 Ok(auth)
308}
309
310#[cfg(windows)]
311fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
312 const PREFIX: &str = "file:///pipe/";
313
314 let url_str = url.as_str();
315
316 if url_str.starts_with(PREFIX) {
317 let pipe_name = &url_str[PREFIX.len()..];
318 let pipe_path = format!(r"\\.\pipe\{}", pipe_name);
319 return Ok(PathBuf::from(pipe_path));
320 }
321
322 url.to_file_path()
323}
324
325#[cfg(not(windows))]
326fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
327 url.to_file_path()
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use reqwest::header::HeaderMap;
334
335 #[tokio::test]
336 async fn test_user_agent_header() {
337 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
338 let url = Url::parse(&format!("http://{}", listener.local_addr().unwrap())).unwrap();
339
340 let http_handler = axum::routing::get(|actual_headers: HeaderMap| {
341 let user_agent = HeaderName::from_str("User-Agent").unwrap();
342 assert_eq!(actual_headers[user_agent], HeaderValue::from_str("test-agent").unwrap());
343
344 async { "" }
345 });
346
347 let server_task = tokio::spawn(async move {
348 axum::serve(listener, http_handler.into_make_service()).await.unwrap()
349 });
350
351 let transport = RuntimeTransportBuilder::new(url.clone())
352 .with_headers(vec!["User-Agent: test-agent".to_string()])
353 .build();
354 let inner = transport.connect_http().await.unwrap();
355
356 match inner {
357 InnerTransport::Http(http) => {
358 let _ = http.client().get(url).send().await.unwrap();
359
360 }
362 _ => unreachable!(),
363 }
364
365 server_task.abort();
366 }
367}