#![cfg_attr(not(test), warn(unused_crate_dependencies))]
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
#[macro_use]
extern crate tracing;
use anvil_rpc::{
error::RpcError,
request::RpcMethodCall,
response::{ResponseResult, RpcResponse},
};
use axum::{
extract::DefaultBodyLimit,
http::{header, HeaderValue, Method},
routing::{post, MethodRouter},
Router,
};
use serde::de::DeserializeOwned;
use std::fmt;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
mod config;
pub use config::ServerConfig;
mod error;
mod handler;
mod pubsub;
pub use pubsub::{PubSubContext, PubSubRpcHandler};
mod ws;
#[cfg(feature = "ipc")]
pub mod ipc;
pub fn http_ws_router<Http, Ws>(config: ServerConfig, http: Http, ws: Ws) -> Router
where
Http: RpcHandler,
Ws: PubSubRpcHandler,
{
router_inner(config, post(handler::handle).get(ws::handle_ws), (http, ws))
}
pub fn http_router<Http>(config: ServerConfig, http: Http) -> Router
where
Http: RpcHandler,
{
router_inner(config, post(handler::handle), (http, ()))
}
fn router_inner<S: Clone + Send + Sync + 'static>(
config: ServerConfig,
root_method_router: MethodRouter<S>,
state: S,
) -> Router {
let ServerConfig { allow_origin, no_cors, no_request_size_limit } = config;
let mut router = Router::new()
.route("/", root_method_router)
.with_state(state)
.layer(TraceLayer::new_for_http());
if !no_cors {
router = router.layer(
CorsLayer::new()
.allow_origin(allow_origin.0)
.allow_headers([header::CONTENT_TYPE])
.allow_methods([Method::GET, Method::POST]),
);
}
if no_request_size_limit {
router = router.layer(DefaultBodyLimit::disable());
}
router
}
#[async_trait::async_trait]
pub trait RpcHandler: Clone + Send + Sync + 'static {
type Request: DeserializeOwned + Send + Sync + fmt::Debug;
async fn on_request(&self, request: Self::Request) -> ResponseResult;
async fn on_call(&self, call: RpcMethodCall) -> RpcResponse {
trace!(target: "rpc", id = ?call.id , method = ?call.method, params = ?call.params, "received method call");
let RpcMethodCall { method, params, id, .. } = call;
let params: serde_json::Value = params.into();
let call = serde_json::json!({
"method": &method,
"params": params
});
match serde_json::from_value::<Self::Request>(call) {
Ok(req) => {
let result = self.on_request(req).await;
RpcResponse::new(id, result)
}
Err(err) => {
let err = err.to_string();
if err.contains("unknown variant") {
error!(target: "rpc", ?method, "failed to deserialize method due to unknown variant");
RpcResponse::new(id, RpcError::method_not_found())
} else {
error!(target: "rpc", ?method, ?err, "failed to deserialize method");
RpcResponse::new(id, RpcError::invalid_params(err))
}
}
}
}
}