1use sqlx::postgres::PgPoolOptions;
2use std::{collections::HashMap, fmt::Display, pin::Pin, sync::Arc};
3use tokio::sync::Mutex;
4
5type BatchFuture = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
6
7pub struct Database {
8 pub cache: Arc<crate::cache::Cache>,
9
10 write: sqlx::PgPool,
11 read: Option<sqlx::PgPool>,
12
13 encryption_key: Arc<str>,
14 use_decryption_cache: bool,
15 batch_actions: Arc<Mutex<HashMap<(&'static str, uuid::Uuid), BatchFuture>>>,
16}
17
18impl Database {
19 pub async fn new(env: &crate::env::Env, cache: Arc<crate::cache::Cache>) -> Self {
20 let start = std::time::Instant::now();
21
22 let instance = Self {
23 cache,
24
25 write: match &env.database_url_primary {
26 Some(url) => PgPoolOptions::new()
27 .min_connections(10)
28 .max_connections(20)
29 .test_before_acquire(false)
30 .connect(url)
31 .await
32 .unwrap(),
33
34 None => PgPoolOptions::new()
35 .min_connections(10)
36 .max_connections(50)
37 .test_before_acquire(false)
38 .connect(&env.database_url)
39 .await
40 .unwrap(),
41 },
42 read: if env.database_url_primary.is_some() {
43 Some(
44 PgPoolOptions::new()
45 .min_connections(10)
46 .max_connections(50)
47 .test_before_acquire(false)
48 .connect(&env.database_url)
49 .await
50 .unwrap(),
51 )
52 } else {
53 None
54 },
55
56 encryption_key: env.app_encryption_key.clone().into(),
57 use_decryption_cache: env.app_use_decryption_cache,
58 batch_actions: Arc::new(Mutex::new(HashMap::new())),
59 };
60
61 let version = instance
62 .version()
63 .await
64 .unwrap_or_else(|_| "unknown".into());
65
66 tracing::info!(
67 "database connected (postgres@{}, {}ms)",
68 version,
69 start.elapsed().as_millis()
70 );
71
72 tokio::spawn({
73 let batch_actions = instance.batch_actions.clone();
74
75 async move {
76 loop {
77 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
78
79 let mut actions = batch_actions.lock().await;
80 for (key, action) in actions.drain() {
81 tracing::debug!("executing batch action for {}:{}", key.0, key.1);
82 if let Err(err) = action.await {
83 tracing::error!(
84 "error executing batch action for {}:{} - {:?}",
85 key.0,
86 key.1,
87 err
88 );
89 sentry_anyhow::capture_anyhow(&err);
90 }
91 }
92 }
93 }
94 });
95
96 instance
97 }
98
99 pub async fn flush_batch_actions(&self) {
100 let mut actions = self.batch_actions.lock().await;
101 for (key, action) in actions.drain() {
102 tracing::debug!("executing batch action for {}:{}", key.0, key.1);
103 if let Err(err) = action.await {
104 tracing::error!(
105 "error executing batch action for {}:{} - {:?}",
106 key.0,
107 key.1,
108 err
109 );
110 sentry_anyhow::capture_anyhow(&err);
111 }
112 }
113 }
114
115 pub async fn version(&self) -> Result<compact_str::CompactString, sqlx::Error> {
116 let version: (compact_str::CompactString,) =
117 sqlx::query_as("SELECT split_part(version(), ' ', 2)")
118 .fetch_one(self.read())
119 .await?;
120
121 Ok(version.0)
122 }
123
124 pub async fn size(&self) -> Result<u64, sqlx::Error> {
125 let size: (i64,) = sqlx::query_as("SELECT pg_database_size(current_database())")
126 .fetch_one(self.read())
127 .await?;
128
129 Ok(size.0 as u64)
130 }
131
132 #[inline]
133 pub fn write(&self) -> &sqlx::PgPool {
134 &self.write
135 }
136
137 #[inline]
138 pub fn read(&self) -> &sqlx::PgPool {
139 self.read.as_ref().unwrap_or(&self.write)
140 }
141
142 pub async fn encrypt(
143 &self,
144 data: impl AsRef<[u8]> + Send + 'static,
145 ) -> Result<Vec<u8>, anyhow::Error> {
146 let encryption_key = self.encryption_key.clone();
147
148 tokio::task::spawn_blocking(move || {
149 simple_crypt::encrypt(data.as_ref(), encryption_key.as_bytes())
150 })
151 .await?
152 }
153
154 #[inline]
155 pub fn blocking_encrypt(&self, data: impl AsRef<[u8]>) -> Result<Vec<u8>, anyhow::Error> {
156 simple_crypt::encrypt(data.as_ref(), self.encryption_key.as_bytes())
157 }
158
159 pub async fn decrypt(
160 &self,
161 data: impl AsRef<[u8]> + Send + 'static,
162 ) -> Result<compact_str::CompactString, anyhow::Error> {
163 if self.use_decryption_cache {
164 self.cache
165 .cached(
166 &format!(
167 "decryption_cache::{}",
168 base32::encode(base32::Alphabet::Z, data.as_ref())
169 ),
170 30,
171 || async {
172 let encryption_key = self.encryption_key.clone();
173 let data = data.as_ref().to_vec();
174
175 tokio::task::spawn_blocking(move || {
176 simple_crypt::decrypt(&data, encryption_key.as_bytes())
177 .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
178 })
179 .await?
180 },
181 )
182 .await
183 } else {
184 let encryption_key = self.encryption_key.clone();
185
186 tokio::task::spawn_blocking(move || {
187 simple_crypt::decrypt(data.as_ref(), encryption_key.as_bytes())
188 .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
189 })
190 .await?
191 }
192 }
193
194 #[inline]
195 pub fn blocking_decrypt(
196 &self,
197 data: impl AsRef<[u8]>,
198 ) -> Result<compact_str::CompactString, anyhow::Error> {
199 simple_crypt::decrypt(data.as_ref(), self.encryption_key.as_bytes())
200 .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
201 }
202
203 #[inline]
204 pub async fn batch_action(
205 &self,
206 key: &'static str,
207 uuid: uuid::Uuid,
208 action: impl Future<Output = Result<(), anyhow::Error>> + Send + 'static,
209 ) {
210 let mut actions = self.batch_actions.lock().await;
211 actions.insert((key, uuid), Box::pin(action));
212 }
213}
214
215#[derive(Debug)]
216pub enum DatabaseError {
217 Sqlx(sqlx::Error),
218 Serde(serde_json::Error),
219 Any(anyhow::Error),
220 Validation(garde::Report),
221 InvalidRelation(InvalidRelationError),
222}
223
224impl Display for DatabaseError {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 match self {
227 Self::Sqlx(sqlx_value) => sqlx_value.fmt(f),
228 Self::Serde(serde_value) => serde_value.fmt(f),
229 Self::Any(any_value) => any_value.fmt(f),
230 Self::Validation(validation_value) => validation_value.fmt(f),
231 Self::InvalidRelation(relation_value) => relation_value.fmt(f),
232 }
233 }
234}
235
236impl From<wings_api::client::ApiHttpError> for DatabaseError {
237 #[inline]
238 fn from(value: wings_api::client::ApiHttpError) -> Self {
239 Self::Any(value.into())
240 }
241}
242
243impl From<anyhow::Error> for DatabaseError {
244 #[inline]
245 fn from(value: anyhow::Error) -> Self {
246 Self::Any(value)
247 }
248}
249
250impl From<serde_json::Error> for DatabaseError {
251 #[inline]
252 fn from(value: serde_json::Error) -> Self {
253 Self::Serde(value)
254 }
255}
256
257impl From<sqlx::Error> for DatabaseError {
258 #[inline]
259 fn from(value: sqlx::Error) -> Self {
260 Self::Sqlx(value)
261 }
262}
263
264impl From<garde::Report> for DatabaseError {
265 #[inline]
266 fn from(value: garde::Report) -> Self {
267 Self::Validation(value)
268 }
269}
270
271impl From<InvalidRelationError> for DatabaseError {
272 fn from(value: InvalidRelationError) -> Self {
273 Self::InvalidRelation(value)
274 }
275}
276
277impl DatabaseError {
278 #[inline]
279 pub fn is_unique_violation(&self) -> bool {
280 match self {
281 Self::Sqlx(sqlx_value) => sqlx_value
282 .as_database_error()
283 .is_some_and(|e| e.is_unique_violation()),
284 _ => false,
285 }
286 }
287
288 #[inline]
289 pub fn is_foreign_key_violation(&self) -> bool {
290 match self {
291 Self::Sqlx(sqlx_value) => sqlx_value
292 .as_database_error()
293 .is_some_and(|e| e.is_foreign_key_violation()),
294 _ => false,
295 }
296 }
297
298 #[inline]
299 pub fn is_check_violation(&self) -> bool {
300 match self {
301 Self::Sqlx(sqlx_value) => sqlx_value
302 .as_database_error()
303 .is_some_and(|e| e.is_check_violation()),
304 _ => false,
305 }
306 }
307
308 #[inline]
309 pub const fn is_validation_error(&self) -> bool {
310 matches!(self, Self::Validation(_))
311 }
312
313 #[inline]
314 pub const fn is_invalid_relation(&self) -> bool {
315 matches!(self, Self::InvalidRelation(_))
316 }
317}
318
319impl std::error::Error for DatabaseError {}
320
321#[derive(Debug)]
322pub struct InvalidRelationError(pub &'static str);
323
324impl Display for InvalidRelationError {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 write!(f, "invalid relation `{}` provided", self.0)
327 }
328}