1use crate::{DEFAULT_USER_AGENT, REQUEST_TIMEOUT};
6use alloy_json_rpc::{RequestPacket, ResponsePacket};
7use alloy_pubsub::{PubSubConnect, PubSubFrontend};
8use alloy_rpc_types::engine::{Claims, JwtSecret};
9use alloy_transport::{
10 Authorization, BoxTransport, TransportError, TransportErrorKind, TransportFut,
11};
12use alloy_transport_http::Http;
13use alloy_transport_ipc::IpcConnect;
14use alloy_transport_ws::WsConnect;
15use reqwest::header::{HeaderName, HeaderValue};
16use std::{fmt, path::PathBuf, str::FromStr, sync::Arc};
17use thiserror::Error;
18use tokio::sync::RwLock;
19use tower::Service;
20use url::Url;
21
22#[derive(Clone, Debug)]
25pub enum InnerTransport {
26 Http(Http<reqwest::Client>),
28 Ws(PubSubFrontend),
30 Ipc(PubSubFrontend),
32}
33
34#[derive(Error, Debug)]
36pub enum RuntimeTransportError {
37 #[error("Internal transport error: {0} with {1}")]
39 TransportError(TransportError, String),
40
41 #[error("URL scheme is not supported: {0}")]
43 BadScheme(String),
44
45 #[error("Invalid HTTP header: {0}")]
47 BadHeader(String),
48
49 #[error("Invalid IPC file path: {0}")]
51 BadPath(String),
52
53 #[error(transparent)]
55 HttpConstructionError(#[from] reqwest::Error),
56
57 #[error("Invalid JWT: {0}")]
59 InvalidJwt(String),
60}
61
62#[derive(Clone, Debug)]
70pub struct RuntimeTransport {
71 inner: Arc<RwLock<Option<InnerTransport>>>,
73 url: Url,
75 headers: Vec<String>,
77 jwt: Option<String>,
79 timeout: std::time::Duration,
81 accept_invalid_certs: bool,
83 no_proxy: bool,
85}
86
87#[derive(Debug)]
89pub struct RuntimeTransportBuilder {
90 url: Url,
91 headers: Vec<String>,
92 jwt: Option<String>,
93 timeout: std::time::Duration,
94 accept_invalid_certs: bool,
95 no_proxy: bool,
96}
97
98impl RuntimeTransportBuilder {
99 pub fn new(url: Url) -> Self {
101 Self {
102 url,
103 headers: vec![],
104 jwt: None,
105 timeout: REQUEST_TIMEOUT,
106 accept_invalid_certs: false,
107 no_proxy: false,
108 }
109 }
110
111 pub fn with_headers(mut self, headers: Vec<String>) -> Self {
113 self.headers = headers;
114 self
115 }
116
117 pub fn with_jwt(mut self, jwt: Option<String>) -> Self {
119 self.jwt = jwt;
120 self
121 }
122
123 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
125 self.timeout = timeout;
126 self
127 }
128
129 pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
131 self.accept_invalid_certs = accept_invalid_certs;
132 self
133 }
134
135 pub fn no_proxy(mut self, no_proxy: bool) -> Self {
140 self.no_proxy = no_proxy;
141 self
142 }
143
144 pub fn build(self) -> RuntimeTransport {
147 RuntimeTransport {
148 inner: Arc::new(RwLock::new(None)),
149 url: self.url,
150 headers: self.headers,
151 jwt: self.jwt,
152 timeout: self.timeout,
153 accept_invalid_certs: self.accept_invalid_certs,
154 no_proxy: self.no_proxy,
155 }
156 }
157}
158
159impl fmt::Display for RuntimeTransport {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 write!(f, "RuntimeTransport {}", self.url)
162 }
163}
164
165impl RuntimeTransport {
166 pub async fn connect(&self) -> Result<InnerTransport, RuntimeTransportError> {
168 match self.url.scheme() {
169 "http" | "https" => self.connect_http(),
170 "ws" | "wss" => self.connect_ws().await,
171 "file" => self.connect_ipc().await,
172 _ => Err(RuntimeTransportError::BadScheme(self.url.scheme().to_string())),
173 }
174 }
175
176 pub fn reqwest_client(&self) -> Result<reqwest::Client, RuntimeTransportError> {
178 let mut client_builder = reqwest::Client::builder()
179 .timeout(self.timeout)
180 .tls_built_in_root_certs(self.url.scheme() == "https")
181 .danger_accept_invalid_certs(self.accept_invalid_certs);
182
183 if self.no_proxy {
187 client_builder = client_builder.no_proxy();
188 }
189
190 let mut headers = reqwest::header::HeaderMap::new();
191
192 if let Some(jwt) = self.jwt.clone() {
194 let auth =
195 build_auth(jwt).map_err(|e| RuntimeTransportError::InvalidJwt(e.to_string()))?;
196
197 let mut auth_value: HeaderValue =
198 HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
199 auth_value.set_sensitive(true);
200
201 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
202 };
203
204 for header in &self.headers {
206 let make_err = || RuntimeTransportError::BadHeader(header.to_string());
207
208 let (key, val) = header.split_once(':').ok_or_else(make_err)?;
209
210 headers.insert(
211 HeaderName::from_str(key.trim()).map_err(|_| make_err())?,
212 HeaderValue::from_str(val.trim()).map_err(|_| make_err())?,
213 );
214 }
215
216 if !headers.contains_key(reqwest::header::USER_AGENT) {
217 headers.insert(
218 reqwest::header::USER_AGENT,
219 HeaderValue::from_str(DEFAULT_USER_AGENT)
220 .expect("User-Agent should be valid string"),
221 );
222 }
223
224 client_builder = client_builder.default_headers(headers);
225
226 Ok(client_builder.build()?)
227 }
228
229 fn connect_http(&self) -> Result<InnerTransport, RuntimeTransportError> {
231 let client = self.reqwest_client()?;
232 Ok(InnerTransport::Http(Http::with_client(client, self.url.clone())))
233 }
234
235 async fn connect_ws(&self) -> Result<InnerTransport, RuntimeTransportError> {
237 let auth = self.jwt.as_ref().and_then(|jwt| build_auth(jwt.clone()).ok());
238 let mut ws = WsConnect::new(self.url.to_string());
239 if let Some(auth) = auth {
240 ws = ws.with_auth(auth);
241 };
242 let service = ws
243 .into_service()
244 .await
245 .map_err(|e| RuntimeTransportError::TransportError(e, self.url.to_string()))?;
246 Ok(InnerTransport::Ws(service))
247 }
248
249 async fn connect_ipc(&self) -> Result<InnerTransport, RuntimeTransportError> {
251 let path = url_to_file_path(&self.url)
252 .map_err(|_| RuntimeTransportError::BadPath(self.url.to_string()))?;
253 let ipc_connector = IpcConnect::new(path.clone());
254 let ipc = ipc_connector.into_service().await.map_err(|e| {
255 RuntimeTransportError::TransportError(e, path.clone().display().to_string())
256 })?;
257 Ok(InnerTransport::Ipc(ipc))
258 }
259
260 pub fn request(&self, req: RequestPacket) -> TransportFut<'static> {
268 let this = self.clone();
269 Box::pin(async move {
270 let mut inner = this.inner.read().await;
271 if inner.is_none() {
272 drop(inner);
273 {
274 let mut inner_mut = this.inner.write().await;
275 if inner_mut.is_none() {
276 *inner_mut =
277 Some(this.connect().await.map_err(TransportErrorKind::custom)?);
278 }
279 }
280 inner = this.inner.read().await;
281 }
282
283 match inner.clone().expect("must've been initialized") {
285 InnerTransport::Http(mut http) => http.call(req),
286 InnerTransport::Ws(mut ws) => ws.call(req),
287 InnerTransport::Ipc(mut ipc) => ipc.call(req),
288 }
289 .await
290 })
291 }
292
293 pub fn boxed(self) -> BoxTransport
295 where
296 Self: Sized + Clone + Send + Sync + 'static,
297 {
298 BoxTransport::new(self)
299 }
300}
301
302impl tower::Service<RequestPacket> for RuntimeTransport {
303 type Response = ResponsePacket;
304 type Error = TransportError;
305 type Future = TransportFut<'static>;
306
307 #[inline]
308 fn poll_ready(
309 &mut self,
310 _cx: &mut std::task::Context<'_>,
311 ) -> std::task::Poll<Result<(), Self::Error>> {
312 std::task::Poll::Ready(Ok(()))
313 }
314
315 #[inline]
316 fn call(&mut self, req: RequestPacket) -> Self::Future {
317 self.request(req)
318 }
319}
320
321impl tower::Service<RequestPacket> for &RuntimeTransport {
322 type Response = ResponsePacket;
323 type Error = TransportError;
324 type Future = TransportFut<'static>;
325
326 #[inline]
327 fn poll_ready(
328 &mut self,
329 _cx: &mut std::task::Context<'_>,
330 ) -> std::task::Poll<Result<(), Self::Error>> {
331 std::task::Poll::Ready(Ok(()))
332 }
333
334 #[inline]
335 fn call(&mut self, req: RequestPacket) -> Self::Future {
336 self.request(req)
337 }
338}
339
340fn build_auth(jwt: String) -> eyre::Result<Authorization> {
341 let secret = JwtSecret::from_hex(jwt)?;
343 let claims = Claims::default();
344 let token = secret.encode(&claims)?;
345
346 let auth = Authorization::Bearer(token);
347
348 Ok(auth)
349}
350
351#[cfg(windows)]
352fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
353 const PREFIX: &str = "file:///pipe/";
354
355 let url_str = url.as_str();
356
357 if let Some(pipe_name) = url_str.strip_prefix(PREFIX) {
358 let pipe_path = format!(r"\\.\pipe\{pipe_name}");
359 return Ok(PathBuf::from(pipe_path));
360 }
361
362 url.to_file_path()
363}
364
365#[cfg(not(windows))]
366fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
367 url.to_file_path()
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use reqwest::header::HeaderMap;
374
375 #[tokio::test]
376 async fn test_user_agent_header() {
377 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
378 let url = Url::parse(&format!("http://{}", listener.local_addr().unwrap())).unwrap();
379
380 let http_handler = axum::routing::get(|actual_headers: HeaderMap| {
381 let user_agent = HeaderName::from_str("User-Agent").unwrap();
382 assert_eq!(actual_headers[user_agent], HeaderValue::from_str("test-agent").unwrap());
383
384 async { "" }
385 });
386
387 let server_task = tokio::spawn(async move {
388 axum::serve(listener, http_handler.into_make_service()).await.unwrap()
389 });
390
391 let transport = RuntimeTransportBuilder::new(url.clone())
392 .with_headers(vec!["User-Agent: test-agent".to_string()])
393 .build();
394 let inner = transport.connect_http().unwrap();
395
396 match inner {
397 InnerTransport::Http(http) => {
398 let _ = http.client().get(url).send().await.unwrap();
399
400 }
402 _ => unreachable!(),
403 }
404
405 server_task.abort();
406 }
407}