anvil_server/
config.rs

1use crate::HeaderValue;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::str::FromStr;
4
5/// Additional server options.
6#[derive(Clone, Debug, Serialize, Deserialize)]
7#[cfg_attr(feature = "clap", derive(clap::Parser), command(next_help_heading = "Server options"))]
8pub struct ServerConfig {
9    /// The cors `allow_origin` header
10    #[cfg_attr(feature = "clap", arg(long, default_value = "*"))]
11    pub allow_origin: HeaderValueWrapper,
12
13    /// Disable CORS.
14    #[cfg_attr(feature = "clap", arg(long, conflicts_with = "allow_origin"))]
15    pub no_cors: bool,
16
17    /// Disable the default request body size limit. At time of writing the default limit is 2MB.
18    #[cfg_attr(feature = "clap", arg(long))]
19    pub no_request_size_limit: bool,
20}
21
22impl ServerConfig {
23    /// Sets the "allow origin" header for CORS.
24    pub fn with_allow_origin(mut self, allow_origin: impl Into<HeaderValueWrapper>) -> Self {
25        self.allow_origin = allow_origin.into();
26        self
27    }
28
29    /// Whether to enable CORS.
30    pub fn set_cors(mut self, cors: bool) -> Self {
31        self.no_cors = !cors;
32        self
33    }
34}
35
36impl Default for ServerConfig {
37    fn default() -> Self {
38        Self {
39            allow_origin: "*".parse::<HeaderValue>().unwrap().into(),
40            no_cors: false,
41            no_request_size_limit: false,
42        }
43    }
44}
45
46#[derive(Clone, Debug)]
47pub struct HeaderValueWrapper(pub HeaderValue);
48
49impl FromStr for HeaderValueWrapper {
50    type Err = <HeaderValue as FromStr>::Err;
51
52    fn from_str(s: &str) -> Result<Self, Self::Err> {
53        Ok(Self(s.parse()?))
54    }
55}
56
57impl Serialize for HeaderValueWrapper {
58    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
59    where
60        S: Serializer,
61    {
62        serializer.serialize_str(self.0.to_str().map_err(serde::ser::Error::custom)?)
63    }
64}
65
66impl<'de> Deserialize<'de> for HeaderValueWrapper {
67    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
68    where
69        D: Deserializer<'de>,
70    {
71        let s = String::deserialize(deserializer)?;
72        Ok(Self(s.parse().map_err(serde::de::Error::custom)?))
73    }
74}
75
76impl std::ops::Deref for HeaderValueWrapper {
77    type Target = HeaderValue;
78
79    fn deref(&self) -> &Self::Target {
80        &self.0
81    }
82}
83
84impl From<HeaderValueWrapper> for HeaderValue {
85    fn from(wrapper: HeaderValueWrapper) -> Self {
86        wrapper.0
87    }
88}
89
90impl From<HeaderValue> for HeaderValueWrapper {
91    fn from(header: HeaderValue) -> Self {
92        Self(header)
93    }
94}