anvil_server/
lib.rs
1#![cfg_attr(not(test), warn(unused_crate_dependencies))]
4#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
5
6#[macro_use]
7extern crate tracing;
8
9use anvil_rpc::{
10 error::RpcError,
11 request::RpcMethodCall,
12 response::{ResponseResult, RpcResponse},
13};
14use axum::{
15 extract::DefaultBodyLimit,
16 http::{header, HeaderValue, Method},
17 routing::{post, MethodRouter},
18 Router,
19};
20use serde::de::DeserializeOwned;
21use std::fmt;
22use tower_http::{cors::CorsLayer, trace::TraceLayer};
23
24mod config;
25pub use config::ServerConfig;
26
27mod error;
28mod handler;
29
30mod pubsub;
31pub use pubsub::{PubSubContext, PubSubRpcHandler};
32
33mod ws;
34
35#[cfg(feature = "ipc")]
36pub mod ipc;
37
38pub fn http_ws_router<Http, Ws>(config: ServerConfig, http: Http, ws: Ws) -> Router
40where
41 Http: RpcHandler,
42 Ws: PubSubRpcHandler,
43{
44 router_inner(config, post(handler::handle).get(ws::handle_ws), (http, ws))
45}
46
47pub fn http_router<Http>(config: ServerConfig, http: Http) -> Router
49where
50 Http: RpcHandler,
51{
52 router_inner(config, post(handler::handle), (http, ()))
53}
54
55fn router_inner<S: Clone + Send + Sync + 'static>(
56 config: ServerConfig,
57 root_method_router: MethodRouter<S>,
58 state: S,
59) -> Router {
60 let ServerConfig { allow_origin, no_cors, no_request_size_limit } = config;
61
62 let mut router = Router::new()
63 .route("/", root_method_router)
64 .with_state(state)
65 .layer(TraceLayer::new_for_http());
66 if !no_cors {
67 router = router.layer(
70 CorsLayer::new()
71 .allow_origin(allow_origin.0)
72 .allow_headers([header::CONTENT_TYPE])
73 .allow_methods([Method::GET, Method::POST]),
74 );
75 }
76 if no_request_size_limit {
77 router = router.layer(DefaultBodyLimit::disable());
78 }
79 router
80}
81
82#[async_trait::async_trait]
84pub trait RpcHandler: Clone + Send + Sync + 'static {
85 type Request: DeserializeOwned + Send + Sync + fmt::Debug;
87
88 async fn on_request(&self, request: Self::Request) -> ResponseResult;
90
91 async fn on_call(&self, call: RpcMethodCall) -> RpcResponse {
100 trace!(target: "rpc", id = ?call.id , method = ?call.method, params = ?call.params, "received method call");
101 let RpcMethodCall { method, params, id, .. } = call;
102
103 let params: serde_json::Value = params.into();
104 let call = serde_json::json!({
105 "method": &method,
106 "params": params
107 });
108
109 match serde_json::from_value::<Self::Request>(call) {
110 Ok(req) => {
111 let result = self.on_request(req).await;
112 RpcResponse::new(id, result)
113 }
114 Err(err) => {
115 let err = err.to_string();
116 if err.contains("unknown variant") {
117 error!(target: "rpc", ?method, "failed to deserialize method due to unknown variant");
118 RpcResponse::new(id, RpcError::method_not_found())
119 } else {
120 error!(target: "rpc", ?method, ?err, "failed to deserialize method");
121 RpcResponse::new(id, RpcError::invalid_params(err))
122 }
123 }
124 }
125 }
126}