1use base64::Engine;
2use sha2::Digest;
3use sqlx::postgres::PgPoolOptions;
4use std::{collections::HashMap, fmt::Display, pin::Pin, sync::Arc};
5use tokio::sync::Mutex;
6
7pub static BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
8 &base64::alphabet::STANDARD,
9 base64::engine::GeneralPurposeConfig::new()
10 .with_decode_allow_trailing_bits(true)
11 .with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
12);
13
14type BatchFuture = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
15
16pub struct Database {
17 pub cache: Arc<crate::cache::Cache>,
18
19 write: sqlx::PgPool,
20 read: Option<sqlx::PgPool>,
21
22 encryption_key: Arc<str>,
23 use_decryption_cache: bool,
24 batch_actions: Arc<Mutex<HashMap<(&'static str, uuid::Uuid), BatchFuture>>>,
25}
26
27impl Database {
28 pub async fn new(env: &crate::env::Env, cache: Arc<crate::cache::Cache>) -> Self {
29 let start = std::time::Instant::now();
30
31 let instance = Self {
32 cache,
33
34 write: match &env.database_url_primary {
35 Some(url) => PgPoolOptions::new()
36 .min_connections(10)
37 .max_connections(20)
38 .test_before_acquire(false)
39 .connect(url)
40 .await
41 .unwrap(),
42
43 None => PgPoolOptions::new()
44 .min_connections(10)
45 .max_connections(50)
46 .test_before_acquire(false)
47 .connect(&env.database_url)
48 .await
49 .unwrap(),
50 },
51 read: if env.database_url_primary.is_some() {
52 Some(
53 PgPoolOptions::new()
54 .min_connections(10)
55 .max_connections(50)
56 .test_before_acquire(false)
57 .connect(&env.database_url)
58 .await
59 .unwrap(),
60 )
61 } else {
62 None
63 },
64
65 encryption_key: env.app_encryption_key.clone().into(),
66 use_decryption_cache: env.app_use_decryption_cache,
67 batch_actions: Arc::new(Mutex::new(HashMap::new())),
68 };
69
70 let version = instance
71 .version()
72 .await
73 .unwrap_or_else(|_| "unknown".into());
74
75 tracing::info!(
76 "database connected (postgres@{}, {}ms)",
77 version,
78 start.elapsed().as_millis()
79 );
80
81 tokio::spawn({
82 let batch_actions = instance.batch_actions.clone();
83
84 async move {
85 loop {
86 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
87
88 let mut actions = batch_actions.lock().await;
89 for (key, action) in actions.drain() {
90 tracing::debug!("executing batch action for {}:{}", key.0, key.1);
91 if let Err(err) = action.await {
92 tracing::error!(
93 "error executing batch action for {}:{} - {:?}",
94 key.0,
95 key.1,
96 err
97 );
98 sentry_anyhow::capture_anyhow(&err);
99 }
100 }
101 }
102 }
103 });
104
105 instance
106 }
107
108 pub async fn flush_batch_actions(&self) {
109 let actions = self.batch_actions.lock().await.drain().collect::<Vec<_>>();
110
111 for (key, action) in actions {
112 tracing::debug!("executing batch action for {}:{}", key.0, key.1);
113 if let Err(err) = action.await {
114 tracing::error!(
115 "error executing batch action for {}:{} - {:?}",
116 key.0,
117 key.1,
118 err
119 );
120 sentry_anyhow::capture_anyhow(&err);
121 }
122 }
123 }
124
125 pub async fn version(&self) -> Result<compact_str::CompactString, sqlx::Error> {
126 let version: (compact_str::CompactString,) =
127 sqlx::query_as("SELECT split_part(version(), ' ', 2)")
128 .fetch_one(self.read())
129 .await?;
130
131 Ok(version.0)
132 }
133
134 pub async fn size(&self) -> Result<u64, sqlx::Error> {
135 let size: (i64,) = sqlx::query_as("SELECT pg_database_size(current_database())")
136 .fetch_one(self.read())
137 .await?;
138
139 Ok(size.0 as u64)
140 }
141
142 #[inline]
143 pub fn write(&self) -> &sqlx::PgPool {
144 &self.write
145 }
146
147 #[inline]
148 pub fn read(&self) -> &sqlx::PgPool {
149 self.read.as_ref().unwrap_or(&self.write)
150 }
151
152 pub async fn encrypt(
153 &self,
154 data: impl AsRef<[u8]> + Send + 'static,
155 ) -> Result<Vec<u8>, anyhow::Error> {
156 let encryption_key = self.encryption_key.clone();
157
158 tokio::task::spawn_blocking(move || {
159 simple_crypt::encrypt(data.as_ref(), encryption_key.as_bytes())
160 })
161 .await?
162 }
163
164 pub async fn encrypt_base64(
165 &self,
166 data: impl AsRef<[u8]> + Send + 'static,
167 ) -> Result<compact_str::CompactString, anyhow::Error> {
168 let encrypted = self.encrypt(data).await?;
169 Ok(BASE64_ENGINE.encode(&encrypted).into())
170 }
171
172 #[inline]
173 pub fn blocking_encrypt(&self, data: impl AsRef<[u8]>) -> Result<Vec<u8>, anyhow::Error> {
174 simple_crypt::encrypt(data.as_ref(), self.encryption_key.as_bytes())
175 }
176
177 #[inline]
178 pub fn blocking_encrypt_base64(
179 &self,
180 data: impl AsRef<[u8]>,
181 ) -> Result<compact_str::CompactString, anyhow::Error> {
182 let encrypted = self.blocking_encrypt(data)?;
183 Ok(BASE64_ENGINE.encode(&encrypted).into())
184 }
185
186 pub async fn decrypt(
187 &self,
188 data: impl AsRef<[u8]> + Send + 'static,
189 ) -> Result<compact_str::CompactString, anyhow::Error> {
190 if self.use_decryption_cache {
191 self.cache
192 .cached(
193 &format!(
194 "decryption_cache::{}",
195 hex::encode(sha2::Sha256::digest(data.as_ref()))
196 ),
197 30,
198 || async {
199 let encryption_key = self.encryption_key.clone();
200 let data = data.as_ref().to_vec();
201
202 tokio::task::spawn_blocking(move || {
203 simple_crypt::decrypt(&data, encryption_key.as_bytes())
204 .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
205 })
206 .await?
207 },
208 )
209 .await
210 } else {
211 let encryption_key = self.encryption_key.clone();
212
213 tokio::task::spawn_blocking(move || {
214 simple_crypt::decrypt(data.as_ref(), encryption_key.as_bytes())
215 .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
216 })
217 .await?
218 }
219 }
220
221 pub async fn decrypt_raw(
222 &self,
223 data: impl AsRef<[u8]> + Send + 'static,
224 ) -> Result<Vec<u8>, anyhow::Error> {
225 if self.use_decryption_cache {
226 self.cache
227 .cached(
228 &format!(
229 "decryption_cache::{}::raw",
230 hex::encode(sha2::Sha256::digest(data.as_ref()))
231 ),
232 30,
233 || async {
234 let encryption_key = self.encryption_key.clone();
235 let data = data.as_ref().to_vec();
236
237 tokio::task::spawn_blocking(move || {
238 simple_crypt::decrypt(&data, encryption_key.as_bytes())
239 })
240 .await?
241 },
242 )
243 .await
244 } else {
245 let encryption_key = self.encryption_key.clone();
246
247 tokio::task::spawn_blocking(move || {
248 simple_crypt::decrypt(data.as_ref(), encryption_key.as_bytes())
249 })
250 .await?
251 }
252 }
253
254 pub async fn decrypt_base64(
255 &self,
256 data: impl AsRef<str>,
257 ) -> Result<compact_str::CompactString, anyhow::Error> {
258 let decoded = BASE64_ENGINE.decode(data.as_ref())?;
259 self.decrypt(decoded).await
260 }
261
262 pub async fn decrypt_base64_raw(
263 &self,
264 data: impl AsRef<str>,
265 ) -> Result<Vec<u8>, anyhow::Error> {
266 let decoded = BASE64_ENGINE.decode(data.as_ref())?;
267 self.decrypt_raw(decoded).await
268 }
269
270 pub async fn decrypt_base64_optional(
271 &self,
272 data: impl AsRef<str>,
273 ) -> Result<Option<compact_str::CompactString>, anyhow::Error> {
274 match BASE64_ENGINE.decode(data.as_ref()) {
275 Ok(decoded) => Ok(Some(self.decrypt(decoded).await?)),
276 Err(_) => Ok(None),
277 }
278 }
279
280 pub async fn decrypt_base64_raw_optional(
281 &self,
282 data: impl AsRef<str>,
283 ) -> Result<Option<Vec<u8>>, anyhow::Error> {
284 match BASE64_ENGINE.decode(data.as_ref()) {
285 Ok(decoded) => Ok(Some(self.decrypt_raw(decoded).await?)),
286 Err(_) => Ok(None),
287 }
288 }
289
290 #[inline]
291 pub fn blocking_decrypt(
292 &self,
293 data: impl AsRef<[u8]>,
294 ) -> Result<compact_str::CompactString, anyhow::Error> {
295 simple_crypt::decrypt(data.as_ref(), self.encryption_key.as_bytes())
296 .map(|s| compact_str::CompactString::from_utf8_lossy(&s))
297 }
298
299 #[inline]
300 pub fn blocking_decrypt_raw(&self, data: impl AsRef<[u8]>) -> Result<Vec<u8>, anyhow::Error> {
301 simple_crypt::decrypt(data.as_ref(), self.encryption_key.as_bytes())
302 }
303
304 #[inline]
305 pub fn blocking_decrypt_base64(
306 &self,
307 data: impl AsRef<str>,
308 ) -> Result<compact_str::CompactString, anyhow::Error> {
309 let decoded = BASE64_ENGINE.decode(data.as_ref())?;
310 self.blocking_decrypt(decoded)
311 }
312
313 #[inline]
314 pub fn blocking_decrypt_base64_raw(
315 &self,
316 data: impl AsRef<str>,
317 ) -> Result<Vec<u8>, anyhow::Error> {
318 let decoded = BASE64_ENGINE.decode(data.as_ref())?;
319 self.blocking_decrypt_raw(decoded)
320 }
321
322 #[inline]
323 pub fn blocking_decrypt_base64_optional(
324 &self,
325 data: impl AsRef<str>,
326 ) -> Result<Option<compact_str::CompactString>, anyhow::Error> {
327 match BASE64_ENGINE.decode(data.as_ref()) {
328 Ok(decoded) => Ok(Some(self.blocking_decrypt(decoded)?)),
329 Err(_) => Ok(None),
330 }
331 }
332
333 #[inline]
334 pub fn blocking_decrypt_base64_raw_optional(
335 &self,
336 data: impl AsRef<str>,
337 ) -> Result<Option<Vec<u8>>, anyhow::Error> {
338 match BASE64_ENGINE.decode(data.as_ref()) {
339 Ok(decoded) => Ok(Some(self.blocking_decrypt_raw(decoded)?)),
340 Err(_) => Ok(None),
341 }
342 }
343
344 #[inline]
345 pub async fn batch_action(
346 &self,
347 key: &'static str,
348 uuid: uuid::Uuid,
349 action: impl Future<Output = Result<(), anyhow::Error>> + Send + 'static,
350 ) {
351 let mut actions = self.batch_actions.lock().await;
352 actions.insert((key, uuid), Box::pin(action));
353 }
354}
355
356#[derive(Debug)]
357pub enum DatabaseError {
358 Sqlx(sqlx::Error),
359 Mongodb(mongodb::error::Error),
360 Serde(serde_json::Error),
361 Any(anyhow::Error),
362 Validation(garde::Report),
363 InvalidRelation(InvalidRelationError),
364}
365
366impl Display for DatabaseError {
367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 match self {
369 Self::Sqlx(sqlx_value) => sqlx_value.fmt(f),
370 Self::Mongodb(mongodb_value) => mongodb_value.fmt(f),
371 Self::Serde(serde_value) => serde_value.fmt(f),
372 Self::Any(any_value) => any_value.fmt(f),
373 Self::Validation(validation_value) => validation_value.fmt(f),
374 Self::InvalidRelation(relation_value) => relation_value.fmt(f),
375 }
376 }
377}
378
379impl From<wings_api::client::ApiHttpError> for DatabaseError {
380 #[inline]
381 fn from(value: wings_api::client::ApiHttpError) -> Self {
382 Self::Any(value.into())
383 }
384}
385
386impl From<anyhow::Error> for DatabaseError {
387 #[inline]
388 fn from(value: anyhow::Error) -> Self {
389 Self::Any(value)
390 }
391}
392
393impl From<serde_json::Error> for DatabaseError {
394 #[inline]
395 fn from(value: serde_json::Error) -> Self {
396 Self::Serde(value)
397 }
398}
399
400impl From<sqlx::Error> for DatabaseError {
401 #[inline]
402 fn from(value: sqlx::Error) -> Self {
403 Self::Sqlx(value)
404 }
405}
406
407impl From<mongodb::error::Error> for DatabaseError {
408 #[inline]
409 fn from(value: mongodb::error::Error) -> Self {
410 Self::Mongodb(value)
411 }
412}
413
414impl From<garde::Report> for DatabaseError {
415 #[inline]
416 fn from(value: garde::Report) -> Self {
417 Self::Validation(value)
418 }
419}
420
421impl From<InvalidRelationError> for DatabaseError {
422 fn from(value: InvalidRelationError) -> Self {
423 Self::InvalidRelation(value)
424 }
425}
426
427impl DatabaseError {
428 #[inline]
429 pub fn is_unique_violation(&self) -> bool {
430 match self {
431 Self::Sqlx(sqlx_value) => sqlx_value
432 .as_database_error()
433 .is_some_and(|e| e.is_unique_violation()),
434 _ => false,
435 }
436 }
437
438 #[inline]
439 pub fn is_foreign_key_violation(&self) -> bool {
440 match self {
441 Self::Sqlx(sqlx_value) => sqlx_value
442 .as_database_error()
443 .is_some_and(|e| e.is_foreign_key_violation()),
444 _ => false,
445 }
446 }
447
448 #[inline]
449 pub fn is_check_violation(&self) -> bool {
450 match self {
451 Self::Sqlx(sqlx_value) => sqlx_value
452 .as_database_error()
453 .is_some_and(|e| e.is_check_violation()),
454 _ => false,
455 }
456 }
457
458 #[inline]
459 pub const fn is_validation_error(&self) -> bool {
460 matches!(self, Self::Validation(_))
461 }
462
463 #[inline]
464 pub const fn is_invalid_relation(&self) -> bool {
465 matches!(self, Self::InvalidRelation(_))
466 }
467}
468
469impl std::error::Error for DatabaseError {}
470
471#[derive(Debug)]
472pub struct InvalidRelationError(pub &'static str);
473
474impl Display for InvalidRelationError {
475 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
476 write!(f, "invalid relation `{}` provided", self.0)
477 }
478}