shared/extensions/
background_tasks.rs1use crate::State;
2use futures_util::FutureExt;
3use std::{borrow::Cow, collections::HashMap, panic::AssertUnwindSafe, sync::Arc};
4use tokio::sync::{OwnedRwLockReadGuard, RwLock};
5
6pub struct BackgroundTask {
7 pub name: &'static str,
8 pub last_execution: std::time::Instant,
9 pub last_error: Option<anyhow::Error>,
10
11 pub task: tokio::task::JoinHandle<()>,
12}
13
14pub struct BackgroundTaskBuilder {
15 state: State,
16 tasks: Arc<RwLock<HashMap<&'static str, BackgroundTask>>>,
17}
18
19impl BackgroundTaskBuilder {
20 pub fn new(state: State) -> Self {
21 Self {
22 state,
23 tasks: Arc::new(RwLock::new(HashMap::new())),
24 }
25 }
26
27 pub async fn add_task<
30 F: Fn(State) -> Fut + Send + Sync + 'static,
31 Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
32 >(
33 &self,
34 name: &'static str,
35 loop_fn: F,
36 ) {
37 if !self.state.env.app_primary {
38 return;
39 }
40
41 let state = self.state.clone();
42 let tasks = Arc::clone(&self.tasks);
43
44 self.tasks.write().await.insert(
45 name,
46 BackgroundTask {
47 name,
48 last_execution: std::time::Instant::now(),
49 last_error: None,
50 task: tokio::spawn(async move {
51 loop {
52 if let Some(task) = tasks.write().await.get_mut(name) {
53 task.last_execution = std::time::Instant::now();
54 }
55
56 tracing::debug!(name, "running background task function");
57 let result = AssertUnwindSafe(loop_fn(state.clone()))
58 .catch_unwind()
59 .await;
60
61 let result = match result {
62 Ok(result) => result,
63 Err(err) => {
64 let err_msg: Cow<'_, str> =
65 if let Some(s) = err.downcast_ref::<&str>() {
66 (*s).into()
67 } else if let Some(s) = err.downcast_ref::<String>() {
68 s.clone().into()
69 } else {
70 "Unknown panic".into()
71 };
72
73 tracing::error!(name, "background task panicked: {}", err_msg);
74 sentry::capture_message(
75 &format!("Background task '{}' panicked: {}", name, err_msg),
76 sentry::Level::Error,
77 );
78
79 if let Some(task) = tasks.write().await.get_mut(name) {
80 task.last_error = Some(anyhow::anyhow!(err_msg));
81 }
82
83 return;
84 }
85 };
86
87 if let Err(err) = &result {
88 tracing::error!(name, "a background task error occurred: {:?}", err);
89 sentry_anyhow::capture_anyhow(err);
90 }
91
92 if let Some(task) = tasks.write().await.get_mut(name) {
93 task.last_error = result.err();
94 }
95 }
96 }),
97 },
98 );
99 }
100
101 pub async fn add_cron_task<
104 F: Fn(State) -> Fut + Send + Sync + 'static,
105 Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
106 >(
107 &self,
108 name: &'static str,
109 cron: cron::Schedule,
110 r#fn: F,
111 ) {
112 if !self.state.env.app_primary {
113 return;
114 }
115
116 let state = self.state.clone();
117 let tasks = Arc::clone(&self.tasks);
118
119 self.tasks.write().await.insert(
120 name,
121 BackgroundTask {
122 name,
123 last_execution: std::time::Instant::now(),
124 last_error: None,
125 task: tokio::spawn(async move {
126 let schedule_iter = cron.upcoming(chrono::Utc);
127
128 for target_datetime in schedule_iter {
129 let target_timestamp = target_datetime.timestamp();
130 let now_timestamp = chrono::Utc::now().timestamp();
131 let sleep_duration = target_timestamp - now_timestamp;
132 if sleep_duration <= 0 {
133 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
134 continue;
135 }
136
137 tokio::time::sleep(std::time::Duration::from_secs(sleep_duration as u64))
138 .await;
139
140 if let Some(task) = tasks.write().await.get_mut(name) {
141 task.last_execution = std::time::Instant::now();
142 }
143
144 tracing::debug!(name, "running background task function");
145 let result = AssertUnwindSafe(r#fn(state.clone())).catch_unwind().await;
146
147 let result = match result {
148 Ok(result) => result,
149 Err(err) => {
150 let err_msg: Cow<'_, str> =
151 if let Some(s) = err.downcast_ref::<&str>() {
152 (*s).into()
153 } else if let Some(s) = err.downcast_ref::<String>() {
154 s.clone().into()
155 } else {
156 "Unknown panic".into()
157 };
158
159 tracing::error!(name, "background task panicked: {}", err_msg);
160 sentry::capture_message(
161 &format!("Background task '{}' panicked: {}", name, err_msg),
162 sentry::Level::Error,
163 );
164
165 if let Some(task) = tasks.write().await.get_mut(name) {
166 task.last_error = Some(anyhow::anyhow!(err_msg));
167 }
168
169 return;
170 }
171 };
172
173 if let Err(err) = &result {
174 tracing::error!(name, "a background task error occurred: {:?}", err);
175 sentry_anyhow::capture_anyhow(err);
176 }
177
178 if let Some(task) = tasks.write().await.get_mut(name) {
179 task.last_error = result.err();
180 }
181 }
182 }),
183 },
184 );
185 }
186}
187
188#[derive(Default)]
189pub struct BackgroundTaskManager {
190 builder: RwLock<Option<BackgroundTaskBuilder>>,
191}
192
193impl BackgroundTaskManager {
194 pub async fn merge_builder(&self, builder: BackgroundTaskBuilder) {
195 self.builder.write().await.replace(builder);
196 }
197
198 pub async fn get_tasks(&self) -> OwnedRwLockReadGuard<HashMap<&'static str, BackgroundTask>> {
199 let inner = self.builder.read().await;
200
201 inner.as_ref().unwrap().tasks.clone().read_owned().await
202 }
203}