1use crate::{ApiError, database::DatabaseError};
2use accept_header::Accept;
3use axum::response::IntoResponse;
4use std::{
5 borrow::Cow,
6 fmt::{Debug, Display},
7 str::FromStr,
8};
9
10pub type ApiResponseResult = Result<ApiResponse, ApiResponse>;
11
12tokio::task_local! {
13 pub static ACCEPT_HEADER: Option<Accept>;
14 pub static APP_DEBUG: bool;
15}
16
17pub fn accept_from_headers(headers: &axum::http::HeaderMap) -> Option<Accept> {
18 let header_value = headers.get(axum::http::header::ACCEPT)?;
19 let header_str = header_value.to_str().ok()?;
20
21 Accept::from_str(header_str).ok()
22}
23
24#[derive(Debug)]
25pub struct ApiResponse {
26 pub body: axum::body::Body,
27 pub status: axum::http::StatusCode,
28 pub headers: axum::http::HeaderMap,
29}
30
31impl ApiResponse {
32 #[inline]
33 pub fn new(body: axum::body::Body) -> Self {
34 Self {
35 body,
36 status: axum::http::StatusCode::OK,
37 headers: axum::http::HeaderMap::new(),
38 }
39 }
40
41 #[inline]
42 pub fn new_stream(stream: impl tokio::io::AsyncRead + Send + 'static) -> Self {
43 Self {
44 body: axum::body::Body::from_stream(tokio_util::io::ReaderStream::with_capacity(
45 stream,
46 crate::BUFFER_SIZE,
47 )),
48 status: axum::http::StatusCode::OK,
49 headers: axum::http::HeaderMap::new(),
50 }
51 }
52
53 pub fn new_serialized(body: impl serde::Serialize) -> Self {
55 let accept_header = ACCEPT_HEADER.try_with(|h| h.clone()).ok().flatten();
56
57 static AVAILABLE_SERIALIZERS: &[mime::Mime] = &[
58 mime::APPLICATION_JSON,
59 mime::APPLICATION_MSGPACK,
60 mime::TEXT_XML,
61 ];
62
63 let negotiated = accept_header
64 .as_ref()
65 .and_then(|accept| accept.negotiate(AVAILABLE_SERIALIZERS).ok())
66 .unwrap_or(mime::APPLICATION_JSON);
67
68 let (content_type, body) = match negotiated {
69 m if m.essence_str() == mime::APPLICATION_MSGPACK.essence_str() => {
70 let mut bytes = Vec::new();
71 let mut se = rmp_serde::Serializer::new(&mut bytes)
72 .with_struct_map()
73 .with_human_readable();
74 if let Err(err) = body.serialize(&mut se) {
75 tracing::error!(
76 "failed to serialize response body to MessagePack: {:?}",
77 err
78 );
79
80 (
81 axum::http::HeaderValue::from_static("application/json"),
82 axum::body::Body::from("{}"),
83 )
84 } else {
85 (
86 axum::http::HeaderValue::from_static("application/msgpack"),
87 axum::body::Body::from(bytes),
88 )
89 }
90 }
91 m if m.essence_str() == mime::TEXT_XML.essence_str() => {
92 let string = serde_xml_rs::to_string(&body).unwrap_or_else(|err| {
93 tracing::error!("failed to serialize response body to XML: {:?}", err);
94 "<error>serialization failed</error>".to_string()
95 });
96
97 (
98 axum::http::HeaderValue::from_static("text/xml"),
99 axum::body::Body::from(string),
100 )
101 }
102 _ => {
103 let bytes = serde_json::to_vec(&body).unwrap_or_else(|err| {
104 tracing::error!("failed to serialize response body to JSON: {:?}", err);
105 b"{}".to_vec()
106 });
107
108 (
109 axum::http::HeaderValue::from_static("application/json"),
110 axum::body::Body::from(bytes),
111 )
112 }
113 };
114
115 Self {
116 body,
117 status: axum::http::StatusCode::OK,
118 headers: axum::http::HeaderMap::from_iter([
119 (axum::http::header::CONTENT_TYPE, content_type),
120 (
121 axum::http::header::VARY,
122 axum::http::HeaderValue::from_static("Accept"),
123 ),
124 ]),
125 }
126 }
127
128 #[inline]
129 pub fn error(err: impl AsRef<str>) -> Self {
130 Self::new_serialized(ApiError::new_value(&[err.as_ref()]))
131 .with_status(axum::http::StatusCode::BAD_REQUEST)
132 }
133
134 #[inline]
135 pub fn with_status(mut self, status: axum::http::StatusCode) -> Self {
136 self.status = status;
137 self
138 }
139
140 #[inline]
141 pub fn with_header(mut self, key: &'static str, value: impl AsRef<str>) -> Self {
142 if let Ok(header_value) = axum::http::HeaderValue::from_str(value.as_ref()) {
143 self.headers.insert(key, header_value);
144 }
145
146 self
147 }
148
149 #[inline]
150 pub fn with_optional_header(
151 mut self,
152 key: &'static str,
153 value: Option<impl AsRef<str>>,
154 ) -> Self {
155 let value = match value {
156 Some(value) => value,
157 None => return self,
158 };
159
160 if let Ok(header_value) = axum::http::HeaderValue::from_str(value.as_ref()) {
161 self.headers.insert(key, header_value);
162 }
163
164 self
165 }
166
167 #[inline]
168 pub fn with_headers(mut self, headers: &axum::http::HeaderMap) -> Self {
169 for (key, value) in headers.iter() {
170 self.headers.insert(key, value.clone());
171 }
172
173 self
174 }
175
176 #[inline]
177 pub fn ok(self) -> ApiResponseResult {
178 Ok(self)
179 }
180}
181
182impl<T> From<T> for ApiResponse
183where
184 T: Into<anyhow::Error>,
185{
186 fn from(err: T) -> Self {
187 let err: anyhow::Error = err.into();
188
189 if let Some(error) = err.downcast_ref::<DisplayError>() {
190 return ApiResponse::error(&error.message).with_status(error.status);
191 } else if let Some(DatabaseError::Validation(error)) = err.downcast_ref::<DatabaseError>() {
192 let error_messages = crate::utils::flatten_validation_errors(error);
193
194 return ApiResponse::new_serialized(ApiError::new_strings_value(error_messages))
195 .with_status(axum::http::StatusCode::BAD_REQUEST);
196 } else if let Some(DatabaseError::InvalidRelation(error)) =
197 err.downcast_ref::<DatabaseError>()
198 {
199 return ApiResponse::error(error.to_string())
200 .with_status(axum::http::StatusCode::BAD_REQUEST);
201 }
202
203 tracing::error!("a request error occurred: {:?}", err);
204 sentry_anyhow::capture_anyhow(&err);
205
206 let debug = APP_DEBUG.try_get().unwrap_or_default();
207
208 ApiResponse::error(if debug {
209 Cow::Owned(err.to_string())
210 } else {
211 "internal server error".into()
212 })
213 .with_status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
214 }
215}
216
217impl IntoResponse for ApiResponse {
218 #[inline]
219 fn into_response(self) -> axum::response::Response {
220 let mut response = axum::response::Response::new(self.body);
221 *response.status_mut() = self.status;
222 *response.headers_mut() = self.headers;
223
224 response
225 }
226}
227
228#[derive(Debug)]
229pub struct DisplayError<'a> {
230 status: axum::http::StatusCode,
231 message: Cow<'a, str>,
232}
233
234impl<'a> DisplayError<'a> {
235 pub fn new(message: impl Into<Cow<'a, str>>) -> Self {
236 Self {
237 status: axum::http::StatusCode::BAD_REQUEST,
238 message: message.into(),
239 }
240 }
241
242 pub fn with_status(mut self, status: axum::http::StatusCode) -> Self {
243 self.status = status;
244
245 self
246 }
247}
248
249impl<'a> Display for DisplayError<'a> {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 f.debug_struct("DisplayError")
252 .field("status", &self.status)
253 .field("message", &self.message)
254 .finish()
255 }
256}
257
258impl<'a> std::error::Error for DisplayError<'a> {}