shared/
response.rs

1use crate::{ApiError, database::DatabaseError};
2use accept_header::Accept;
3use axum::response::IntoResponse;
4use std::{
5    borrow::Cow,
6    fmt::{Debug, Display},
7    str::FromStr,
8};
9
10pub type ApiResponseResult = Result<ApiResponse, ApiResponse>;
11
12tokio::task_local! {
13    pub static ACCEPT_HEADER: Option<Accept>;
14    pub static APP_DEBUG: bool;
15}
16
17pub fn accept_from_headers(headers: &axum::http::HeaderMap) -> Option<Accept> {
18    let header_value = headers.get(axum::http::header::ACCEPT)?;
19    let header_str = header_value.to_str().ok()?;
20
21    Accept::from_str(header_str).ok()
22}
23
24#[derive(Debug)]
25pub struct ApiResponse {
26    pub body: axum::body::Body,
27    pub status: axum::http::StatusCode,
28    pub headers: axum::http::HeaderMap,
29}
30
31impl ApiResponse {
32    #[inline]
33    pub fn new(body: axum::body::Body) -> Self {
34        Self {
35            body,
36            status: axum::http::StatusCode::OK,
37            headers: axum::http::HeaderMap::new(),
38        }
39    }
40
41    #[inline]
42    pub fn new_stream(stream: impl tokio::io::AsyncRead + Send + 'static) -> Self {
43        Self {
44            body: axum::body::Body::from_stream(tokio_util::io::ReaderStream::with_capacity(
45                stream,
46                crate::BUFFER_SIZE,
47            )),
48            status: axum::http::StatusCode::OK,
49            headers: axum::http::HeaderMap::new(),
50        }
51    }
52
53    /// Create a new API response with content negotiation based on the `Accept` header.
54    pub fn new_serialized(body: impl serde::Serialize) -> Self {
55        let accept_header = ACCEPT_HEADER.try_with(|h| h.clone()).ok().flatten();
56
57        static AVAILABLE_SERIALIZERS: &[mime::Mime] = &[
58            mime::APPLICATION_JSON,
59            mime::APPLICATION_MSGPACK,
60            mime::TEXT_XML,
61        ];
62
63        let negotiated = accept_header
64            .as_ref()
65            .and_then(|accept| accept.negotiate(AVAILABLE_SERIALIZERS).ok())
66            .unwrap_or(mime::APPLICATION_JSON);
67
68        let (content_type, body) = match negotiated {
69            m if m.essence_str() == mime::APPLICATION_MSGPACK.essence_str() => {
70                let mut bytes = Vec::new();
71                let mut se = rmp_serde::Serializer::new(&mut bytes)
72                    .with_struct_map()
73                    .with_human_readable();
74                if let Err(err) = body.serialize(&mut se) {
75                    tracing::error!(
76                        "failed to serialize response body to MessagePack: {:?}",
77                        err
78                    );
79
80                    (
81                        axum::http::HeaderValue::from_static("application/json"),
82                        axum::body::Body::from("{}"),
83                    )
84                } else {
85                    (
86                        axum::http::HeaderValue::from_static("application/msgpack"),
87                        axum::body::Body::from(bytes),
88                    )
89                }
90            }
91            m if m.essence_str() == mime::TEXT_XML.essence_str() => {
92                let string = serde_xml_rs::to_string(&body).unwrap_or_else(|err| {
93                    tracing::error!("failed to serialize response body to XML: {:?}", err);
94                    "<error>serialization failed</error>".to_string()
95                });
96
97                (
98                    axum::http::HeaderValue::from_static("text/xml"),
99                    axum::body::Body::from(string),
100                )
101            }
102            _ => {
103                let bytes = serde_json::to_vec(&body).unwrap_or_else(|err| {
104                    tracing::error!("failed to serialize response body to JSON: {:?}", err);
105                    b"{}".to_vec()
106                });
107
108                (
109                    axum::http::HeaderValue::from_static("application/json"),
110                    axum::body::Body::from(bytes),
111                )
112            }
113        };
114
115        Self {
116            body,
117            status: axum::http::StatusCode::OK,
118            headers: axum::http::HeaderMap::from_iter([
119                (axum::http::header::CONTENT_TYPE, content_type),
120                (
121                    axum::http::header::VARY,
122                    axum::http::HeaderValue::from_static("Accept"),
123                ),
124            ]),
125        }
126    }
127
128    #[inline]
129    pub fn error(err: impl AsRef<str>) -> Self {
130        Self::new_serialized(ApiError::new_value(&[err.as_ref()]))
131            .with_status(axum::http::StatusCode::BAD_REQUEST)
132    }
133
134    #[inline]
135    pub fn with_status(mut self, status: axum::http::StatusCode) -> Self {
136        self.status = status;
137        self
138    }
139
140    #[inline]
141    pub fn with_header(mut self, key: &'static str, value: impl AsRef<str>) -> Self {
142        if let Ok(header_value) = axum::http::HeaderValue::from_str(value.as_ref()) {
143            self.headers.insert(key, header_value);
144        }
145
146        self
147    }
148
149    #[inline]
150    pub fn with_optional_header(
151        mut self,
152        key: &'static str,
153        value: Option<impl AsRef<str>>,
154    ) -> Self {
155        let value = match value {
156            Some(value) => value,
157            None => return self,
158        };
159
160        if let Ok(header_value) = axum::http::HeaderValue::from_str(value.as_ref()) {
161            self.headers.insert(key, header_value);
162        }
163
164        self
165    }
166
167    #[inline]
168    pub fn ok(self) -> ApiResponseResult {
169        Ok(self)
170    }
171}
172
173impl<T> From<T> for ApiResponse
174where
175    T: Into<anyhow::Error>,
176{
177    fn from(err: T) -> Self {
178        let err: anyhow::Error = err.into();
179
180        if let Some(error) = err.downcast_ref::<DisplayError>() {
181            return ApiResponse::error(&error.message).with_status(error.status);
182        } else if let Some(DatabaseError::Validation(error)) = err.downcast_ref::<DatabaseError>() {
183            let error_messages = crate::utils::flatten_validation_errors(error);
184
185            return ApiResponse::new_serialized(ApiError::new_strings_value(error_messages))
186                .with_status(axum::http::StatusCode::BAD_REQUEST);
187        } else if let Some(DatabaseError::InvalidRelation(error)) =
188            err.downcast_ref::<DatabaseError>()
189        {
190            return ApiResponse::error(error.to_string())
191                .with_status(axum::http::StatusCode::BAD_REQUEST);
192        }
193
194        tracing::error!("a request error occurred: {:?}", err);
195        sentry_anyhow::capture_anyhow(&err);
196
197        let debug = APP_DEBUG.try_get().unwrap_or_default();
198
199        ApiResponse::error(if debug {
200            Cow::Owned(err.to_string())
201        } else {
202            "internal server error".into()
203        })
204        .with_status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
205    }
206}
207
208impl IntoResponse for ApiResponse {
209    #[inline]
210    fn into_response(self) -> axum::response::Response {
211        let mut response = axum::response::Response::new(self.body);
212        *response.status_mut() = self.status;
213        *response.headers_mut() = self.headers;
214
215        response
216    }
217}
218
219#[derive(Debug)]
220pub struct DisplayError<'a> {
221    status: axum::http::StatusCode,
222    message: Cow<'a, str>,
223}
224
225impl<'a> DisplayError<'a> {
226    pub fn new(message: impl Into<Cow<'a, str>>) -> Self {
227        Self {
228            status: axum::http::StatusCode::BAD_REQUEST,
229            message: message.into(),
230        }
231    }
232
233    pub fn with_status(mut self, status: axum::http::StatusCode) -> Self {
234        self.status = status;
235
236        self
237    }
238}
239
240impl<'a> Display for DisplayError<'a> {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        f.debug_struct("DisplayError")
243            .field("status", &self.status)
244            .field("message", &self.message)
245            .finish()
246    }
247}
248
249impl<'a> std::error::Error for DisplayError<'a> {}