shared/
payload.rs

1use crate::response::ApiResponse;
2use axum::{
3    body::Bytes,
4    extract::{FromRequest, OptionalFromRequest, Request},
5    response::IntoResponse,
6};
7use serde::de::DeserializeOwned;
8use std::{str::FromStr, sync::LazyLock};
9
10pub struct PayloadRejection(anyhow::Error);
11
12impl IntoResponse for PayloadRejection {
13    fn into_response(self) -> axum::response::Response {
14        ApiResponse::error(format!("invalid payload: {}", self.0))
15            .with_status(axum::http::StatusCode::BAD_REQUEST)
16            .into_response()
17    }
18}
19
20impl From<anyhow::Error> for PayloadRejection {
21    fn from(err: anyhow::Error) -> Self {
22        Self(err)
23    }
24}
25
26static AVAILABLE_DESERIALIZERS: LazyLock<[mime::Mime; 4]> = LazyLock::new(|| {
27    [
28        mime::APPLICATION_JSON,
29        mime::APPLICATION_MSGPACK,
30        mime::TEXT_XML,
31        mime::Mime::from_str("application/yaml").unwrap(),
32    ]
33});
34
35pub struct Payload<T: DeserializeOwned>(pub T);
36
37impl<T: DeserializeOwned> Payload<T> {
38    pub fn into_inner(self) -> T {
39        self.0
40    }
41
42    pub fn from_bytes(content_type: mime::Mime, bytes: &Bytes) -> Result<Self, PayloadRejection> {
43        match content_type.essence_str() {
44            m if m == mime::APPLICATION_JSON.essence_str() => {
45                let value = serde_json::from_slice(bytes).map_err(anyhow::Error::from)?;
46                Ok(Payload(value))
47            }
48            m if m == mime::APPLICATION_MSGPACK.essence_str() => {
49                let mut de = rmp_serde::Deserializer::new(bytes.as_ref()).with_human_readable();
50                let value = T::deserialize(&mut de).map_err(anyhow::Error::from)?;
51                Ok(Payload(value))
52            }
53            m if m == mime::TEXT_XML.essence_str() => {
54                let value =
55                    serde_xml_rs::from_reader(bytes.as_ref()).map_err(anyhow::Error::from)?;
56                Ok(Payload(value))
57            }
58            "application/yaml" => {
59                let value = serde_norway::from_slice(bytes).map_err(anyhow::Error::from)?;
60                Ok(Payload(value))
61            }
62            _ => Err(PayloadRejection(anyhow::anyhow!(
63                "unsupported content type"
64            ))),
65        }
66    }
67}
68
69impl<T: DeserializeOwned, S: Send + Sync> FromRequest<S> for Payload<T> {
70    type Rejection = PayloadRejection;
71
72    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
73        let content_type = req
74            .headers()
75            .get(axum::http::header::CONTENT_TYPE)
76            .and_then(|v| v.to_str().ok())
77            .and_then(|s| s.parse::<mime::Mime>().ok());
78
79        let Some(content_type) = content_type else {
80            return Err(PayloadRejection(anyhow::anyhow!("missing content type")));
81        };
82
83        if !AVAILABLE_DESERIALIZERS.contains(&content_type) {
84            return Err(PayloadRejection(anyhow::anyhow!(
85                "unsupported content type"
86            )));
87        }
88
89        let bytes = match Bytes::from_request(req, state).await {
90            Ok(b) => b,
91            Err(_) => return Err(PayloadRejection(anyhow::anyhow!("failed to read body"))),
92        };
93        Self::from_bytes(content_type, &bytes)
94    }
95}
96
97impl<T, S> OptionalFromRequest<S> for Payload<T>
98where
99    T: DeserializeOwned,
100    S: Send + Sync,
101{
102    type Rejection = PayloadRejection;
103
104    async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
105        let content_type = req
106            .headers()
107            .get(axum::http::header::CONTENT_TYPE)
108            .and_then(|v| v.to_str().ok())
109            .and_then(|s| s.parse::<mime::Mime>().ok());
110        let Some(content_type) = content_type else {
111            return Ok(None);
112        };
113
114        if !AVAILABLE_DESERIALIZERS.contains(&content_type) {
115            return Err(PayloadRejection(anyhow::anyhow!(
116                "unsupported content type"
117            )));
118        }
119
120        let bytes = match Bytes::from_request(req, state).await {
121            Ok(b) => b,
122            Err(_) => return Err(PayloadRejection(anyhow::anyhow!("failed to read body"))),
123        };
124        Self::from_bytes(content_type, &bytes).map(Some)
125    }
126}