shared/extensions/
background_tasks.rs

1use crate::State;
2use futures_util::FutureExt;
3use std::{borrow::Cow, collections::HashMap, panic::AssertUnwindSafe, sync::Arc};
4use tokio::sync::{OwnedRwLockReadGuard, RwLock};
5
6pub struct BackgroundTask {
7    pub name: &'static str,
8    pub last_execution: std::time::Instant,
9    pub last_error: Option<anyhow::Error>,
10
11    pub task: tokio::task::JoinHandle<()>,
12}
13
14pub struct BackgroundTaskBuilder {
15    state: State,
16    tasks: Arc<RwLock<HashMap<&'static str, BackgroundTask>>>,
17}
18
19impl BackgroundTaskBuilder {
20    pub fn new(state: State) -> Self {
21        Self {
22            state,
23            tasks: Arc::new(RwLock::new(HashMap::new())),
24        }
25    }
26
27    /// Adds a background task that will be executed periodically, depending on your loop function implementation.
28    /// This will only run on primary instances, so be aware of that when implementing your task.
29    pub async fn add_task<
30        F: Fn(State) -> Fut + Send + Sync + 'static,
31        Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
32    >(
33        &self,
34        name: &'static str,
35        loop_fn: F,
36    ) {
37        if !self.state.env.app_primary {
38            return;
39        }
40
41        let state = self.state.clone();
42        let tasks = Arc::clone(&self.tasks);
43
44        self.tasks.write().await.insert(
45            name,
46            BackgroundTask {
47                name,
48                last_execution: std::time::Instant::now(),
49                last_error: None,
50                task: tokio::spawn(async move {
51                    loop {
52                        if let Some(task) = tasks.write().await.get_mut(name) {
53                            task.last_execution = std::time::Instant::now();
54                        }
55
56                        tracing::debug!(name, "running background task loop function");
57                        let result = AssertUnwindSafe(loop_fn(state.clone()))
58                            .catch_unwind()
59                            .await;
60
61                        let result = match result {
62                            Ok(result) => result,
63                            Err(err) => {
64                                let err_msg: Cow<'_, str> =
65                                    if let Some(s) = err.downcast_ref::<&str>() {
66                                        (*s).into()
67                                    } else if let Some(s) = err.downcast_ref::<String>() {
68                                        s.clone().into()
69                                    } else {
70                                        "Unknown panic".into()
71                                    };
72
73                                tracing::error!(name, "background task panicked: {}", err_msg);
74                                sentry::capture_message(
75                                    &format!("Background task '{}' panicked: {}", name, err_msg),
76                                    sentry::Level::Error,
77                                );
78
79                                if let Some(task) = tasks.write().await.get_mut(name) {
80                                    task.last_error = Some(anyhow::anyhow!(err_msg));
81                                }
82
83                                return;
84                            }
85                        };
86
87                        if let Err(err) = &result {
88                            tracing::error!(name, "a background task error occurred: {:?}", err);
89                            sentry_anyhow::capture_anyhow(err);
90                        }
91
92                        if let Some(task) = tasks.write().await.get_mut(name) {
93                            task.last_error = result.err();
94                        }
95                    }
96                }),
97            },
98        );
99    }
100}
101
102#[derive(Default)]
103pub struct BackgroundTaskManager {
104    builder: RwLock<Option<BackgroundTaskBuilder>>,
105}
106
107impl BackgroundTaskManager {
108    pub async fn merge_builder(&self, builder: BackgroundTaskBuilder) {
109        self.builder.write().await.replace(builder);
110    }
111
112    pub async fn get_tasks(&self) -> OwnedRwLockReadGuard<HashMap<&'static str, BackgroundTask>> {
113        let inner = self.builder.read().await;
114
115        inner.as_ref().unwrap().tasks.clone().read_owned().await
116    }
117}