shared/
extract.rs

1use axum::{
2    extract::{FromRequestParts, OptionalFromRequestParts},
3    http::{Extensions, StatusCode, request::Parts},
4    response::IntoResponse,
5};
6use std::{
7    convert::Infallible,
8    ops::{Deref, DerefMut},
9};
10
11pub struct ConsumingExtensionError(String);
12
13impl IntoResponse for ConsumingExtensionError {
14    fn into_response(self) -> axum::response::Response {
15        (StatusCode::BAD_REQUEST, self.0).into_response()
16    }
17}
18
19pub struct ConsumingExtension<T>(pub T);
20
21impl<T> ConsumingExtension<T>
22where
23    T: Send + Sync + 'static,
24{
25    fn from_extensions(extensions: &mut Extensions) -> Option<Self> {
26        extensions.remove().map(ConsumingExtension)
27    }
28}
29
30impl<T, S> FromRequestParts<S> for ConsumingExtension<T>
31where
32    T: Send + Sync + 'static,
33    S: Send + Sync,
34{
35    type Rejection = ConsumingExtensionError;
36
37    async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
38        Self::from_extensions(&mut req.extensions).ok_or_else(|| {
39            ConsumingExtensionError(format!(
40                "Extension of type `{}` was not found. Perhaps you forgot to add it? See `shared::extract::ConsumingExtension`.",
41                std::any::type_name::<T>()
42            ))
43        })
44    }
45}
46
47impl<T, S> OptionalFromRequestParts<S> for ConsumingExtension<T>
48where
49    T: Send + Sync + 'static,
50    S: Send + Sync,
51{
52    type Rejection = Infallible;
53
54    async fn from_request_parts(
55        req: &mut Parts,
56        _state: &S,
57    ) -> Result<Option<Self>, Self::Rejection> {
58        Ok(Self::from_extensions(&mut req.extensions))
59    }
60}
61
62impl<T> Deref for ConsumingExtension<T>
63where
64    T: Send + Sync + 'static,
65{
66    type Target = T;
67
68    fn deref(&self) -> &Self::Target {
69        &self.0
70    }
71}
72
73impl<T> DerefMut for ConsumingExtension<T>
74where
75    T: Send + Sync + 'static,
76{
77    fn deref_mut(&mut self) -> &mut Self::Target {
78        &mut self.0
79    }
80}