shared/extensions/
shutdown_handlers.rs

1use crate::State;
2use futures_util::FutureExt;
3use std::{borrow::Cow, collections::HashMap, panic::AssertUnwindSafe, pin::Pin, sync::Arc};
4use tokio::sync::{OwnedRwLockReadGuard, RwLock};
5
6pub type ShutdownFunc =
7    dyn Fn(State) -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>> + Send + Sync;
8
9pub struct ShutdownHandler {
10    pub name: &'static str,
11    pub task: Box<ShutdownFunc>,
12}
13
14pub struct ShutdownHandlerBuilder {
15    state: State,
16    tasks: Arc<RwLock<HashMap<&'static str, ShutdownHandler>>>,
17}
18
19impl ShutdownHandlerBuilder {
20    pub fn new(state: State) -> Self {
21        Self {
22            state,
23            tasks: Arc::new(RwLock::new(HashMap::new())),
24        }
25    }
26
27    /// Adds a shutdown handler that will be executed when a gradual shutdown is initiated.
28    /// This will run on primary and on non-primary instances, so be aware of that when implementing your handler and
29    /// perhaps check `state.env.app_primary`.
30    pub async fn add_handler<
31        F: Fn(State) -> Fut + Send + Sync + 'static,
32        Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
33    >(
34        &self,
35        name: &'static str,
36        shutdown_fn: F,
37    ) {
38        let state = self.state.clone();
39        let tasks = Arc::clone(&self.tasks);
40
41        self.tasks.write().await.insert(
42            name,
43            ShutdownHandler {
44                name,
45                task: Box::new(move |state: State| Box::pin(shutdown_fn(state))),
46            },
47        );
48    }
49}
50
51#[derive(Default)]
52pub struct ShutdownHandlerManager {
53    builder: RwLock<Option<ShutdownHandlerBuilder>>,
54}
55
56impl ShutdownHandlerManager {
57    pub async fn merge_builder(&self, builder: ShutdownHandlerBuilder) {
58        self.builder.write().await.replace(builder);
59    }
60
61    pub async fn get_handlers(
62        &self,
63    ) -> OwnedRwLockReadGuard<HashMap<&'static str, ShutdownHandler>> {
64        let inner = self.builder.read().await;
65
66        inner.as_ref().unwrap().tasks.clone().read_owned().await
67    }
68
69    pub async fn handle_shutdown(&self) {
70        let handlers = self.get_handlers().await;
71
72        for (name, handler) in handlers.iter() {
73            tracing::info!(name, "running shutdown task");
74            let result = AssertUnwindSafe((handler.task)(
75                self.builder.read().await.as_ref().unwrap().state.clone(),
76            ))
77            .catch_unwind()
78            .await;
79
80            match result {
81                Ok(result) => {
82                    if let Err(err) = result {
83                        tracing::error!(name, %err, "shutdown task failed");
84                        sentry_anyhow::capture_anyhow(&err);
85                    }
86                }
87                Err(err) => {
88                    let err_msg: Cow<'_, str> = if let Some(s) = err.downcast_ref::<&str>() {
89                        (*s).into()
90                    } else if let Some(s) = err.downcast_ref::<String>() {
91                        s.as_str().into()
92                    } else {
93                        "unknown panic".into()
94                    };
95
96                    tracing::error!(name, %err_msg, "shutdown task panicked");
97                    sentry::capture_message(
98                        &format!("Shutdown handler '{}' panicked: {}", name, err_msg),
99                        sentry::Level::Error,
100                    );
101                }
102            }
103        }
104    }
105}