shared/extensions/
shutdown_handlers.rs1use crate::State;
2use futures_util::FutureExt;
3use std::{borrow::Cow, collections::HashMap, panic::AssertUnwindSafe, pin::Pin, sync::Arc};
4use tokio::sync::{OwnedRwLockReadGuard, RwLock};
5
6pub type ShutdownFunc =
7 dyn Fn(State) -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>> + Send + Sync;
8
9pub struct ShutdownHandler {
10 pub name: &'static str,
11 pub task: Box<ShutdownFunc>,
12}
13
14pub struct ShutdownHandlerBuilder {
15 state: State,
16 tasks: Arc<RwLock<HashMap<&'static str, ShutdownHandler>>>,
17}
18
19impl ShutdownHandlerBuilder {
20 pub fn new(state: State) -> Self {
21 Self {
22 state,
23 tasks: Arc::new(RwLock::new(HashMap::new())),
24 }
25 }
26
27 pub async fn add_handler<
31 F: Fn(State) -> Fut + Send + Sync + 'static,
32 Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
33 >(
34 &self,
35 name: &'static str,
36 shutdown_fn: F,
37 ) {
38 let state = self.state.clone();
39 let tasks = Arc::clone(&self.tasks);
40
41 self.tasks.write().await.insert(
42 name,
43 ShutdownHandler {
44 name,
45 task: Box::new(move |state: State| Box::pin(shutdown_fn(state))),
46 },
47 );
48 }
49}
50
51#[derive(Default)]
52pub struct ShutdownHandlerManager {
53 builder: RwLock<Option<ShutdownHandlerBuilder>>,
54}
55
56impl ShutdownHandlerManager {
57 pub async fn merge_builder(&self, builder: ShutdownHandlerBuilder) {
58 self.builder.write().await.replace(builder);
59 }
60
61 pub async fn get_handlers(
62 &self,
63 ) -> OwnedRwLockReadGuard<HashMap<&'static str, ShutdownHandler>> {
64 let inner = self.builder.read().await;
65
66 inner.as_ref().unwrap().tasks.clone().read_owned().await
67 }
68
69 pub async fn handle_shutdown(&self) {
70 let handlers = self.get_handlers().await;
71
72 for (name, handler) in handlers.iter() {
73 tracing::info!(name, "running shutdown task");
74 let result = AssertUnwindSafe((handler.task)(
75 self.builder.read().await.as_ref().unwrap().state.clone(),
76 ))
77 .catch_unwind()
78 .await;
79
80 match result {
81 Ok(result) => {
82 if let Err(err) = result {
83 tracing::error!(name, %err, "shutdown task failed");
84 sentry_anyhow::capture_anyhow(&err);
85 }
86 }
87 Err(err) => {
88 let err_msg: Cow<'_, str> = if let Some(s) = err.downcast_ref::<&str>() {
89 (*s).into()
90 } else if let Some(s) = err.downcast_ref::<String>() {
91 s.as_str().into()
92 } else {
93 "unknown panic".into()
94 };
95
96 tracing::error!(name, %err_msg, "shutdown task panicked");
97 sentry::capture_message(
98 &format!("Shutdown handler '{}' panicked: {}", name, err_msg),
99 sentry::Level::Error,
100 );
101 }
102 }
103 }
104 }
105}