1use axum::{
2 extract::{FromRequestParts, OptionalFromRequestParts},
3 http::{Extensions, StatusCode, request::Parts},
4 response::IntoResponse,
5};
6use std::{
7 convert::Infallible,
8 ops::{Deref, DerefMut},
9};
10
11pub struct ConsumingExtensionError(String);
12
13impl IntoResponse for ConsumingExtensionError {
14 fn into_response(self) -> axum::response::Response {
15 (StatusCode::BAD_REQUEST, self.0).into_response()
16 }
17}
18
19pub struct ConsumingExtension<T>(pub T);
20
21impl<T> ConsumingExtension<T>
22where
23 T: Send + Sync + 'static,
24{
25 fn from_extensions(extensions: &mut Extensions) -> Option<Self> {
26 extensions.remove().map(ConsumingExtension)
27 }
28}
29
30impl<T, S> FromRequestParts<S> for ConsumingExtension<T>
31where
32 T: Send + Sync + 'static,
33 S: Send + Sync,
34{
35 type Rejection = ConsumingExtensionError;
36
37 async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
38 Self::from_extensions(&mut req.extensions).ok_or_else(|| {
39 ConsumingExtensionError(format!(
40 "Extension of type `{}` was not found. Perhaps you forgot to add it? See `shared::extract::ConsumingExtension`.",
41 std::any::type_name::<T>()
42 ))
43 })
44 }
45}
46
47impl<T, S> OptionalFromRequestParts<S> for ConsumingExtension<T>
48where
49 T: Send + Sync + 'static,
50 S: Send + Sync,
51{
52 type Rejection = Infallible;
53
54 async fn from_request_parts(
55 req: &mut Parts,
56 _state: &S,
57 ) -> Result<Option<Self>, Self::Rejection> {
58 Ok(Self::from_extensions(&mut req.extensions))
59 }
60}
61
62impl<T> Deref for ConsumingExtension<T>
63where
64 T: Send + Sync + 'static,
65{
66 type Target = T;
67
68 fn deref(&self) -> &Self::Target {
69 &self.0
70 }
71}
72
73impl<T> DerefMut for ConsumingExtension<T>
74where
75 T: Send + Sync + 'static,
76{
77 fn deref_mut(&mut self) -> &mut Self::Target {
78 &mut self.0
79 }
80}