use crate::{DEFAULT_USER_AGENT, REQUEST_TIMEOUT};
use alloy_json_rpc::{RequestPacket, ResponsePacket};
use alloy_pubsub::{PubSubConnect, PubSubFrontend};
use alloy_rpc_types::engine::{Claims, JwtSecret};
use alloy_transport::{
Authorization, BoxTransport, TransportError, TransportErrorKind, TransportFut,
};
use alloy_transport_http::Http;
use alloy_transport_ipc::IpcConnect;
use alloy_transport_ws::WsConnect;
use reqwest::header::{HeaderName, HeaderValue};
use std::{fmt, path::PathBuf, str::FromStr, sync::Arc};
use thiserror::Error;
use tokio::sync::RwLock;
use tower::Service;
use url::Url;
#[derive(Clone, Debug)]
pub enum InnerTransport {
Http(Http<reqwest::Client>),
Ws(PubSubFrontend),
Ipc(PubSubFrontend),
}
#[derive(Error, Debug)]
pub enum RuntimeTransportError {
#[error("Internal transport error: {0} with {1}")]
TransportError(TransportError, String),
#[error("Failed to lock the transport")]
LockError,
#[error("URL scheme is not supported: {0}")]
BadScheme(String),
#[error("Invalid HTTP header: {0}")]
BadHeader(String),
#[error("Invalid IPC file path: {0}")]
BadPath(String),
#[error(transparent)]
HttpConstructionError(#[from] reqwest::Error),
#[error("Invalid JWT: {0}")]
InvalidJwt(String),
}
#[derive(Clone, Debug, Error)]
pub struct RuntimeTransport {
inner: Arc<RwLock<Option<InnerTransport>>>,
url: Url,
headers: Vec<String>,
jwt: Option<String>,
timeout: std::time::Duration,
}
#[derive(Debug)]
pub struct RuntimeTransportBuilder {
url: Url,
headers: Vec<String>,
jwt: Option<String>,
timeout: std::time::Duration,
}
impl RuntimeTransportBuilder {
pub fn new(url: Url) -> Self {
Self { url, headers: vec![], jwt: None, timeout: REQUEST_TIMEOUT }
}
pub fn with_headers(mut self, headers: Vec<String>) -> Self {
self.headers = headers;
self
}
pub fn with_jwt(mut self, jwt: Option<String>) -> Self {
self.jwt = jwt;
self
}
pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = timeout;
self
}
pub fn build(self) -> RuntimeTransport {
RuntimeTransport {
inner: Arc::new(RwLock::new(None)),
url: self.url,
headers: self.headers,
jwt: self.jwt,
timeout: self.timeout,
}
}
}
impl fmt::Display for RuntimeTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RuntimeTransport {}", self.url)
}
}
impl RuntimeTransport {
pub async fn connect(&self) -> Result<InnerTransport, RuntimeTransportError> {
match self.url.scheme() {
"http" | "https" => self.connect_http().await,
"ws" | "wss" => self.connect_ws().await,
"file" => self.connect_ipc().await,
_ => Err(RuntimeTransportError::BadScheme(self.url.scheme().to_string())),
}
}
async fn connect_http(&self) -> Result<InnerTransport, RuntimeTransportError> {
let mut client_builder = reqwest::Client::builder()
.timeout(self.timeout)
.tls_built_in_root_certs(self.url.scheme() == "https");
let mut headers = reqwest::header::HeaderMap::new();
if let Some(jwt) = self.jwt.clone() {
let auth =
build_auth(jwt).map_err(|e| RuntimeTransportError::InvalidJwt(e.to_string()))?;
let mut auth_value: HeaderValue =
HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
auth_value.set_sensitive(true);
headers.insert(reqwest::header::AUTHORIZATION, auth_value);
};
for header in self.headers.iter() {
let make_err = || RuntimeTransportError::BadHeader(header.to_string());
let (key, val) = header.split_once(':').ok_or_else(make_err)?;
headers.insert(
HeaderName::from_str(key.trim()).map_err(|_| make_err())?,
HeaderValue::from_str(val.trim()).map_err(|_| make_err())?,
);
}
if !headers.iter().any(|(k, _v)| k.as_str().starts_with("User-Agent:")) {
headers.insert(
reqwest::header::USER_AGENT,
HeaderValue::from_str(DEFAULT_USER_AGENT)
.expect("User-Agent should be valid string"),
);
}
client_builder = client_builder.default_headers(headers);
let client =
client_builder.build().map_err(RuntimeTransportError::HttpConstructionError)?;
Ok(InnerTransport::Http(Http::with_client(client, self.url.clone())))
}
async fn connect_ws(&self) -> Result<InnerTransport, RuntimeTransportError> {
let auth = self.jwt.as_ref().and_then(|jwt| build_auth(jwt.clone()).ok());
let ws = WsConnect { url: self.url.to_string(), auth, config: None }
.into_service()
.await
.map_err(|e| RuntimeTransportError::TransportError(e, self.url.to_string()))?;
Ok(InnerTransport::Ws(ws))
}
async fn connect_ipc(&self) -> Result<InnerTransport, RuntimeTransportError> {
let path = url_to_file_path(&self.url)
.map_err(|_| RuntimeTransportError::BadPath(self.url.to_string()))?;
let ipc_connector = IpcConnect::new(path.clone());
let ipc = ipc_connector.into_service().await.map_err(|e| {
RuntimeTransportError::TransportError(e, path.clone().display().to_string())
})?;
Ok(InnerTransport::Ipc(ipc))
}
pub fn request(&self, req: RequestPacket) -> TransportFut<'static> {
let this = self.clone();
Box::pin(async move {
let mut inner = this.inner.read().await;
if inner.is_none() {
drop(inner);
{
let mut inner_mut = this.inner.write().await;
if inner_mut.is_none() {
*inner_mut =
Some(this.connect().await.map_err(TransportErrorKind::custom)?);
}
}
inner = this.inner.read().await;
}
match inner.as_ref().expect("must've been initialized") {
InnerTransport::Http(http) => {
let mut http = http;
http.call(req)
}
InnerTransport::Ws(ws) => {
let mut ws = ws;
ws.call(req)
}
InnerTransport::Ipc(ipc) => {
let mut ipc = ipc;
ipc.call(req)
}
}
.await
})
}
pub fn boxed(self) -> BoxTransport
where
Self: Sized + Clone + Send + Sync + 'static,
{
BoxTransport::new(self)
}
}
impl tower::Service<RequestPacket> for RuntimeTransport {
type Response = ResponsePacket;
type Error = TransportError;
type Future = TransportFut<'static>;
#[inline]
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request(req)
}
}
impl tower::Service<RequestPacket> for &RuntimeTransport {
type Response = ResponsePacket;
type Error = TransportError;
type Future = TransportFut<'static>;
#[inline]
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request(req)
}
}
fn build_auth(jwt: String) -> eyre::Result<Authorization> {
let secret = JwtSecret::from_hex(jwt)?;
let claims = Claims::default();
let token = secret.encode(&claims)?;
let auth = Authorization::Bearer(token);
Ok(auth)
}
#[cfg(windows)]
fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
const PREFIX: &str = "file:///pipe/";
let url_str = url.as_str();
if url_str.starts_with(PREFIX) {
let pipe_name = &url_str[PREFIX.len()..];
let pipe_path = format!(r"\\.\pipe\{}", pipe_name);
return Ok(PathBuf::from(pipe_path));
}
url.to_file_path()
}
#[cfg(not(windows))]
fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
url.to_file_path()
}