Skip to main content

shared/extensions/
background_tasks.rs

1use crate::State;
2use futures_util::FutureExt;
3use std::{borrow::Cow, collections::HashMap, panic::AssertUnwindSafe, sync::Arc};
4use tokio::sync::{OwnedRwLockReadGuard, RwLock};
5
6pub struct BackgroundTask {
7    pub name: &'static str,
8    pub last_execution: std::time::Instant,
9    pub last_error: Option<anyhow::Error>,
10
11    pub task: tokio::task::JoinHandle<()>,
12}
13
14pub struct BackgroundTaskBuilder {
15    state: State,
16    tasks: Arc<RwLock<HashMap<&'static str, BackgroundTask>>>,
17}
18
19impl BackgroundTaskBuilder {
20    pub fn new(state: State) -> Self {
21        Self {
22            state,
23            tasks: Arc::new(RwLock::new(HashMap::new())),
24        }
25    }
26
27    /// Adds a background task that will be executed periodically, depending on your loop function implementation.
28    /// This will only run on primary instances, so be aware of that when implementing your task.
29    pub async fn add_task<
30        F: Fn(State) -> Fut + Send + Sync + 'static,
31        Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
32    >(
33        &self,
34        name: &'static str,
35        loop_fn: F,
36    ) {
37        if !self.state.env.app_primary {
38            return;
39        }
40
41        let state = self.state.clone();
42        let tasks = Arc::clone(&self.tasks);
43
44        self.tasks.write().await.insert(
45            name,
46            BackgroundTask {
47                name,
48                last_execution: std::time::Instant::now(),
49                last_error: None,
50                task: tokio::spawn(async move {
51                    loop {
52                        if let Some(task) = tasks.write().await.get_mut(name) {
53                            task.last_execution = std::time::Instant::now();
54                        }
55
56                        tracing::debug!(name, "running background task function");
57                        let result = AssertUnwindSafe(loop_fn(state.clone()))
58                            .catch_unwind()
59                            .await;
60
61                        let result = match result {
62                            Ok(result) => result,
63                            Err(err) => {
64                                let err_msg: Cow<'_, str> =
65                                    if let Some(s) = err.downcast_ref::<&str>() {
66                                        (*s).into()
67                                    } else if let Some(s) = err.downcast_ref::<String>() {
68                                        s.clone().into()
69                                    } else {
70                                        "Unknown panic".into()
71                                    };
72
73                                tracing::error!(name, "background task panicked: {}", err_msg);
74                                sentry::capture_message(
75                                    &format!("Background task '{}' panicked: {}", name, err_msg),
76                                    sentry::Level::Error,
77                                );
78
79                                if let Some(task) = tasks.write().await.get_mut(name) {
80                                    task.last_error = Some(anyhow::anyhow!(err_msg));
81                                }
82
83                                return;
84                            }
85                        };
86
87                        if let Err(err) = &result {
88                            tracing::error!(name, "a background task error occurred: {:?}", err);
89                            sentry_anyhow::capture_anyhow(err);
90                        }
91
92                        if let Some(task) = tasks.write().await.get_mut(name) {
93                            task.last_error = result.err();
94                        }
95                    }
96                }),
97            },
98        );
99    }
100
101    /// Adds a background task that will be executed periodically, depending on the cron you provide, with the UTC timezone.
102    /// This will only run on primary instances, so be aware of that when implementing your task.
103    pub async fn add_cron_task<
104        F: Fn(State) -> Fut + Send + Sync + 'static,
105        Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
106    >(
107        &self,
108        name: &'static str,
109        cron: cron::Schedule,
110        r#fn: F,
111    ) {
112        if !self.state.env.app_primary {
113            return;
114        }
115
116        let state = self.state.clone();
117        let tasks = Arc::clone(&self.tasks);
118
119        self.tasks.write().await.insert(
120            name,
121            BackgroundTask {
122                name,
123                last_execution: std::time::Instant::now(),
124                last_error: None,
125                task: tokio::spawn(async move {
126                    let schedule_iter = cron.upcoming(chrono::Utc);
127
128                    for target_datetime in schedule_iter {
129                        let target_timestamp = target_datetime.timestamp();
130                        let now_timestamp = chrono::Utc::now().timestamp();
131                        let sleep_duration = target_timestamp - now_timestamp;
132                        if sleep_duration <= 0 {
133                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
134                            continue;
135                        }
136
137                        tokio::time::sleep(std::time::Duration::from_secs(sleep_duration as u64))
138                            .await;
139
140                        if let Some(task) = tasks.write().await.get_mut(name) {
141                            task.last_execution = std::time::Instant::now();
142                        }
143
144                        tracing::debug!(name, "running background task function");
145                        let result = AssertUnwindSafe(r#fn(state.clone())).catch_unwind().await;
146
147                        let result = match result {
148                            Ok(result) => result,
149                            Err(err) => {
150                                let err_msg: Cow<'_, str> =
151                                    if let Some(s) = err.downcast_ref::<&str>() {
152                                        (*s).into()
153                                    } else if let Some(s) = err.downcast_ref::<String>() {
154                                        s.clone().into()
155                                    } else {
156                                        "Unknown panic".into()
157                                    };
158
159                                tracing::error!(name, "background task panicked: {}", err_msg);
160                                sentry::capture_message(
161                                    &format!("Background task '{}' panicked: {}", name, err_msg),
162                                    sentry::Level::Error,
163                                );
164
165                                if let Some(task) = tasks.write().await.get_mut(name) {
166                                    task.last_error = Some(anyhow::anyhow!(err_msg));
167                                }
168
169                                return;
170                            }
171                        };
172
173                        if let Err(err) = &result {
174                            tracing::error!(name, "a background task error occurred: {:?}", err);
175                            sentry_anyhow::capture_anyhow(err);
176                        }
177
178                        if let Some(task) = tasks.write().await.get_mut(name) {
179                            task.last_error = result.err();
180                        }
181                    }
182                }),
183            },
184        );
185    }
186}
187
188#[derive(Default)]
189pub struct BackgroundTaskManager {
190    builder: RwLock<Option<BackgroundTaskBuilder>>,
191}
192
193impl BackgroundTaskManager {
194    pub async fn merge_builder(&self, builder: BackgroundTaskBuilder) {
195        self.builder.write().await.replace(builder);
196    }
197
198    pub async fn get_tasks(&self) -> OwnedRwLockReadGuard<HashMap<&'static str, BackgroundTask>> {
199        let inner = self.builder.read().await;
200
201        inner.as_ref().unwrap().tasks.clone().read_owned().await
202    }
203}