shared/extensions/
background_tasks.rs1use crate::State;
2use futures_util::FutureExt;
3use std::{borrow::Cow, collections::HashMap, panic::AssertUnwindSafe, sync::Arc};
4use tokio::sync::{OwnedRwLockReadGuard, RwLock};
5
6pub struct BackgroundTask {
7 pub name: &'static str,
8 pub last_execution: std::time::Instant,
9 pub last_error: Option<anyhow::Error>,
10
11 pub task: tokio::task::JoinHandle<()>,
12}
13
14pub struct BackgroundTaskBuilder {
15 state: State,
16 tasks: Arc<RwLock<HashMap<&'static str, BackgroundTask>>>,
17}
18
19impl BackgroundTaskBuilder {
20 pub fn new(state: State) -> Self {
21 Self {
22 state,
23 tasks: Arc::new(RwLock::new(HashMap::new())),
24 }
25 }
26
27 pub async fn add_task<
30 F: Fn(State) -> Fut + Send + Sync + 'static,
31 Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
32 >(
33 &self,
34 name: &'static str,
35 loop_fn: F,
36 ) {
37 if !self.state.env.app_primary {
38 return;
39 }
40
41 let state = self.state.clone();
42 let tasks = Arc::clone(&self.tasks);
43
44 self.tasks.write().await.insert(
45 name,
46 BackgroundTask {
47 name,
48 last_execution: std::time::Instant::now(),
49 last_error: None,
50 task: tokio::spawn(async move {
51 loop {
52 if let Some(task) = tasks.write().await.get_mut(name) {
53 task.last_execution = std::time::Instant::now();
54 }
55
56 tracing::debug!(name, "running background task loop function");
57 let result = AssertUnwindSafe(loop_fn(state.clone()))
58 .catch_unwind()
59 .await;
60
61 let result = match result {
62 Ok(result) => result,
63 Err(err) => {
64 let err_msg: Cow<'_, str> =
65 if let Some(s) = err.downcast_ref::<&str>() {
66 (*s).into()
67 } else if let Some(s) = err.downcast_ref::<String>() {
68 s.clone().into()
69 } else {
70 "Unknown panic".into()
71 };
72
73 tracing::error!(name, "background task panicked: {}", err_msg);
74 sentry::capture_message(
75 &format!("Background task '{}' panicked: {}", name, err_msg),
76 sentry::Level::Error,
77 );
78
79 if let Some(task) = tasks.write().await.get_mut(name) {
80 task.last_error = Some(anyhow::anyhow!(err_msg));
81 }
82
83 return;
84 }
85 };
86
87 if let Err(err) = &result {
88 tracing::error!(name, "a background task error occurred: {:?}", err);
89 sentry_anyhow::capture_anyhow(err);
90 }
91
92 if let Some(task) = tasks.write().await.get_mut(name) {
93 task.last_error = result.err();
94 }
95 }
96 }),
97 },
98 );
99 }
100}
101
102#[derive(Default)]
103pub struct BackgroundTaskManager {
104 builder: RwLock<Option<BackgroundTaskBuilder>>,
105}
106
107impl BackgroundTaskManager {
108 pub async fn merge_builder(&self, builder: BackgroundTaskBuilder) {
109 self.builder.write().await.replace(builder);
110 }
111
112 pub async fn get_tasks(&self) -> OwnedRwLockReadGuard<HashMap<&'static str, BackgroundTask>> {
113 let inner = self.builder.read().await;
114
115 inner.as_ref().unwrap().tasks.clone().read_owned().await
116 }
117}