Skip to main content

shared/
payload.rs

1use crate::response::ApiResponse;
2use axum::{
3    body::Bytes,
4    extract::{FromRequest, OptionalFromRequest, Request},
5    response::IntoResponse,
6};
7use serde::de::DeserializeOwned;
8use std::{str::FromStr, sync::LazyLock};
9
10pub struct PayloadRejection(anyhow::Error);
11
12impl IntoResponse for PayloadRejection {
13    fn into_response(self) -> axum::response::Response {
14        ApiResponse::error(format!("invalid payload: {}", self.0))
15            .with_status(axum::http::StatusCode::BAD_REQUEST)
16            .into_response()
17    }
18}
19
20impl From<anyhow::Error> for PayloadRejection {
21    fn from(err: anyhow::Error) -> Self {
22        Self(err)
23    }
24}
25
26static AVAILABLE_DESERIALIZERS: LazyLock<[mime::Mime; 4]> = LazyLock::new(|| {
27    [
28        mime::APPLICATION_JSON,
29        mime::APPLICATION_MSGPACK,
30        mime::TEXT_XML,
31        mime::Mime::from_str("application/yaml").unwrap(),
32    ]
33});
34
35/// A small axum payload extractor with content negotiation based on the `Accept` header.
36pub struct Payload<T: DeserializeOwned>(pub T);
37
38impl<T: DeserializeOwned> Payload<T> {
39    pub fn into_inner(self) -> T {
40        self.0
41    }
42
43    pub fn from_bytes(
44        content_type: mime::Mime,
45        mut bytes: Bytes,
46    ) -> Result<Self, PayloadRejection> {
47        match content_type.essence_str() {
48            m if m == mime::APPLICATION_JSON.essence_str() => {
49                if bytes.is_empty() {
50                    bytes = Bytes::from_static(b"{}");
51                }
52
53                let value = serde_json::from_slice(&bytes).map_err(anyhow::Error::from)?;
54                Ok(Payload(value))
55            }
56            m if m == mime::APPLICATION_MSGPACK.essence_str() => {
57                if bytes.is_empty() {
58                    bytes = Bytes::from_static(&[0x80]);
59                }
60
61                let mut de = rmp_serde::Deserializer::new(bytes.as_ref()).with_human_readable();
62                let value = T::deserialize(&mut de).map_err(anyhow::Error::from)?;
63                Ok(Payload(value))
64            }
65            m if m == mime::TEXT_XML.essence_str() => {
66                if bytes.is_empty() {
67                    bytes = Bytes::from_static(b"<root></root>");
68                }
69
70                let value =
71                    serde_xml_rs::from_reader(bytes.as_ref()).map_err(anyhow::Error::from)?;
72                Ok(Payload(value))
73            }
74            "application/yaml" => {
75                if bytes.is_empty() {
76                    bytes = Bytes::from_static(b"{}");
77                }
78
79                let value = serde_norway::from_slice(&bytes).map_err(anyhow::Error::from)?;
80                Ok(Payload(value))
81            }
82            _ => Err(PayloadRejection(anyhow::anyhow!(
83                "unsupported content type"
84            ))),
85        }
86    }
87}
88
89impl<T: DeserializeOwned, S: Send + Sync> FromRequest<S> for Payload<T> {
90    type Rejection = PayloadRejection;
91
92    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
93        let content_type = req
94            .headers()
95            .get(axum::http::header::CONTENT_TYPE)
96            .and_then(|v| v.to_str().ok())
97            .and_then(|s| s.parse::<mime::Mime>().ok());
98
99        let Some(content_type) = content_type else {
100            return Err(PayloadRejection(anyhow::anyhow!(
101                "missing content type header"
102            )));
103        };
104
105        if !AVAILABLE_DESERIALIZERS.contains(&content_type) {
106            return Err(PayloadRejection(anyhow::anyhow!(
107                "unsupported content type"
108            )));
109        }
110
111        let bytes = match Bytes::from_request(req, state).await {
112            Ok(b) => b,
113            Err(_) => return Err(PayloadRejection(anyhow::anyhow!("failed to read body"))),
114        };
115        Self::from_bytes(content_type, bytes)
116    }
117}
118
119impl<T, S> OptionalFromRequest<S> for Payload<T>
120where
121    T: DeserializeOwned,
122    S: Send + Sync,
123{
124    type Rejection = PayloadRejection;
125
126    async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
127        let content_type = req
128            .headers()
129            .get(axum::http::header::CONTENT_TYPE)
130            .and_then(|v| v.to_str().ok())
131            .and_then(|s| s.parse::<mime::Mime>().ok());
132        let Some(content_type) = content_type else {
133            return Ok(None);
134        };
135
136        if !AVAILABLE_DESERIALIZERS.contains(&content_type) {
137            return Err(PayloadRejection(anyhow::anyhow!(
138                "unsupported content type"
139            )));
140        }
141
142        let bytes = match Bytes::from_request(req, state).await {
143            Ok(b) => b,
144            Err(_) => return Err(PayloadRejection(anyhow::anyhow!("failed to read body"))),
145        };
146        Self::from_bytes(content_type, bytes).map(Some)
147    }
148}