1use crate::{ApiError, database::DatabaseError};
2use accept_header::Accept;
3use axum::response::IntoResponse;
4use std::{
5 borrow::Cow,
6 fmt::{Debug, Display},
7 str::FromStr,
8};
9
10pub type ApiResponseResult = Result<ApiResponse, ApiResponse>;
11
12tokio::task_local! {
13 pub static ACCEPT_HEADER: Option<Accept>;
14 pub static APP_DEBUG: bool;
15}
16
17pub fn accept_from_headers(headers: &axum::http::HeaderMap) -> Option<Accept> {
18 let header_value = headers.get(axum::http::header::ACCEPT)?;
19 let header_str = header_value.to_str().ok()?;
20
21 Accept::from_str(header_str).ok()
22}
23
24#[derive(Debug)]
25pub struct ApiResponse {
26 pub body: axum::body::Body,
27 pub status: axum::http::StatusCode,
28 pub headers: axum::http::HeaderMap,
29}
30
31impl ApiResponse {
32 #[inline]
33 pub fn new(body: axum::body::Body) -> Self {
34 Self {
35 body,
36 status: axum::http::StatusCode::OK,
37 headers: axum::http::HeaderMap::new(),
38 }
39 }
40
41 #[inline]
42 pub fn new_stream(stream: impl tokio::io::AsyncRead + Send + 'static) -> Self {
43 Self {
44 body: axum::body::Body::from_stream(tokio_util::io::ReaderStream::with_capacity(
45 stream,
46 crate::BUFFER_SIZE,
47 )),
48 status: axum::http::StatusCode::OK,
49 headers: axum::http::HeaderMap::new(),
50 }
51 }
52
53 pub fn new_serialized(body: impl serde::Serialize) -> Self {
55 let accept_header = ACCEPT_HEADER.try_with(|h| h.clone()).ok().flatten();
56
57 static AVAILABLE_SERIALIZERS: &[mime::Mime] = &[
58 mime::APPLICATION_JSON,
59 mime::APPLICATION_MSGPACK,
60 mime::TEXT_XML,
61 ];
62
63 let negotiated = accept_header
64 .as_ref()
65 .and_then(|accept| accept.negotiate(AVAILABLE_SERIALIZERS).ok())
66 .unwrap_or(mime::APPLICATION_JSON);
67
68 let (content_type, body) = match negotiated {
69 m if m.essence_str() == mime::APPLICATION_MSGPACK.essence_str() => {
70 let mut bytes = Vec::new();
71 let mut se = rmp_serde::Serializer::new(&mut bytes)
72 .with_struct_map()
73 .with_human_readable();
74 if let Err(err) = body.serialize(&mut se) {
75 tracing::error!(
76 "failed to serialize response body to MessagePack: {:?}",
77 err
78 );
79
80 (
81 axum::http::HeaderValue::from_static("application/json"),
82 axum::body::Body::from("{}"),
83 )
84 } else {
85 (
86 axum::http::HeaderValue::from_static("application/msgpack"),
87 axum::body::Body::from(bytes),
88 )
89 }
90 }
91 m if m.essence_str() == mime::TEXT_XML.essence_str() => {
92 let string = serde_xml_rs::to_string(&body).unwrap_or_else(|err| {
93 tracing::error!("failed to serialize response body to XML: {:?}", err);
94 "<error>serialization failed</error>".to_string()
95 });
96
97 (
98 axum::http::HeaderValue::from_static("text/xml"),
99 axum::body::Body::from(string),
100 )
101 }
102 _ => {
103 let bytes = serde_json::to_vec(&body).unwrap_or_else(|err| {
104 tracing::error!("failed to serialize response body to JSON: {:?}", err);
105 b"{}".to_vec()
106 });
107
108 (
109 axum::http::HeaderValue::from_static("application/json"),
110 axum::body::Body::from(bytes),
111 )
112 }
113 };
114
115 Self {
116 body,
117 status: axum::http::StatusCode::OK,
118 headers: axum::http::HeaderMap::from_iter([
119 (axum::http::header::CONTENT_TYPE, content_type),
120 (
121 axum::http::header::VARY,
122 axum::http::HeaderValue::from_static("Accept"),
123 ),
124 ]),
125 }
126 }
127
128 #[inline]
129 pub fn error(err: impl AsRef<str>) -> Self {
130 Self::new_serialized(ApiError::new_value(&[err.as_ref()]))
131 .with_status(axum::http::StatusCode::BAD_REQUEST)
132 }
133
134 #[inline]
135 pub fn with_status(mut self, status: axum::http::StatusCode) -> Self {
136 self.status = status;
137 self
138 }
139
140 #[inline]
141 pub fn with_header(mut self, key: &'static str, value: impl AsRef<str>) -> Self {
142 if let Ok(header_value) = axum::http::HeaderValue::from_str(value.as_ref()) {
143 self.headers.insert(key, header_value);
144 }
145
146 self
147 }
148
149 #[inline]
150 pub fn with_optional_header(
151 mut self,
152 key: &'static str,
153 value: Option<impl AsRef<str>>,
154 ) -> Self {
155 let value = match value {
156 Some(value) => value,
157 None => return self,
158 };
159
160 if let Ok(header_value) = axum::http::HeaderValue::from_str(value.as_ref()) {
161 self.headers.insert(key, header_value);
162 }
163
164 self
165 }
166
167 #[inline]
168 pub fn ok(self) -> ApiResponseResult {
169 Ok(self)
170 }
171}
172
173impl<T> From<T> for ApiResponse
174where
175 T: Into<anyhow::Error>,
176{
177 fn from(err: T) -> Self {
178 let err: anyhow::Error = err.into();
179
180 if let Some(error) = err.downcast_ref::<DisplayError>() {
181 return ApiResponse::error(&error.message).with_status(error.status);
182 } else if let Some(DatabaseError::Validation(error)) = err.downcast_ref::<DatabaseError>() {
183 let error_messages = crate::utils::flatten_validation_errors(error);
184
185 return ApiResponse::new_serialized(ApiError::new_strings_value(error_messages))
186 .with_status(axum::http::StatusCode::BAD_REQUEST);
187 } else if let Some(DatabaseError::InvalidRelation(error)) =
188 err.downcast_ref::<DatabaseError>()
189 {
190 return ApiResponse::error(error.to_string())
191 .with_status(axum::http::StatusCode::BAD_REQUEST);
192 }
193
194 tracing::error!("a request error occurred: {:?}", err);
195 sentry_anyhow::capture_anyhow(&err);
196
197 let debug = APP_DEBUG.try_get().unwrap_or_default();
198
199 ApiResponse::error(if debug {
200 Cow::Owned(err.to_string())
201 } else {
202 "internal server error".into()
203 })
204 .with_status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
205 }
206}
207
208impl IntoResponse for ApiResponse {
209 #[inline]
210 fn into_response(self) -> axum::response::Response {
211 let mut response = axum::response::Response::new(self.body);
212 *response.status_mut() = self.status;
213 *response.headers_mut() = self.headers;
214
215 response
216 }
217}
218
219#[derive(Debug)]
220pub struct DisplayError<'a> {
221 status: axum::http::StatusCode,
222 message: Cow<'a, str>,
223}
224
225impl<'a> DisplayError<'a> {
226 pub fn new(message: impl Into<Cow<'a, str>>) -> Self {
227 Self {
228 status: axum::http::StatusCode::BAD_REQUEST,
229 message: message.into(),
230 }
231 }
232
233 pub fn with_status(mut self, status: axum::http::StatusCode) -> Self {
234 self.status = status;
235
236 self
237 }
238}
239
240impl<'a> Display for DisplayError<'a> {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 f.debug_struct("DisplayError")
243 .field("status", &self.status)
244 .field("message", &self.message)
245 .finish()
246 }
247}
248
249impl<'a> std::error::Error for DisplayError<'a> {}