1use crate::response::ApiResponse;
2use axum::{
3 body::Bytes,
4 extract::{FromRequest, OptionalFromRequest, Request},
5 response::IntoResponse,
6};
7use serde::de::DeserializeOwned;
8use std::{str::FromStr, sync::LazyLock};
9
10pub struct PayloadRejection(anyhow::Error);
11
12impl IntoResponse for PayloadRejection {
13 fn into_response(self) -> axum::response::Response {
14 ApiResponse::error(format!("invalid payload: {}", self.0))
15 .with_status(axum::http::StatusCode::BAD_REQUEST)
16 .into_response()
17 }
18}
19
20impl From<anyhow::Error> for PayloadRejection {
21 fn from(err: anyhow::Error) -> Self {
22 Self(err)
23 }
24}
25
26static AVAILABLE_DESERIALIZERS: LazyLock<[mime::Mime; 4]> = LazyLock::new(|| {
27 [
28 mime::APPLICATION_JSON,
29 mime::APPLICATION_MSGPACK,
30 mime::TEXT_XML,
31 mime::Mime::from_str("application/yaml").unwrap(),
32 ]
33});
34
35pub struct Payload<T: DeserializeOwned>(pub T);
37
38impl<T: DeserializeOwned> Payload<T> {
39 pub fn into_inner(self) -> T {
40 self.0
41 }
42
43 pub fn from_bytes(
44 content_type: mime::Mime,
45 mut bytes: Bytes,
46 ) -> Result<Self, PayloadRejection> {
47 match content_type.essence_str() {
48 m if m == mime::APPLICATION_JSON.essence_str() => {
49 if bytes.is_empty() {
50 bytes = Bytes::from_static(b"{}");
51 }
52
53 let value = serde_json::from_slice(&bytes).map_err(anyhow::Error::from)?;
54 Ok(Payload(value))
55 }
56 m if m == mime::APPLICATION_MSGPACK.essence_str() => {
57 if bytes.is_empty() {
58 bytes = Bytes::from_static(&[0x80]);
59 }
60
61 let mut de = rmp_serde::Deserializer::new(bytes.as_ref()).with_human_readable();
62 let value = T::deserialize(&mut de).map_err(anyhow::Error::from)?;
63 Ok(Payload(value))
64 }
65 m if m == mime::TEXT_XML.essence_str() => {
66 if bytes.is_empty() {
67 bytes = Bytes::from_static(b"<root></root>");
68 }
69
70 let value =
71 serde_xml_rs::from_reader(bytes.as_ref()).map_err(anyhow::Error::from)?;
72 Ok(Payload(value))
73 }
74 "application/yaml" => {
75 if bytes.is_empty() {
76 bytes = Bytes::from_static(b"{}");
77 }
78
79 let value = serde_norway::from_slice(&bytes).map_err(anyhow::Error::from)?;
80 Ok(Payload(value))
81 }
82 _ => Err(PayloadRejection(anyhow::anyhow!(
83 "unsupported content type"
84 ))),
85 }
86 }
87}
88
89impl<T: DeserializeOwned, S: Send + Sync> FromRequest<S> for Payload<T> {
90 type Rejection = PayloadRejection;
91
92 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
93 let content_type = req
94 .headers()
95 .get(axum::http::header::CONTENT_TYPE)
96 .and_then(|v| v.to_str().ok())
97 .and_then(|s| s.parse::<mime::Mime>().ok());
98
99 let Some(content_type) = content_type else {
100 return Err(PayloadRejection(anyhow::anyhow!(
101 "missing content type header"
102 )));
103 };
104
105 if !AVAILABLE_DESERIALIZERS.contains(&content_type) {
106 return Err(PayloadRejection(anyhow::anyhow!(
107 "unsupported content type"
108 )));
109 }
110
111 let bytes = match Bytes::from_request(req, state).await {
112 Ok(b) => b,
113 Err(_) => return Err(PayloadRejection(anyhow::anyhow!("failed to read body"))),
114 };
115 Self::from_bytes(content_type, bytes)
116 }
117}
118
119impl<T, S> OptionalFromRequest<S> for Payload<T>
120where
121 T: DeserializeOwned,
122 S: Send + Sync,
123{
124 type Rejection = PayloadRejection;
125
126 async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
127 let content_type = req
128 .headers()
129 .get(axum::http::header::CONTENT_TYPE)
130 .and_then(|v| v.to_str().ok())
131 .and_then(|s| s.parse::<mime::Mime>().ok());
132 let Some(content_type) = content_type else {
133 return Ok(None);
134 };
135
136 if !AVAILABLE_DESERIALIZERS.contains(&content_type) {
137 return Err(PayloadRejection(anyhow::anyhow!(
138 "unsupported content type"
139 )));
140 }
141
142 let bytes = match Bytes::from_request(req, state).await {
143 Ok(b) => b,
144 Err(_) => return Err(PayloadRejection(anyhow::anyhow!("failed to read body"))),
145 };
146 Self::from_bytes(content_type, bytes).map(Some)
147 }
148}