shared/
database.rs

1use sqlx::postgres::PgPoolOptions;
2use std::{collections::HashMap, fmt::Display, pin::Pin, sync::Arc};
3use tokio::sync::Mutex;
4
5type BatchFuture = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
6
7pub struct Database {
8    pub cache: Arc<crate::cache::Cache>,
9
10    write: sqlx::PgPool,
11    read: Option<sqlx::PgPool>,
12
13    encryption_key: Arc<str>,
14    use_decryption_cache: bool,
15    batch_actions: Arc<Mutex<HashMap<(&'static str, uuid::Uuid), BatchFuture>>>,
16}
17
18impl Database {
19    pub async fn new(env: &crate::env::Env, cache: Arc<crate::cache::Cache>) -> Self {
20        let start = std::time::Instant::now();
21
22        let instance = Self {
23            cache,
24
25            write: match &env.database_url_primary {
26                Some(url) => PgPoolOptions::new()
27                    .min_connections(10)
28                    .max_connections(20)
29                    .test_before_acquire(false)
30                    .connect(url)
31                    .await
32                    .unwrap(),
33
34                None => PgPoolOptions::new()
35                    .min_connections(10)
36                    .max_connections(50)
37                    .test_before_acquire(false)
38                    .connect(&env.database_url)
39                    .await
40                    .unwrap(),
41            },
42            read: if env.database_url_primary.is_some() {
43                Some(
44                    PgPoolOptions::new()
45                        .min_connections(10)
46                        .max_connections(50)
47                        .test_before_acquire(false)
48                        .connect(&env.database_url)
49                        .await
50                        .unwrap(),
51                )
52            } else {
53                None
54            },
55
56            encryption_key: env.app_encryption_key.clone().into(),
57            use_decryption_cache: env.app_use_decryption_cache,
58            batch_actions: Arc::new(Mutex::new(HashMap::new())),
59        };
60
61        let version = instance
62            .version()
63            .await
64            .unwrap_or_else(|_| "unknown".into());
65
66        tracing::info!(
67            "database connected (postgres@{}, {}ms)",
68            version,
69            start.elapsed().as_millis()
70        );
71
72        tokio::spawn({
73            let batch_actions = instance.batch_actions.clone();
74
75            async move {
76                loop {
77                    tokio::time::sleep(std::time::Duration::from_secs(5)).await;
78
79                    let mut actions = batch_actions.lock().await;
80                    for (key, action) in actions.drain() {
81                        tracing::debug!("executing batch action for {}:{}", key.0, key.1);
82                        if let Err(err) = action.await {
83                            tracing::error!(
84                                "error executing batch action for {}:{} - {:?}",
85                                key.0,
86                                key.1,
87                                err
88                            );
89                            sentry_anyhow::capture_anyhow(&err);
90                        }
91                    }
92                }
93            }
94        });
95
96        instance
97    }
98
99    pub async fn flush_batch_actions(&self) {
100        let mut actions = self.batch_actions.lock().await;
101        for (key, action) in actions.drain() {
102            tracing::debug!("executing batch action for {}:{}", key.0, key.1);
103            if let Err(err) = action.await {
104                tracing::error!(
105                    "error executing batch action for {}:{} - {:?}",
106                    key.0,
107                    key.1,
108                    err
109                );
110                sentry_anyhow::capture_anyhow(&err);
111            }
112        }
113    }
114
115    pub async fn version(&self) -> Result<compact_str::CompactString, sqlx::Error> {
116        let version: (compact_str::CompactString,) =
117            sqlx::query_as("SELECT split_part(version(), ' ', 2)")
118                .fetch_one(self.read())
119                .await?;
120
121        Ok(version.0)
122    }
123
124    pub async fn size(&self) -> Result<u64, sqlx::Error> {
125        let size: (i64,) = sqlx::query_as("SELECT pg_database_size(current_database())")
126            .fetch_one(self.read())
127            .await?;
128
129        Ok(size.0 as u64)
130    }
131
132    #[inline]
133    pub fn write(&self) -> &sqlx::PgPool {
134        &self.write
135    }
136
137    #[inline]
138    pub fn read(&self) -> &sqlx::PgPool {
139        self.read.as_ref().unwrap_or(&self.write)
140    }
141
142    pub async fn encrypt(
143        &self,
144        data: impl AsRef<[u8]> + Send + 'static,
145    ) -> Result<Vec<u8>, anyhow::Error> {
146        let encryption_key = self.encryption_key.clone();
147
148        tokio::task::spawn_blocking(move || {
149            simple_crypt::encrypt(data.as_ref(), encryption_key.as_bytes())
150        })
151        .await?
152    }
153
154    #[inline]
155    pub fn blocking_encrypt(&self, data: impl AsRef<[u8]>) -> Result<Vec<u8>, anyhow::Error> {
156        simple_crypt::encrypt(data.as_ref(), self.encryption_key.as_bytes())
157    }
158
159    pub async fn decrypt(
160        &self,
161        data: impl AsRef<[u8]> + Send + 'static,
162    ) -> Result<compact_str::CompactString, anyhow::Error> {
163        if self.use_decryption_cache {
164            self.cache
165                .cached(
166                    &format!(
167                        "decryption_cache::{}",
168                        base32::encode(base32::Alphabet::Z, data.as_ref())
169                    ),
170                    30,
171                    || async {
172                        let encryption_key = self.encryption_key.clone();
173                        let data = data.as_ref().to_vec();
174
175                        tokio::task::spawn_blocking(move || {
176                            simple_crypt::decrypt(&data, encryption_key.as_bytes())
177                                .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
178                        })
179                        .await?
180                    },
181                )
182                .await
183        } else {
184            let encryption_key = self.encryption_key.clone();
185
186            tokio::task::spawn_blocking(move || {
187                simple_crypt::decrypt(data.as_ref(), encryption_key.as_bytes())
188                    .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
189            })
190            .await?
191        }
192    }
193
194    #[inline]
195    pub fn blocking_decrypt(
196        &self,
197        data: impl AsRef<[u8]>,
198    ) -> Result<compact_str::CompactString, anyhow::Error> {
199        simple_crypt::decrypt(data.as_ref(), self.encryption_key.as_bytes())
200            .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
201    }
202
203    #[inline]
204    pub async fn batch_action(
205        &self,
206        key: &'static str,
207        uuid: uuid::Uuid,
208        action: impl Future<Output = Result<(), anyhow::Error>> + Send + 'static,
209    ) {
210        let mut actions = self.batch_actions.lock().await;
211        actions.insert((key, uuid), Box::pin(action));
212    }
213}
214
215#[derive(Debug)]
216pub enum DatabaseError {
217    Sqlx(sqlx::Error),
218    Serde(serde_json::Error),
219    Any(anyhow::Error),
220    Validation(garde::Report),
221    InvalidRelation(InvalidRelationError),
222}
223
224impl Display for DatabaseError {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        match self {
227            Self::Sqlx(sqlx_value) => sqlx_value.fmt(f),
228            Self::Serde(serde_value) => serde_value.fmt(f),
229            Self::Any(any_value) => any_value.fmt(f),
230            Self::Validation(validation_value) => validation_value.fmt(f),
231            Self::InvalidRelation(relation_value) => relation_value.fmt(f),
232        }
233    }
234}
235
236impl From<wings_api::client::ApiHttpError> for DatabaseError {
237    #[inline]
238    fn from(value: wings_api::client::ApiHttpError) -> Self {
239        Self::Any(value.into())
240    }
241}
242
243impl From<anyhow::Error> for DatabaseError {
244    #[inline]
245    fn from(value: anyhow::Error) -> Self {
246        Self::Any(value)
247    }
248}
249
250impl From<serde_json::Error> for DatabaseError {
251    #[inline]
252    fn from(value: serde_json::Error) -> Self {
253        Self::Serde(value)
254    }
255}
256
257impl From<sqlx::Error> for DatabaseError {
258    #[inline]
259    fn from(value: sqlx::Error) -> Self {
260        Self::Sqlx(value)
261    }
262}
263
264impl From<garde::Report> for DatabaseError {
265    #[inline]
266    fn from(value: garde::Report) -> Self {
267        Self::Validation(value)
268    }
269}
270
271impl From<InvalidRelationError> for DatabaseError {
272    fn from(value: InvalidRelationError) -> Self {
273        Self::InvalidRelation(value)
274    }
275}
276
277impl DatabaseError {
278    #[inline]
279    pub fn is_unique_violation(&self) -> bool {
280        match self {
281            Self::Sqlx(sqlx_value) => sqlx_value
282                .as_database_error()
283                .is_some_and(|e| e.is_unique_violation()),
284            _ => false,
285        }
286    }
287
288    #[inline]
289    pub fn is_foreign_key_violation(&self) -> bool {
290        match self {
291            Self::Sqlx(sqlx_value) => sqlx_value
292                .as_database_error()
293                .is_some_and(|e| e.is_foreign_key_violation()),
294            _ => false,
295        }
296    }
297
298    #[inline]
299    pub fn is_check_violation(&self) -> bool {
300        match self {
301            Self::Sqlx(sqlx_value) => sqlx_value
302                .as_database_error()
303                .is_some_and(|e| e.is_check_violation()),
304            _ => false,
305        }
306    }
307
308    #[inline]
309    pub const fn is_validation_error(&self) -> bool {
310        matches!(self, Self::Validation(_))
311    }
312
313    #[inline]
314    pub const fn is_invalid_relation(&self) -> bool {
315        matches!(self, Self::InvalidRelation(_))
316    }
317}
318
319impl std::error::Error for DatabaseError {}
320
321#[derive(Debug)]
322pub struct InvalidRelationError(pub &'static str);
323
324impl Display for InvalidRelationError {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        write!(f, "invalid relation `{}` provided", self.0)
327    }
328}