1use crate::response::ApiResponse;
2use axum::{
3 body::Bytes,
4 extract::{FromRequest, OptionalFromRequest, Request},
5 response::IntoResponse,
6};
7use serde::de::DeserializeOwned;
8use std::{str::FromStr, sync::LazyLock};
9
10pub struct PayloadRejection(anyhow::Error);
11
12impl IntoResponse for PayloadRejection {
13 fn into_response(self) -> axum::response::Response {
14 ApiResponse::error(format!("invalid payload: {}", self.0))
15 .with_status(axum::http::StatusCode::BAD_REQUEST)
16 .into_response()
17 }
18}
19
20impl From<anyhow::Error> for PayloadRejection {
21 fn from(err: anyhow::Error) -> Self {
22 Self(err)
23 }
24}
25
26static AVAILABLE_DESERIALIZERS: LazyLock<[mime::Mime; 4]> = LazyLock::new(|| {
27 [
28 mime::APPLICATION_JSON,
29 mime::APPLICATION_MSGPACK,
30 mime::TEXT_XML,
31 mime::Mime::from_str("application/yaml").unwrap(),
32 ]
33});
34
35pub struct Payload<T: DeserializeOwned>(pub T);
36
37impl<T: DeserializeOwned> Payload<T> {
38 pub fn into_inner(self) -> T {
39 self.0
40 }
41
42 pub fn from_bytes(content_type: mime::Mime, bytes: &Bytes) -> Result<Self, PayloadRejection> {
43 match content_type.essence_str() {
44 m if m == mime::APPLICATION_JSON.essence_str() => {
45 let value = serde_json::from_slice(bytes).map_err(anyhow::Error::from)?;
46 Ok(Payload(value))
47 }
48 m if m == mime::APPLICATION_MSGPACK.essence_str() => {
49 let mut de = rmp_serde::Deserializer::new(bytes.as_ref()).with_human_readable();
50 let value = T::deserialize(&mut de).map_err(anyhow::Error::from)?;
51 Ok(Payload(value))
52 }
53 m if m == mime::TEXT_XML.essence_str() => {
54 let value =
55 serde_xml_rs::from_reader(bytes.as_ref()).map_err(anyhow::Error::from)?;
56 Ok(Payload(value))
57 }
58 "application/yaml" => {
59 let value = serde_norway::from_slice(bytes).map_err(anyhow::Error::from)?;
60 Ok(Payload(value))
61 }
62 _ => Err(PayloadRejection(anyhow::anyhow!(
63 "unsupported content type"
64 ))),
65 }
66 }
67}
68
69impl<T: DeserializeOwned, S: Send + Sync> FromRequest<S> for Payload<T> {
70 type Rejection = PayloadRejection;
71
72 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
73 let content_type = req
74 .headers()
75 .get(axum::http::header::CONTENT_TYPE)
76 .and_then(|v| v.to_str().ok())
77 .and_then(|s| s.parse::<mime::Mime>().ok());
78
79 let Some(content_type) = content_type else {
80 return Err(PayloadRejection(anyhow::anyhow!("missing content type")));
81 };
82
83 if !AVAILABLE_DESERIALIZERS.contains(&content_type) {
84 return Err(PayloadRejection(anyhow::anyhow!(
85 "unsupported content type"
86 )));
87 }
88
89 let bytes = match Bytes::from_request(req, state).await {
90 Ok(b) => b,
91 Err(_) => return Err(PayloadRejection(anyhow::anyhow!("failed to read body"))),
92 };
93 Self::from_bytes(content_type, &bytes)
94 }
95}
96
97impl<T, S> OptionalFromRequest<S> for Payload<T>
98where
99 T: DeserializeOwned,
100 S: Send + Sync,
101{
102 type Rejection = PayloadRejection;
103
104 async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
105 let content_type = req
106 .headers()
107 .get(axum::http::header::CONTENT_TYPE)
108 .and_then(|v| v.to_str().ok())
109 .and_then(|s| s.parse::<mime::Mime>().ok());
110 let Some(content_type) = content_type else {
111 return Ok(None);
112 };
113
114 if !AVAILABLE_DESERIALIZERS.contains(&content_type) {
115 return Err(PayloadRejection(anyhow::anyhow!(
116 "unsupported content type"
117 )));
118 }
119
120 let bytes = match Bytes::from_request(req, state).await {
121 Ok(b) => b,
122 Err(_) => return Err(PayloadRejection(anyhow::anyhow!("failed to read body"))),
123 };
124 Self::from_bytes(content_type, &bytes).map(Some)
125 }
126}