Skip to main content

shared/
database.rs

1use base64::Engine;
2use sha2::Digest;
3use sqlx::postgres::PgPoolOptions;
4use std::{collections::HashMap, fmt::Display, pin::Pin, sync::Arc};
5use tokio::sync::Mutex;
6
7pub static BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
8    &base64::alphabet::STANDARD,
9    base64::engine::GeneralPurposeConfig::new()
10        .with_decode_allow_trailing_bits(true)
11        .with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
12);
13
14type BatchFuture = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
15
16pub struct Database {
17    pub cache: Arc<crate::cache::Cache>,
18
19    write: sqlx::PgPool,
20    read: Option<sqlx::PgPool>,
21
22    encryption_key: Arc<str>,
23    use_decryption_cache: bool,
24    batch_actions: Arc<Mutex<HashMap<(&'static str, uuid::Uuid), BatchFuture>>>,
25}
26
27impl Database {
28    pub async fn new(env: &crate::env::Env, cache: Arc<crate::cache::Cache>) -> Self {
29        let start = std::time::Instant::now();
30
31        let instance = Self {
32            cache,
33
34            write: match &env.database_url_primary {
35                Some(url) => PgPoolOptions::new()
36                    .min_connections(10)
37                    .max_connections(20)
38                    .test_before_acquire(false)
39                    .connect(url)
40                    .await
41                    .unwrap(),
42
43                None => PgPoolOptions::new()
44                    .min_connections(10)
45                    .max_connections(50)
46                    .test_before_acquire(false)
47                    .connect(&env.database_url)
48                    .await
49                    .unwrap(),
50            },
51            read: if env.database_url_primary.is_some() {
52                Some(
53                    PgPoolOptions::new()
54                        .min_connections(10)
55                        .max_connections(50)
56                        .test_before_acquire(false)
57                        .connect(&env.database_url)
58                        .await
59                        .unwrap(),
60                )
61            } else {
62                None
63            },
64
65            encryption_key: env.app_encryption_key.clone().into(),
66            use_decryption_cache: env.app_use_decryption_cache,
67            batch_actions: Arc::new(Mutex::new(HashMap::new())),
68        };
69
70        let version = instance
71            .version()
72            .await
73            .unwrap_or_else(|_| "unknown".into());
74
75        tracing::info!(
76            "database connected (postgres@{}, {}ms)",
77            version,
78            start.elapsed().as_millis()
79        );
80
81        tokio::spawn({
82            let batch_actions = instance.batch_actions.clone();
83
84            async move {
85                loop {
86                    tokio::time::sleep(std::time::Duration::from_secs(5)).await;
87
88                    let mut actions = batch_actions.lock().await;
89                    for (key, action) in actions.drain() {
90                        tracing::debug!("executing batch action for {}:{}", key.0, key.1);
91                        if let Err(err) = action.await {
92                            tracing::error!(
93                                "error executing batch action for {}:{} - {:?}",
94                                key.0,
95                                key.1,
96                                err
97                            );
98                            sentry_anyhow::capture_anyhow(&err);
99                        }
100                    }
101                }
102            }
103        });
104
105        instance
106    }
107
108    pub async fn flush_batch_actions(&self) {
109        let actions = self.batch_actions.lock().await.drain().collect::<Vec<_>>();
110
111        for (key, action) in actions {
112            tracing::debug!("executing batch action for {}:{}", key.0, key.1);
113            if let Err(err) = action.await {
114                tracing::error!(
115                    "error executing batch action for {}:{} - {:?}",
116                    key.0,
117                    key.1,
118                    err
119                );
120                sentry_anyhow::capture_anyhow(&err);
121            }
122        }
123    }
124
125    pub async fn version(&self) -> Result<compact_str::CompactString, sqlx::Error> {
126        let version: (compact_str::CompactString,) =
127            sqlx::query_as("SELECT split_part(version(), ' ', 2)")
128                .fetch_one(self.read())
129                .await?;
130
131        Ok(version.0)
132    }
133
134    pub async fn size(&self) -> Result<u64, sqlx::Error> {
135        let size: (i64,) = sqlx::query_as("SELECT pg_database_size(current_database())")
136            .fetch_one(self.read())
137            .await?;
138
139        Ok(size.0 as u64)
140    }
141
142    #[inline]
143    pub fn write(&self) -> &sqlx::PgPool {
144        &self.write
145    }
146
147    #[inline]
148    pub fn read(&self) -> &sqlx::PgPool {
149        self.read.as_ref().unwrap_or(&self.write)
150    }
151
152    pub async fn encrypt(
153        &self,
154        data: impl AsRef<[u8]> + Send + 'static,
155    ) -> Result<Vec<u8>, anyhow::Error> {
156        let encryption_key = self.encryption_key.clone();
157
158        tokio::task::spawn_blocking(move || {
159            simple_crypt::encrypt(data.as_ref(), encryption_key.as_bytes())
160        })
161        .await?
162    }
163
164    pub async fn encrypt_base64(
165        &self,
166        data: impl AsRef<[u8]> + Send + 'static,
167    ) -> Result<compact_str::CompactString, anyhow::Error> {
168        let encrypted = self.encrypt(data).await?;
169        Ok(BASE64_ENGINE.encode(&encrypted).into())
170    }
171
172    #[inline]
173    pub fn blocking_encrypt(&self, data: impl AsRef<[u8]>) -> Result<Vec<u8>, anyhow::Error> {
174        simple_crypt::encrypt(data.as_ref(), self.encryption_key.as_bytes())
175    }
176
177    #[inline]
178    pub fn blocking_encrypt_base64(
179        &self,
180        data: impl AsRef<[u8]>,
181    ) -> Result<compact_str::CompactString, anyhow::Error> {
182        let encrypted = self.blocking_encrypt(data)?;
183        Ok(BASE64_ENGINE.encode(&encrypted).into())
184    }
185
186    pub async fn decrypt(
187        &self,
188        data: impl AsRef<[u8]> + Send + 'static,
189    ) -> Result<compact_str::CompactString, anyhow::Error> {
190        if self.use_decryption_cache {
191            self.cache
192                .cached(
193                    &format!(
194                        "decryption_cache::{}",
195                        hex::encode(sha2::Sha256::digest(data.as_ref()))
196                    ),
197                    30,
198                    || async {
199                        let encryption_key = self.encryption_key.clone();
200                        let data = data.as_ref().to_vec();
201
202                        tokio::task::spawn_blocking(move || {
203                            simple_crypt::decrypt(&data, encryption_key.as_bytes())
204                                .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
205                        })
206                        .await?
207                    },
208                )
209                .await
210        } else {
211            let encryption_key = self.encryption_key.clone();
212
213            tokio::task::spawn_blocking(move || {
214                simple_crypt::decrypt(data.as_ref(), encryption_key.as_bytes())
215                    .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
216            })
217            .await?
218        }
219    }
220
221    pub async fn decrypt_raw(
222        &self,
223        data: impl AsRef<[u8]> + Send + 'static,
224    ) -> Result<Vec<u8>, anyhow::Error> {
225        if self.use_decryption_cache {
226            self.cache
227                .cached(
228                    &format!(
229                        "decryption_cache::{}::raw",
230                        hex::encode(sha2::Sha256::digest(data.as_ref()))
231                    ),
232                    30,
233                    || async {
234                        let encryption_key = self.encryption_key.clone();
235                        let data = data.as_ref().to_vec();
236
237                        tokio::task::spawn_blocking(move || {
238                            simple_crypt::decrypt(&data, encryption_key.as_bytes())
239                        })
240                        .await?
241                    },
242                )
243                .await
244        } else {
245            let encryption_key = self.encryption_key.clone();
246
247            tokio::task::spawn_blocking(move || {
248                simple_crypt::decrypt(data.as_ref(), encryption_key.as_bytes())
249            })
250            .await?
251        }
252    }
253
254    pub async fn decrypt_base64(
255        &self,
256        data: impl AsRef<str>,
257    ) -> Result<compact_str::CompactString, anyhow::Error> {
258        let decoded = BASE64_ENGINE.decode(data.as_ref())?;
259        self.decrypt(decoded).await
260    }
261
262    pub async fn decrypt_base64_raw(
263        &self,
264        data: impl AsRef<str>,
265    ) -> Result<Vec<u8>, anyhow::Error> {
266        let decoded = BASE64_ENGINE.decode(data.as_ref())?;
267        self.decrypt_raw(decoded).await
268    }
269
270    pub async fn decrypt_base64_optional(
271        &self,
272        data: impl AsRef<str>,
273    ) -> Result<Option<compact_str::CompactString>, anyhow::Error> {
274        match BASE64_ENGINE.decode(data.as_ref()) {
275            Ok(decoded) => Ok(Some(self.decrypt(decoded).await?)),
276            Err(_) => Ok(None),
277        }
278    }
279
280    pub async fn decrypt_base64_raw_optional(
281        &self,
282        data: impl AsRef<str>,
283    ) -> Result<Option<Vec<u8>>, anyhow::Error> {
284        match BASE64_ENGINE.decode(data.as_ref()) {
285            Ok(decoded) => Ok(Some(self.decrypt_raw(decoded).await?)),
286            Err(_) => Ok(None),
287        }
288    }
289
290    #[inline]
291    pub fn blocking_decrypt(
292        &self,
293        data: impl AsRef<[u8]>,
294    ) -> Result<compact_str::CompactString, anyhow::Error> {
295        simple_crypt::decrypt(data.as_ref(), self.encryption_key.as_bytes())
296            .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
297    }
298
299    #[inline]
300    pub fn blocking_decrypt_raw(&self, data: impl AsRef<[u8]>) -> Result<Vec<u8>, anyhow::Error> {
301        simple_crypt::decrypt(data.as_ref(), self.encryption_key.as_bytes())
302    }
303
304    #[inline]
305    pub fn blocking_decrypt_base64(
306        &self,
307        data: impl AsRef<str>,
308    ) -> Result<compact_str::CompactString, anyhow::Error> {
309        let decoded = BASE64_ENGINE.decode(data.as_ref())?;
310        self.blocking_decrypt(decoded)
311    }
312
313    #[inline]
314    pub fn blocking_decrypt_base64_raw(
315        &self,
316        data: impl AsRef<str>,
317    ) -> Result<Vec<u8>, anyhow::Error> {
318        let decoded = BASE64_ENGINE.decode(data.as_ref())?;
319        self.blocking_decrypt_raw(decoded)
320    }
321
322    #[inline]
323    pub fn blocking_decrypt_base64_optional(
324        &self,
325        data: impl AsRef<str>,
326    ) -> Result<Option<compact_str::CompactString>, anyhow::Error> {
327        match BASE64_ENGINE.decode(data.as_ref()) {
328            Ok(decoded) => Ok(Some(self.blocking_decrypt(decoded)?)),
329            Err(_) => Ok(None),
330        }
331    }
332
333    #[inline]
334    pub fn blocking_decrypt_base64_raw_optional(
335        &self,
336        data: impl AsRef<str>,
337    ) -> Result<Option<Vec<u8>>, anyhow::Error> {
338        match BASE64_ENGINE.decode(data.as_ref()) {
339            Ok(decoded) => Ok(Some(self.blocking_decrypt_raw(decoded)?)),
340            Err(_) => Ok(None),
341        }
342    }
343
344    #[inline]
345    pub async fn batch_action(
346        &self,
347        key: &'static str,
348        uuid: uuid::Uuid,
349        action: impl Future<Output = Result<(), anyhow::Error>> + Send + 'static,
350    ) {
351        let mut actions = self.batch_actions.lock().await;
352        actions.insert((key, uuid), Box::pin(action));
353    }
354}
355
356#[derive(Debug)]
357pub enum DatabaseError {
358    Sqlx(sqlx::Error),
359    Mongodb(mongodb::error::Error),
360    Serde(serde_json::Error),
361    Any(anyhow::Error),
362    Validation(garde::Report),
363    InvalidRelation(InvalidRelationError),
364}
365
366impl Display for DatabaseError {
367    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        match self {
369            Self::Sqlx(sqlx_value) => sqlx_value.fmt(f),
370            Self::Mongodb(mongodb_value) => mongodb_value.fmt(f),
371            Self::Serde(serde_value) => serde_value.fmt(f),
372            Self::Any(any_value) => any_value.fmt(f),
373            Self::Validation(validation_value) => validation_value.fmt(f),
374            Self::InvalidRelation(relation_value) => relation_value.fmt(f),
375        }
376    }
377}
378
379impl From<wings_api::client::ApiHttpError> for DatabaseError {
380    #[inline]
381    fn from(value: wings_api::client::ApiHttpError) -> Self {
382        Self::Any(value.into())
383    }
384}
385
386impl From<anyhow::Error> for DatabaseError {
387    #[inline]
388    fn from(value: anyhow::Error) -> Self {
389        Self::Any(value)
390    }
391}
392
393impl From<serde_json::Error> for DatabaseError {
394    #[inline]
395    fn from(value: serde_json::Error) -> Self {
396        Self::Serde(value)
397    }
398}
399
400impl From<sqlx::Error> for DatabaseError {
401    #[inline]
402    fn from(value: sqlx::Error) -> Self {
403        Self::Sqlx(value)
404    }
405}
406
407impl From<mongodb::error::Error> for DatabaseError {
408    #[inline]
409    fn from(value: mongodb::error::Error) -> Self {
410        Self::Mongodb(value)
411    }
412}
413
414impl From<garde::Report> for DatabaseError {
415    #[inline]
416    fn from(value: garde::Report) -> Self {
417        Self::Validation(value)
418    }
419}
420
421impl From<InvalidRelationError> for DatabaseError {
422    fn from(value: InvalidRelationError) -> Self {
423        Self::InvalidRelation(value)
424    }
425}
426
427impl DatabaseError {
428    #[inline]
429    pub fn is_unique_violation(&self) -> bool {
430        match self {
431            Self::Sqlx(sqlx_value) => sqlx_value
432                .as_database_error()
433                .is_some_and(|e| e.is_unique_violation()),
434            _ => false,
435        }
436    }
437
438    #[inline]
439    pub fn is_foreign_key_violation(&self) -> bool {
440        match self {
441            Self::Sqlx(sqlx_value) => sqlx_value
442                .as_database_error()
443                .is_some_and(|e| e.is_foreign_key_violation()),
444            _ => false,
445        }
446    }
447
448    #[inline]
449    pub fn is_check_violation(&self) -> bool {
450        match self {
451            Self::Sqlx(sqlx_value) => sqlx_value
452                .as_database_error()
453                .is_some_and(|e| e.is_check_violation()),
454            _ => false,
455        }
456    }
457
458    #[inline]
459    pub const fn is_validation_error(&self) -> bool {
460        matches!(self, Self::Validation(_))
461    }
462
463    #[inline]
464    pub const fn is_invalid_relation(&self) -> bool {
465        matches!(self, Self::InvalidRelation(_))
466    }
467}
468
469impl std::error::Error for DatabaseError {}
470
471#[derive(Debug)]
472pub struct InvalidRelationError(pub &'static str);
473
474impl Display for InvalidRelationError {
475    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
476        write!(f, "invalid relation `{}` provided", self.0)
477    }
478}