1use crate::{DEFAULT_USER_AGENT, REQUEST_TIMEOUT};
6use alloy_json_rpc::{RequestPacket, ResponsePacket};
7use alloy_pubsub::{PubSubConnect, PubSubFrontend};
8use alloy_rpc_types::engine::{Claims, JwtSecret};
9use alloy_transport::{
10 Authorization, BoxTransport, TransportError, TransportErrorKind, TransportFut,
11};
12use alloy_transport_http::Http;
13use alloy_transport_ipc::IpcConnect;
14use alloy_transport_ws::WsConnect;
15use reqwest::header::{HeaderName, HeaderValue};
16use std::{fmt, path::PathBuf, str::FromStr, sync::Arc};
17use thiserror::Error;
18use tokio::sync::RwLock;
19use tower::Service;
20use url::Url;
21
22#[derive(Clone, Debug)]
25pub enum InnerTransport {
26 Http(Http<reqwest::Client>),
28 Ws(PubSubFrontend),
30 Ipc(PubSubFrontend),
32}
33
34#[derive(Error, Debug)]
36pub enum RuntimeTransportError {
37 #[error("Internal transport error: {0} with {1}")]
39 TransportError(TransportError, String),
40
41 #[error("URL scheme is not supported: {0}")]
43 BadScheme(String),
44
45 #[error("Invalid HTTP header: {0}")]
47 BadHeader(String),
48
49 #[error("Invalid IPC file path: {0}")]
51 BadPath(String),
52
53 #[error(transparent)]
55 HttpConstructionError(#[from] reqwest::Error),
56
57 #[error("Invalid JWT: {0}")]
59 InvalidJwt(String),
60}
61
62#[derive(Clone, Debug)]
70pub struct RuntimeTransport {
71 inner: Arc<RwLock<Option<InnerTransport>>>,
73 url: Url,
75 headers: Vec<String>,
77 jwt: Option<String>,
79 timeout: std::time::Duration,
81 accept_invalid_certs: bool,
83 no_proxy: bool,
85}
86
87#[derive(Debug)]
89pub struct RuntimeTransportBuilder {
90 url: Url,
91 headers: Vec<String>,
92 jwt: Option<String>,
93 timeout: std::time::Duration,
94 accept_invalid_certs: bool,
95 no_proxy: bool,
96}
97
98impl RuntimeTransportBuilder {
99 pub fn new(url: Url) -> Self {
101 Self {
102 url,
103 headers: vec![],
104 jwt: None,
105 timeout: REQUEST_TIMEOUT,
106 accept_invalid_certs: false,
107 no_proxy: false,
108 }
109 }
110
111 pub fn with_headers(mut self, headers: Vec<String>) -> Self {
113 self.headers = headers;
114 self
115 }
116
117 pub fn with_jwt(mut self, jwt: Option<String>) -> Self {
119 self.jwt = jwt;
120 self
121 }
122
123 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
125 self.timeout = timeout;
126 self
127 }
128
129 pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
131 self.accept_invalid_certs = accept_invalid_certs;
132 self
133 }
134
135 pub fn no_proxy(mut self, no_proxy: bool) -> Self {
140 self.no_proxy = no_proxy;
141 self
142 }
143
144 pub fn build(self) -> RuntimeTransport {
147 RuntimeTransport {
148 inner: Arc::new(RwLock::new(None)),
149 url: self.url,
150 headers: self.headers,
151 jwt: self.jwt,
152 timeout: self.timeout,
153 accept_invalid_certs: self.accept_invalid_certs,
154 no_proxy: self.no_proxy,
155 }
156 }
157}
158
159impl fmt::Display for RuntimeTransport {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 write!(f, "RuntimeTransport {}", self.url)
162 }
163}
164
165impl RuntimeTransport {
166 pub async fn connect(&self) -> Result<InnerTransport, RuntimeTransportError> {
168 match self.url.scheme() {
169 "http" | "https" => self.connect_http(),
170 "ws" | "wss" => self.connect_ws().await,
171 "file" => self.connect_ipc().await,
172 _ => Err(RuntimeTransportError::BadScheme(self.url.scheme().to_string())),
173 }
174 }
175
176 pub fn reqwest_client(&self) -> Result<reqwest::Client, RuntimeTransportError> {
178 let mut client_builder = reqwest::Client::builder()
179 .timeout(self.timeout)
180 .danger_accept_invalid_certs(self.accept_invalid_certs);
181
182 if self.no_proxy {
186 client_builder = client_builder.no_proxy();
187 }
188
189 let mut headers = reqwest::header::HeaderMap::new();
190
191 if let Some(jwt) = self.jwt.clone() {
193 let auth =
194 build_auth(jwt).map_err(|e| RuntimeTransportError::InvalidJwt(e.to_string()))?;
195
196 let mut auth_value: HeaderValue =
197 HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
198 auth_value.set_sensitive(true);
199
200 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
201 };
202
203 for header in &self.headers {
205 let make_err = || RuntimeTransportError::BadHeader(header.to_string());
206
207 let (key, val) = header.split_once(':').ok_or_else(make_err)?;
208
209 headers.insert(
210 HeaderName::from_str(key.trim()).map_err(|_| make_err())?,
211 HeaderValue::from_str(val.trim()).map_err(|_| make_err())?,
212 );
213 }
214
215 if !headers.contains_key(reqwest::header::USER_AGENT) {
216 headers.insert(
217 reqwest::header::USER_AGENT,
218 HeaderValue::from_str(DEFAULT_USER_AGENT)
219 .expect("User-Agent should be valid string"),
220 );
221 }
222
223 client_builder = client_builder.default_headers(headers);
224
225 Ok(client_builder.build()?)
226 }
227
228 fn connect_http(&self) -> Result<InnerTransport, RuntimeTransportError> {
230 let client = self.reqwest_client()?;
231 Ok(InnerTransport::Http(Http::with_client(client, self.url.clone())))
232 }
233
234 async fn connect_ws(&self) -> Result<InnerTransport, RuntimeTransportError> {
236 let auth = self.jwt.as_ref().and_then(|jwt| build_auth(jwt.clone()).ok());
237 let mut ws = WsConnect::new(self.url.to_string());
238 if let Some(auth) = auth {
239 ws = ws.with_auth(auth);
240 };
241 let service = ws
242 .into_service()
243 .await
244 .map_err(|e| RuntimeTransportError::TransportError(e, self.url.to_string()))?;
245 Ok(InnerTransport::Ws(service))
246 }
247
248 async fn connect_ipc(&self) -> Result<InnerTransport, RuntimeTransportError> {
250 let path = url_to_file_path(&self.url)
251 .map_err(|_| RuntimeTransportError::BadPath(self.url.to_string()))?;
252 let ipc_connector = IpcConnect::new(path.clone());
253 let ipc = ipc_connector.into_service().await.map_err(|e| {
254 RuntimeTransportError::TransportError(e, path.clone().display().to_string())
255 })?;
256 Ok(InnerTransport::Ipc(ipc))
257 }
258
259 pub fn request(&self, req: RequestPacket) -> TransportFut<'static> {
267 let this = self.clone();
268 Box::pin(async move {
269 let mut inner = this.inner.read().await;
270 if inner.is_none() {
271 drop(inner);
272 {
273 let mut inner_mut = this.inner.write().await;
274 if inner_mut.is_none() {
275 *inner_mut =
276 Some(this.connect().await.map_err(TransportErrorKind::custom)?);
277 }
278 }
279 inner = this.inner.read().await;
280 }
281
282 match inner.clone().expect("must've been initialized") {
284 InnerTransport::Http(mut http) => http.call(req),
285 InnerTransport::Ws(mut ws) => ws.call(req),
286 InnerTransport::Ipc(mut ipc) => ipc.call(req),
287 }
288 .await
289 })
290 }
291
292 pub fn boxed(self) -> BoxTransport
294 where
295 Self: Sized + Clone + Send + Sync + 'static,
296 {
297 BoxTransport::new(self)
298 }
299}
300
301impl tower::Service<RequestPacket> for RuntimeTransport {
302 type Response = ResponsePacket;
303 type Error = TransportError;
304 type Future = TransportFut<'static>;
305
306 #[inline]
307 fn poll_ready(
308 &mut self,
309 _cx: &mut std::task::Context<'_>,
310 ) -> std::task::Poll<Result<(), Self::Error>> {
311 std::task::Poll::Ready(Ok(()))
312 }
313
314 #[inline]
315 fn call(&mut self, req: RequestPacket) -> Self::Future {
316 self.request(req)
317 }
318}
319
320impl tower::Service<RequestPacket> for &RuntimeTransport {
321 type Response = ResponsePacket;
322 type Error = TransportError;
323 type Future = TransportFut<'static>;
324
325 #[inline]
326 fn poll_ready(
327 &mut self,
328 _cx: &mut std::task::Context<'_>,
329 ) -> std::task::Poll<Result<(), Self::Error>> {
330 std::task::Poll::Ready(Ok(()))
331 }
332
333 #[inline]
334 fn call(&mut self, req: RequestPacket) -> Self::Future {
335 self.request(req)
336 }
337}
338
339fn build_auth(jwt: String) -> eyre::Result<Authorization> {
340 let secret = JwtSecret::from_hex(jwt)?;
342 let claims = Claims::default();
343 let token = secret.encode(&claims)?;
344
345 let auth = Authorization::Bearer(token);
346
347 Ok(auth)
348}
349
350#[cfg(windows)]
351fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
352 const PREFIX: &str = "file:///pipe/";
353
354 let url_str = url.as_str();
355
356 if let Some(pipe_name) = url_str.strip_prefix(PREFIX) {
357 let pipe_path = format!(r"\\.\pipe\{pipe_name}");
358 return Ok(PathBuf::from(pipe_path));
359 }
360
361 url.to_file_path()
362}
363
364#[cfg(not(windows))]
365fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
366 url.to_file_path()
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use reqwest::header::HeaderMap;
373
374 #[tokio::test]
375 async fn test_user_agent_header() {
376 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
377 let url = Url::parse(&format!("http://{}", listener.local_addr().unwrap())).unwrap();
378
379 let http_handler = axum::routing::get(|actual_headers: HeaderMap| {
380 let user_agent = HeaderName::from_str("User-Agent").unwrap();
381 assert_eq!(actual_headers[user_agent], HeaderValue::from_str("test-agent").unwrap());
382
383 async { "" }
384 });
385
386 let server_task = tokio::spawn(async move {
387 axum::serve(listener, http_handler.into_make_service()).await.unwrap()
388 });
389
390 let transport = RuntimeTransportBuilder::new(url.clone())
391 .with_headers(vec!["User-Agent: test-agent".to_string()])
392 .build();
393 let inner = transport.connect_http().unwrap();
394
395 match inner {
396 InnerTransport::Http(http) => {
397 let _ = http.client().get(url).send().await.unwrap();
398
399 }
401 _ => unreachable!(),
402 }
403
404 server_task.abort();
405 }
406}