1use crate::{DEFAULT_USER_AGENT, REQUEST_TIMEOUT};
6use alloy_json_rpc::{RequestPacket, ResponsePacket};
7use alloy_pubsub::{PubSubConnect, PubSubFrontend};
8use alloy_rpc_types::engine::{Claims, JwtSecret};
9use alloy_transport::{
10 Authorization, BoxTransport, TransportError, TransportErrorKind, TransportFut,
11};
12use alloy_transport_http::Http;
13use alloy_transport_ipc::IpcConnect;
14use alloy_transport_ws::WsConnect;
15use reqwest::header::{HeaderName, HeaderValue};
16use std::{fmt, path::PathBuf, str::FromStr, sync::Arc};
17use thiserror::Error;
18use tokio::sync::RwLock;
19use tower::Service;
20use url::Url;
21
22#[derive(Clone, Debug)]
25pub enum InnerTransport {
26 Http(Http<reqwest::Client>),
28 Ws(PubSubFrontend),
30 Ipc(PubSubFrontend),
32}
33
34#[derive(Error, Debug)]
36pub enum RuntimeTransportError {
37 #[error("Internal transport error: {0} with {1}")]
39 TransportError(TransportError, String),
40
41 #[error("URL scheme is not supported: {0}")]
43 BadScheme(String),
44
45 #[error("Invalid HTTP header: {0}")]
47 BadHeader(String),
48
49 #[error("Invalid IPC file path: {0}")]
51 BadPath(String),
52
53 #[error(transparent)]
55 HttpConstructionError(#[from] reqwest::Error),
56
57 #[error("Invalid JWT: {0}")]
59 InvalidJwt(String),
60}
61
62#[derive(Clone, Debug)]
70pub struct RuntimeTransport {
71 inner: Arc<RwLock<Option<InnerTransport>>>,
73 url: Url,
75 headers: Vec<String>,
77 jwt: Option<String>,
79 timeout: std::time::Duration,
81 accept_invalid_certs: bool,
83}
84
85#[derive(Debug)]
87pub struct RuntimeTransportBuilder {
88 url: Url,
89 headers: Vec<String>,
90 jwt: Option<String>,
91 timeout: std::time::Duration,
92 accept_invalid_certs: bool,
93}
94
95impl RuntimeTransportBuilder {
96 pub fn new(url: Url) -> Self {
98 Self {
99 url,
100 headers: vec![],
101 jwt: None,
102 timeout: REQUEST_TIMEOUT,
103 accept_invalid_certs: false,
104 }
105 }
106
107 pub fn with_headers(mut self, headers: Vec<String>) -> Self {
109 self.headers = headers;
110 self
111 }
112
113 pub fn with_jwt(mut self, jwt: Option<String>) -> Self {
115 self.jwt = jwt;
116 self
117 }
118
119 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
121 self.timeout = timeout;
122 self
123 }
124
125 pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
127 self.accept_invalid_certs = accept_invalid_certs;
128 self
129 }
130
131 pub fn build(self) -> RuntimeTransport {
134 RuntimeTransport {
135 inner: Arc::new(RwLock::new(None)),
136 url: self.url,
137 headers: self.headers,
138 jwt: self.jwt,
139 timeout: self.timeout,
140 accept_invalid_certs: self.accept_invalid_certs,
141 }
142 }
143}
144
145impl fmt::Display for RuntimeTransport {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 write!(f, "RuntimeTransport {}", self.url)
148 }
149}
150
151impl RuntimeTransport {
152 pub async fn connect(&self) -> Result<InnerTransport, RuntimeTransportError> {
154 match self.url.scheme() {
155 "http" | "https" => self.connect_http(),
156 "ws" | "wss" => self.connect_ws().await,
157 "file" => self.connect_ipc().await,
158 _ => Err(RuntimeTransportError::BadScheme(self.url.scheme().to_string())),
159 }
160 }
161
162 pub fn reqwest_client(&self) -> Result<reqwest::Client, RuntimeTransportError> {
164 let mut client_builder = reqwest::Client::builder()
165 .timeout(self.timeout)
166 .tls_built_in_root_certs(self.url.scheme() == "https")
167 .danger_accept_invalid_certs(self.accept_invalid_certs);
168 let mut headers = reqwest::header::HeaderMap::new();
169
170 if let Some(jwt) = self.jwt.clone() {
172 let auth =
173 build_auth(jwt).map_err(|e| RuntimeTransportError::InvalidJwt(e.to_string()))?;
174
175 let mut auth_value: HeaderValue =
176 HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
177 auth_value.set_sensitive(true);
178
179 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
180 };
181
182 for header in &self.headers {
184 let make_err = || RuntimeTransportError::BadHeader(header.to_string());
185
186 let (key, val) = header.split_once(':').ok_or_else(make_err)?;
187
188 headers.insert(
189 HeaderName::from_str(key.trim()).map_err(|_| make_err())?,
190 HeaderValue::from_str(val.trim()).map_err(|_| make_err())?,
191 );
192 }
193
194 if !headers.contains_key(reqwest::header::USER_AGENT) {
195 headers.insert(
196 reqwest::header::USER_AGENT,
197 HeaderValue::from_str(DEFAULT_USER_AGENT)
198 .expect("User-Agent should be valid string"),
199 );
200 }
201
202 client_builder = client_builder.default_headers(headers);
203
204 Ok(client_builder.build()?)
205 }
206
207 fn connect_http(&self) -> Result<InnerTransport, RuntimeTransportError> {
209 let client = self.reqwest_client()?;
210 Ok(InnerTransport::Http(Http::with_client(client, self.url.clone())))
211 }
212
213 async fn connect_ws(&self) -> Result<InnerTransport, RuntimeTransportError> {
215 let auth = self.jwt.as_ref().and_then(|jwt| build_auth(jwt.clone()).ok());
216 let mut ws = WsConnect::new(self.url.to_string());
217 if let Some(auth) = auth {
218 ws = ws.with_auth(auth);
219 };
220 let service = ws
221 .into_service()
222 .await
223 .map_err(|e| RuntimeTransportError::TransportError(e, self.url.to_string()))?;
224 Ok(InnerTransport::Ws(service))
225 }
226
227 async fn connect_ipc(&self) -> Result<InnerTransport, RuntimeTransportError> {
229 let path = url_to_file_path(&self.url)
230 .map_err(|_| RuntimeTransportError::BadPath(self.url.to_string()))?;
231 let ipc_connector = IpcConnect::new(path.clone());
232 let ipc = ipc_connector.into_service().await.map_err(|e| {
233 RuntimeTransportError::TransportError(e, path.clone().display().to_string())
234 })?;
235 Ok(InnerTransport::Ipc(ipc))
236 }
237
238 pub fn request(&self, req: RequestPacket) -> TransportFut<'static> {
246 let this = self.clone();
247 Box::pin(async move {
248 let mut inner = this.inner.read().await;
249 if inner.is_none() {
250 drop(inner);
251 {
252 let mut inner_mut = this.inner.write().await;
253 if inner_mut.is_none() {
254 *inner_mut =
255 Some(this.connect().await.map_err(TransportErrorKind::custom)?);
256 }
257 }
258 inner = this.inner.read().await;
259 }
260
261 match inner.clone().expect("must've been initialized") {
263 InnerTransport::Http(mut http) => http.call(req),
264 InnerTransport::Ws(mut ws) => ws.call(req),
265 InnerTransport::Ipc(mut ipc) => ipc.call(req),
266 }
267 .await
268 })
269 }
270
271 pub fn boxed(self) -> BoxTransport
273 where
274 Self: Sized + Clone + Send + Sync + 'static,
275 {
276 BoxTransport::new(self)
277 }
278}
279
280impl tower::Service<RequestPacket> for RuntimeTransport {
281 type Response = ResponsePacket;
282 type Error = TransportError;
283 type Future = TransportFut<'static>;
284
285 #[inline]
286 fn poll_ready(
287 &mut self,
288 _cx: &mut std::task::Context<'_>,
289 ) -> std::task::Poll<Result<(), Self::Error>> {
290 std::task::Poll::Ready(Ok(()))
291 }
292
293 #[inline]
294 fn call(&mut self, req: RequestPacket) -> Self::Future {
295 self.request(req)
296 }
297}
298
299impl tower::Service<RequestPacket> for &RuntimeTransport {
300 type Response = ResponsePacket;
301 type Error = TransportError;
302 type Future = TransportFut<'static>;
303
304 #[inline]
305 fn poll_ready(
306 &mut self,
307 _cx: &mut std::task::Context<'_>,
308 ) -> std::task::Poll<Result<(), Self::Error>> {
309 std::task::Poll::Ready(Ok(()))
310 }
311
312 #[inline]
313 fn call(&mut self, req: RequestPacket) -> Self::Future {
314 self.request(req)
315 }
316}
317
318fn build_auth(jwt: String) -> eyre::Result<Authorization> {
319 let secret = JwtSecret::from_hex(jwt)?;
321 let claims = Claims::default();
322 let token = secret.encode(&claims)?;
323
324 let auth = Authorization::Bearer(token);
325
326 Ok(auth)
327}
328
329#[cfg(windows)]
330fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
331 const PREFIX: &str = "file:///pipe/";
332
333 let url_str = url.as_str();
334
335 if let Some(pipe_name) = url_str.strip_prefix(PREFIX) {
336 let pipe_path = format!(r"\\.\pipe\{pipe_name}");
337 return Ok(PathBuf::from(pipe_path));
338 }
339
340 url.to_file_path()
341}
342
343#[cfg(not(windows))]
344fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
345 url.to_file_path()
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use reqwest::header::HeaderMap;
352
353 #[tokio::test]
354 async fn test_user_agent_header() {
355 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
356 let url = Url::parse(&format!("http://{}", listener.local_addr().unwrap())).unwrap();
357
358 let http_handler = axum::routing::get(|actual_headers: HeaderMap| {
359 let user_agent = HeaderName::from_str("User-Agent").unwrap();
360 assert_eq!(actual_headers[user_agent], HeaderValue::from_str("test-agent").unwrap());
361
362 async { "" }
363 });
364
365 let server_task = tokio::spawn(async move {
366 axum::serve(listener, http_handler.into_make_service()).await.unwrap()
367 });
368
369 let transport = RuntimeTransportBuilder::new(url.clone())
370 .with_headers(vec!["User-Agent: test-agent".to_string()])
371 .build();
372 let inner = transport.connect_http().unwrap();
373
374 match inner {
375 InnerTransport::Http(http) => {
376 let _ = http.client().get(url).send().await.unwrap();
377
378 }
380 _ => unreachable!(),
381 }
382
383 server_task.abort();
384 }
385}