anvil_server/
lib.rs

1//! Bootstrap [axum] RPC servers.
2
3#![cfg_attr(not(test), warn(unused_crate_dependencies))]
4#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
5
6#[macro_use]
7extern crate tracing;
8
9use anvil_rpc::{
10    error::RpcError,
11    request::RpcMethodCall,
12    response::{ResponseResult, RpcResponse},
13};
14use axum::{
15    extract::DefaultBodyLimit,
16    http::{header, HeaderValue, Method},
17    routing::{post, MethodRouter},
18    Router,
19};
20use serde::de::DeserializeOwned;
21use std::fmt;
22use tower_http::{cors::CorsLayer, trace::TraceLayer};
23
24mod config;
25pub use config::ServerConfig;
26
27mod error;
28mod handler;
29
30mod pubsub;
31pub use pubsub::{PubSubContext, PubSubRpcHandler};
32
33mod ws;
34
35#[cfg(feature = "ipc")]
36pub mod ipc;
37
38/// Configures an [`axum::Router`] that handles JSON-RPC calls via both HTTP and WS.
39pub fn http_ws_router<Http, Ws>(config: ServerConfig, http: Http, ws: Ws) -> Router
40where
41    Http: RpcHandler,
42    Ws: PubSubRpcHandler,
43{
44    router_inner(config, post(handler::handle).get(ws::handle_ws), (http, ws))
45}
46
47/// Configures an [`axum::Router`] that handles JSON-RPC calls via HTTP.
48pub fn http_router<Http>(config: ServerConfig, http: Http) -> Router
49where
50    Http: RpcHandler,
51{
52    router_inner(config, post(handler::handle), (http, ()))
53}
54
55fn router_inner<S: Clone + Send + Sync + 'static>(
56    config: ServerConfig,
57    root_method_router: MethodRouter<S>,
58    state: S,
59) -> Router {
60    let ServerConfig { allow_origin, no_cors, no_request_size_limit } = config;
61
62    let mut router = Router::new()
63        .route("/", root_method_router)
64        .with_state(state)
65        .layer(TraceLayer::new_for_http());
66    if !no_cors {
67        // See [`tower_http::cors`](https://docs.rs/tower-http/latest/tower_http/cors/index.html)
68        // for more details.
69        router = router.layer(
70            CorsLayer::new()
71                .allow_origin(allow_origin.0)
72                .allow_headers([header::CONTENT_TYPE])
73                .allow_methods([Method::GET, Method::POST]),
74        );
75    }
76    if no_request_size_limit {
77        router = router.layer(DefaultBodyLimit::disable());
78    }
79    router
80}
81
82/// Helper trait that is used to execute ethereum rpc calls
83#[async_trait::async_trait]
84pub trait RpcHandler: Clone + Send + Sync + 'static {
85    /// The request type to expect
86    type Request: DeserializeOwned + Send + Sync + fmt::Debug;
87
88    /// Invoked when the request was received
89    async fn on_request(&self, request: Self::Request) -> ResponseResult;
90
91    /// Invoked for every incoming `RpcMethodCall`
92    ///
93    /// This will attempt to deserialize a `{ "method" : "<name>", "params": "<params>" }` message
94    /// into the `Request` type of this handler. If a `Request` instance was deserialized
95    /// successfully, [`Self::on_request`] will be invoked.
96    ///
97    /// **Note**: override this function if the expected `Request` deviates from `{ "method" :
98    /// "<name>", "params": "<params>" }`
99    async fn on_call(&self, call: RpcMethodCall) -> RpcResponse {
100        trace!(target: "rpc",  id = ?call.id , method = ?call.method, params = ?call.params, "received method call");
101        let RpcMethodCall { method, params, id, .. } = call;
102
103        let params: serde_json::Value = params.into();
104        let call = serde_json::json!({
105            "method": &method,
106            "params": params
107        });
108
109        match serde_json::from_value::<Self::Request>(call) {
110            Ok(req) => {
111                let result = self.on_request(req).await;
112                RpcResponse::new(id, result)
113            }
114            Err(err) => {
115                let err = err.to_string();
116                if err.contains("unknown variant") {
117                    error!(target: "rpc", ?method, "failed to deserialize method due to unknown variant");
118                    RpcResponse::new(id, RpcError::method_not_found())
119                } else {
120                    error!(target: "rpc", ?method, ?err, "failed to deserialize method");
121                    RpcResponse::new(id, RpcError::invalid_params(err))
122                }
123            }
124        }
125    }
126}