shared/models/
mod.rs

1use crate::database::DatabaseError;
2use compact_str::CompactStringExt;
3use futures_util::{StreamExt, TryStreamExt};
4use garde::Validate;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use sqlx::{
7    Arguments, Postgres, QueryBuilder, Row,
8    postgres::{PgArguments, PgRow},
9};
10use std::{
11    collections::{BTreeMap, HashSet},
12    marker::PhantomData,
13    pin::Pin,
14    sync::{Arc, LazyLock},
15};
16use tokio::sync::RwLock;
17use utoipa::ToSchema;
18
19pub mod admin_activity;
20pub mod backup_configuration;
21pub mod database_host;
22pub mod egg_repository;
23pub mod egg_repository_egg;
24pub mod location;
25pub mod location_database_host;
26pub mod mount;
27pub mod nest;
28pub mod nest_egg;
29pub mod nest_egg_mount;
30pub mod nest_egg_variable;
31pub mod node;
32pub mod node_allocation;
33pub mod node_mount;
34pub mod oauth_provider;
35pub mod role;
36pub mod server;
37pub mod server_activity;
38pub mod server_allocation;
39pub mod server_backup;
40pub mod server_database;
41pub mod server_mount;
42pub mod server_schedule;
43pub mod server_schedule_step;
44pub mod server_subuser;
45pub mod server_variable;
46pub mod user;
47pub mod user_activity;
48pub mod user_api_key;
49pub mod user_command_snippet;
50pub mod user_oauth_link;
51pub mod user_password_reset;
52pub mod user_recovery_code;
53pub mod user_security_key;
54pub mod user_server_group;
55pub mod user_session;
56pub mod user_ssh_key;
57
58#[derive(ToSchema, Validate, Deserialize)]
59pub struct PaginationParams {
60    #[garde(range(min = 1))]
61    #[schema(minimum = 1)]
62    #[serde(default = "Pagination::default_page")]
63    pub page: i64,
64    #[garde(range(min = 1, max = 100))]
65    #[schema(minimum = 1, maximum = 100)]
66    #[serde(default = "Pagination::default_per_page")]
67    pub per_page: i64,
68}
69
70#[derive(ToSchema, Validate, Deserialize)]
71pub struct PaginationParamsWithSearch {
72    #[garde(range(min = 1))]
73    #[schema(minimum = 1)]
74    #[serde(default = "Pagination::default_page")]
75    pub page: i64,
76    #[garde(range(min = 1, max = 100))]
77    #[schema(minimum = 1, maximum = 100)]
78    #[serde(default = "Pagination::default_per_page")]
79    pub per_page: i64,
80    #[garde(length(chars, min = 1, max = 128))]
81    #[schema(min_length = 1, max_length = 128)]
82    #[serde(
83        default,
84        deserialize_with = "crate::deserialize::deserialize_string_option"
85    )]
86    pub search: Option<compact_str::CompactString>,
87}
88
89#[derive(ToSchema, Serialize)]
90pub struct Pagination<T: Serialize = serde_json::Value> {
91    pub total: i64,
92    pub per_page: i64,
93    pub page: i64,
94
95    pub data: Vec<T>,
96}
97
98impl Pagination {
99    #[inline]
100    pub const fn default_page() -> i64 {
101        1
102    }
103
104    #[inline]
105    pub const fn default_per_page() -> i64 {
106        25
107    }
108}
109
110impl<T: Serialize> Pagination<T> {
111    pub async fn async_map<R: serde::Serialize, Fut: Future<Output = R>>(
112        self,
113        mapper: impl Fn(T) -> Fut,
114    ) -> Pagination<R> {
115        let mut results = Vec::new();
116        results.reserve_exact(self.data.len());
117        let mut result_stream =
118            futures_util::stream::iter(self.data.into_iter().map(mapper)).buffered(25);
119
120        while let Some(result) = result_stream.next().await {
121            results.push(result);
122        }
123
124        Pagination {
125            total: self.total,
126            per_page: self.per_page,
127            page: self.page,
128            data: results,
129        }
130    }
131
132    pub async fn try_async_map<R: serde::Serialize, E, Fut: Future<Output = Result<R, E>>>(
133        self,
134        mapper: impl Fn(T) -> Fut,
135    ) -> Result<Pagination<R>, E> {
136        let mut results = Vec::new();
137        results.reserve_exact(self.data.len());
138        let mut result_stream =
139            futures_util::stream::iter(self.data.into_iter().map(mapper)).buffered(25);
140
141        while let Some(result) = result_stream.try_next().await? {
142            results.push(result);
143        }
144
145        Ok(Pagination {
146            total: self.total,
147            per_page: self.per_page,
148            page: self.page,
149            data: results,
150        })
151    }
152}
153
154pub trait BaseModel: Serialize + DeserializeOwned {
155    const NAME: &'static str;
156
157    fn columns(prefix: Option<&str>) -> BTreeMap<&'static str, compact_str::CompactString>;
158
159    #[inline]
160    fn columns_sql(prefix: Option<&str>) -> compact_str::CompactString {
161        Self::columns(prefix)
162            .iter()
163            .map(|(key, value)| compact_str::format_compact!("{key} as {value}"))
164            .join_compact(", ")
165    }
166
167    fn map(prefix: Option<&str>, row: &PgRow) -> Result<Self, crate::database::DatabaseError>;
168}
169
170#[async_trait::async_trait]
171pub trait EventEmittingModel: BaseModel {
172    type Event: Send + Sync + 'static;
173
174    fn get_event_emitter() -> &'static crate::events::EventEmitter<Self::Event>;
175
176    async fn register_event_handler<
177        F: Fn(crate::State, Arc<Self::Event>) -> Fut + Send + Sync + 'static,
178        Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
179    >(
180        listener: F,
181    ) -> crate::events::EventHandlerHandle {
182        Self::get_event_emitter()
183            .register_event_handler(listener)
184            .await
185    }
186
187    /// # Warning
188    /// This method will block the current thread if the lock is not available
189    fn blocking_register_event_handler<
190        F: Fn(crate::State, Arc<Self::Event>) -> Fut + Send + Sync + 'static,
191        Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
192    >(
193        listener: F,
194    ) -> crate::events::EventHandlerHandle {
195        Self::get_event_emitter().blocking_register_event_handler(listener)
196    }
197}
198
199type CreateListenerResult<'a> =
200    Pin<Box<dyn Future<Output = Result<(), crate::database::DatabaseError>> + Send + 'a>>;
201type CreateListener<M> = dyn for<'a> Fn(
202        &'a mut <M as CreatableModel>::CreateOptions<'_>,
203        &'a mut InsertQueryBuilder,
204        &'a crate::State,
205        &'a mut sqlx::Transaction<'_, sqlx::Postgres>,
206    ) -> CreateListenerResult<'a>
207    + Send
208    + Sync;
209pub type CreateListenerList<M> = Arc<ModelHandlerList<Box<CreateListener<M>>>>;
210
211#[async_trait::async_trait]
212pub trait CreatableModel: BaseModel + Send + Sync + 'static {
213    type CreateOptions<'a>: Send + Sync + Validate;
214    type CreateResult: Send;
215
216    fn get_create_handlers() -> &'static LazyLock<CreateListenerList<Self>>;
217
218    async fn register_create_handler<
219        F: for<'a> Fn(
220                &'a mut Self::CreateOptions<'_>,
221                &'a mut InsertQueryBuilder,
222                &'a crate::State,
223                &'a mut sqlx::Transaction<'_, sqlx::Postgres>,
224            ) -> Pin<
225                Box<dyn Future<Output = Result<(), crate::database::DatabaseError>> + Send + 'a>,
226            > + Send
227            + Sync
228            + 'static,
229    >(
230        priority: ListenerPriority,
231        callback: F,
232    ) {
233        let erased = Box::new(callback) as Box<CreateListener<Self>>;
234
235        Self::get_create_handlers()
236            .register_handler(priority, erased)
237            .await;
238    }
239
240    /// # Warning
241    /// This method will block the current thread if the lock is not available
242    fn blocking_register_create_handler<
243        F: for<'a> Fn(
244                &'a mut Self::CreateOptions<'_>,
245                &'a mut InsertQueryBuilder,
246                &'a crate::State,
247                &'a mut sqlx::Transaction<'_, sqlx::Postgres>,
248            ) -> Pin<
249                Box<dyn Future<Output = Result<(), crate::database::DatabaseError>> + Send + 'a>,
250            > + Send
251            + Sync
252            + 'static,
253    >(
254        priority: ListenerPriority,
255        callback: F,
256    ) {
257        let erased = Box::new(callback) as Box<CreateListener<Self>>;
258
259        Self::get_create_handlers().blocking_register_handler(priority, erased);
260    }
261
262    async fn run_create_handlers(
263        options: &mut Self::CreateOptions<'_>,
264        query_builder: &mut InsertQueryBuilder,
265        state: &crate::State,
266        transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
267    ) -> Result<(), crate::database::DatabaseError> {
268        let listeners = Self::get_create_handlers().listeners.read().await;
269
270        for listener in listeners.iter() {
271            (*listener.callback)(options, query_builder, state, transaction).await?;
272        }
273
274        Ok(())
275    }
276
277    async fn create(
278        state: &crate::State,
279        options: Self::CreateOptions<'_>,
280    ) -> Result<Self::CreateResult, crate::database::DatabaseError>;
281}
282
283type UpdateListenerResult<'a> =
284    Pin<Box<dyn Future<Output = Result<(), crate::database::DatabaseError>> + Send + 'a>>;
285type UpdateListener<M> = dyn for<'a> Fn(
286        &'a mut M,
287        &'a mut <M as UpdatableModel>::UpdateOptions,
288        &'a mut UpdateQueryBuilder,
289        &'a crate::State,
290        &'a mut sqlx::Transaction<'_, sqlx::Postgres>,
291    ) -> UpdateListenerResult<'a>
292    + Send
293    + Sync;
294pub type UpdateListenerList<M> = Arc<ModelHandlerList<Box<UpdateListener<M>>>>;
295
296#[async_trait::async_trait]
297pub trait UpdatableModel: BaseModel + Send + Sync + 'static {
298    type UpdateOptions: Send + Sync + Default + ToSchema + DeserializeOwned + Serialize + Validate;
299
300    fn get_update_handlers() -> &'static LazyLock<UpdateListenerList<Self>>;
301
302    async fn register_update_handler<
303        F: for<'a> Fn(
304                &'a mut Self,
305                &'a mut Self::UpdateOptions,
306                &'a mut UpdateQueryBuilder,
307                &'a crate::State,
308                &'a mut sqlx::Transaction<'_, sqlx::Postgres>,
309            ) -> Pin<
310                Box<dyn Future<Output = Result<(), crate::database::DatabaseError>> + Send + 'a>,
311            > + Send
312            + Sync
313            + 'static,
314    >(
315        priority: ListenerPriority,
316        callback: F,
317    ) {
318        let erased = Box::new(callback) as Box<UpdateListener<Self>>;
319
320        Self::get_update_handlers()
321            .register_handler(priority, erased)
322            .await;
323    }
324
325    /// # Warning
326    /// This method will block the current thread if the lock is not available
327    fn blocking_register_update_handler<
328        F: for<'a> Fn(
329                &'a mut Self,
330                &'a mut Self::UpdateOptions,
331                &'a mut UpdateQueryBuilder,
332                &'a crate::State,
333                &'a mut sqlx::Transaction<'_, sqlx::Postgres>,
334            ) -> Pin<
335                Box<dyn Future<Output = Result<(), crate::database::DatabaseError>> + Send + 'a>,
336            > + Send
337            + Sync
338            + 'static,
339    >(
340        priority: ListenerPriority,
341        callback: F,
342    ) {
343        let erased = Box::new(callback) as Box<UpdateListener<Self>>;
344
345        Self::get_update_handlers().blocking_register_handler(priority, erased);
346    }
347
348    async fn run_update_handlers(
349        &mut self,
350        options: &mut Self::UpdateOptions,
351        query_builder: &mut UpdateQueryBuilder,
352        state: &crate::State,
353        transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
354    ) -> Result<(), crate::database::DatabaseError> {
355        let listeners = Self::get_update_handlers().listeners.read().await;
356
357        for listener in listeners.iter() {
358            (*listener.callback)(self, options, query_builder, state, transaction).await?;
359        }
360
361        Ok(())
362    }
363
364    async fn update(
365        &mut self,
366        state: &crate::State,
367        options: Self::UpdateOptions,
368    ) -> Result<(), crate::database::DatabaseError>;
369}
370
371type DeleteListenerResult<'a> =
372    Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send + 'a>>;
373type DeleteListener<M> = dyn for<'a> Fn(
374        &'a M,
375        &'a <M as DeletableModel>::DeleteOptions,
376        &'a crate::State,
377        &'a mut sqlx::Transaction<'_, sqlx::Postgres>,
378    ) -> DeleteListenerResult<'a>
379    + Send
380    + Sync;
381pub type DeleteListenerList<M> = Arc<ModelHandlerList<Box<DeleteListener<M>>>>;
382
383#[async_trait::async_trait]
384pub trait DeletableModel: BaseModel + Send + Sync + 'static {
385    type DeleteOptions: Send + Sync + Default;
386
387    fn get_delete_handlers() -> &'static LazyLock<DeleteListenerList<Self>>;
388
389    async fn register_delete_handler<
390        F: for<'a> Fn(
391                &'a Self,
392                &'a Self::DeleteOptions,
393                &'a crate::State,
394                &'a mut sqlx::Transaction<'_, sqlx::Postgres>,
395            )
396                -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send + 'a>>
397            + Send
398            + Sync
399            + 'static,
400    >(
401        priority: ListenerPriority,
402        callback: F,
403    ) {
404        let erased = Box::new(callback) as Box<DeleteListener<Self>>;
405
406        Self::get_delete_handlers()
407            .register_handler(priority, erased)
408            .await;
409    }
410
411    /// # Warning
412    /// This method will block the current thread if the lock is not available
413    fn blocking_register_delete_handler<
414        F: for<'a> Fn(
415                &'a Self,
416                &'a Self::DeleteOptions,
417                &'a crate::State,
418                &'a mut sqlx::Transaction<'_, sqlx::Postgres>,
419            )
420                -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send + 'a>>
421            + Send
422            + Sync
423            + 'static,
424    >(
425        priority: ListenerPriority,
426        callback: F,
427    ) {
428        let erased = Box::new(callback) as Box<DeleteListener<Self>>;
429
430        Self::get_delete_handlers().blocking_register_handler(priority, erased);
431    }
432
433    async fn run_delete_handlers(
434        &self,
435        options: &Self::DeleteOptions,
436        state: &crate::State,
437        transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
438    ) -> Result<(), anyhow::Error> {
439        let listeners = Self::get_delete_handlers().listeners.read().await;
440
441        for listener in listeners.iter() {
442            (*listener.callback)(self, options, state, transaction).await?;
443        }
444
445        Ok(())
446    }
447
448    async fn delete(
449        &self,
450        state: &crate::State,
451        options: Self::DeleteOptions,
452    ) -> Result<(), anyhow::Error>;
453}
454
455#[async_trait::async_trait]
456pub trait ByUuid: BaseModel {
457    async fn by_uuid(
458        database: &crate::database::Database,
459        uuid: uuid::Uuid,
460    ) -> Result<Self, DatabaseError>;
461
462    async fn by_uuid_cached(
463        database: &crate::database::Database,
464        uuid: uuid::Uuid,
465    ) -> Result<Self, anyhow::Error> {
466        database
467            .cache
468            .cached(&format!("{}::{uuid}", Self::NAME), 10, || {
469                Self::by_uuid(database, uuid)
470            })
471            .await
472    }
473
474    async fn by_uuid_optional(
475        database: &crate::database::Database,
476        uuid: uuid::Uuid,
477    ) -> Result<Option<Self>, DatabaseError> {
478        match Self::by_uuid(database, uuid).await {
479            Ok(res) => Ok(Some(res)),
480            Err(DatabaseError::Sqlx(sqlx::Error::RowNotFound)) => Ok(None),
481            Err(err) => Err(err),
482        }
483    }
484
485    async fn by_uuid_optional_cached(
486        database: &crate::database::Database,
487        uuid: uuid::Uuid,
488    ) -> Result<Option<Self>, anyhow::Error> {
489        match Self::by_uuid_cached(database, uuid).await {
490            Ok(res) => Ok(Some(res)),
491            Err(err) => {
492                if let Some(sqlx::Error::RowNotFound) = err.downcast_ref::<sqlx::Error>() {
493                    Ok(None)
494                } else {
495                    Err(err)
496                }
497            }
498        }
499    }
500
501    #[inline]
502    fn get_fetchable(uuid: uuid::Uuid) -> Fetchable<Self> {
503        Fetchable {
504            uuid,
505            _model: PhantomData,
506        }
507    }
508
509    #[inline]
510    fn get_fetchable_from_row(row: &PgRow, column: impl AsRef<str>) -> Option<Fetchable<Self>> {
511        match row.try_get(column.as_ref()) {
512            Ok(uuid) => Some(Fetchable {
513                uuid,
514                _model: PhantomData,
515            }),
516            Err(_) => None,
517        }
518    }
519}
520
521#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
522pub enum ListenerPriority {
523    Highest,
524    High,
525    #[default]
526    Normal,
527    Low,
528    Lowest,
529}
530
531impl ListenerPriority {
532    #[inline]
533    fn rank(self) -> u8 {
534        match self {
535            Self::Highest => 5,
536            Self::High => 4,
537            Self::Normal => 3,
538            Self::Low => 2,
539            Self::Lowest => 1,
540        }
541    }
542}
543
544impl PartialOrd for ListenerPriority {
545    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
546        Some(self.cmp(other))
547    }
548}
549
550impl Ord for ListenerPriority {
551    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
552        let self_rank = self.rank();
553        let other_rank = other.rank();
554
555        other_rank.cmp(&self_rank)
556    }
557}
558
559#[async_trait::async_trait]
560impl<F: Send + Sync> crate::events::DisconnectEventHandler for ModelHandlerList<F> {
561    #[inline]
562    async fn disconnect(&self, id: uuid::Uuid) {
563        self.listeners.write().await.retain(|l| l.uuid != id);
564    }
565
566    #[inline]
567    fn blocking_disconnect(&self, id: uuid::Uuid) {
568        self.listeners.blocking_write().retain(|l| l.uuid != id);
569    }
570}
571
572pub struct ModelHandlerList<F: Send + Sync + 'static> {
573    listeners: RwLock<Vec<ModelHandler<F>>>,
574}
575
576impl<F: Send + Sync + 'static> Default for ModelHandlerList<F> {
577    fn default() -> Self {
578        Self {
579            listeners: RwLock::new(Vec::new()),
580        }
581    }
582}
583
584impl<F: Send + Sync + 'static> ModelHandlerList<F> {
585    pub async fn register_handler(
586        self: &Arc<Self>,
587        priority: ListenerPriority,
588        callback: F,
589    ) -> ModelHandlerHandle {
590        let listener = ModelHandler::new(callback, priority, self.clone());
591        let aborter = listener.handle();
592
593        let mut self_listeners = self.listeners.write().await;
594        self_listeners.push(listener);
595        self_listeners.sort_by(|a, b| a.priority.cmp(&b.priority));
596
597        aborter
598    }
599
600    /// # Warning
601    /// This method will block the current thread if the lock is not available
602    pub fn blocking_register_handler(
603        self: &Arc<Self>,
604        priority: ListenerPriority,
605        callback: F,
606    ) -> ModelHandlerHandle {
607        let listener = ModelHandler::new(callback, priority, self.clone());
608        let aborter = listener.handle();
609
610        let mut self_listeners = self.listeners.blocking_write();
611        self_listeners.push(listener);
612        self_listeners.sort_by(|a, b| a.priority.cmp(&b.priority));
613
614        aborter
615    }
616}
617
618pub struct ModelHandler<F: Send + Sync + 'static> {
619    uuid: uuid::Uuid,
620    priority: ListenerPriority,
621    list: Arc<ModelHandlerList<F>>,
622
623    pub callback: F,
624}
625
626impl<F: Send + Sync + 'static> ModelHandler<F> {
627    pub fn new(callback: F, priority: ListenerPriority, list: Arc<ModelHandlerList<F>>) -> Self {
628        Self {
629            uuid: uuid::Uuid::new_v4(),
630            priority,
631            list,
632            callback,
633        }
634    }
635
636    pub fn handle(&self) -> ModelHandlerHandle {
637        ModelHandlerHandle {
638            list_ref: self.list.clone(),
639            id: self.uuid,
640        }
641    }
642}
643
644pub struct ModelHandlerHandle {
645    list_ref: Arc<dyn crate::events::DisconnectEventHandler + Send + Sync>,
646    id: uuid::Uuid,
647}
648
649impl ModelHandlerHandle {
650    pub async fn disconnect(&self) {
651        self.list_ref.disconnect(self.id).await;
652    }
653
654    /// # Warning
655    /// This method will block the current thread if the lists' lock is not available
656    pub fn blocking_disconnect(&self) {
657        self.list_ref.blocking_disconnect(self.id);
658    }
659}
660
661#[derive(Serialize, Deserialize, Clone, Copy)]
662pub struct Fetchable<M: ByUuid> {
663    pub uuid: uuid::Uuid,
664    #[serde(skip)]
665    _model: PhantomData<M>,
666}
667
668impl<M: ByUuid + Send> Fetchable<M> {
669    #[inline]
670    pub async fn fetch(&self, database: &crate::database::Database) -> Result<M, DatabaseError> {
671        M::by_uuid(database, self.uuid).await
672    }
673
674    #[inline]
675    pub async fn fetch_cached(
676        &self,
677        database: &crate::database::Database,
678    ) -> Result<M, anyhow::Error> {
679        M::by_uuid_cached(database, self.uuid).await
680    }
681
682    #[inline]
683    pub async fn fetch_optional(
684        &self,
685        database: &crate::database::Database,
686    ) -> Result<Option<M>, DatabaseError> {
687        M::by_uuid_optional(database, self.uuid).await
688    }
689
690    #[inline]
691    pub async fn fetch_optional_cached(
692        &self,
693        database: &crate::database::Database,
694    ) -> Result<Option<M>, anyhow::Error> {
695        M::by_uuid_optional_cached(database, self.uuid).await
696    }
697}
698
699pub struct InsertQueryBuilder<'a> {
700    table: &'a str,
701    columns: Vec<&'a str>,
702    expressions: Vec<String>,
703    arguments: PgArguments,
704    returning_clause: Option<&'a str>,
705}
706
707impl<'a> InsertQueryBuilder<'a> {
708    pub fn new(table: &'a str) -> Self {
709        Self {
710            table,
711            columns: Vec::new(),
712            expressions: Vec::new(),
713            arguments: PgArguments::default(),
714            returning_clause: None,
715        }
716    }
717
718    pub fn set<T: 'a + sqlx::Encode<'a, Postgres> + sqlx::Type<Postgres> + Send>(
719        &mut self,
720        column: &'a str,
721        value: T,
722    ) -> &mut Self {
723        if self.columns.contains(&column) {
724            return self;
725        }
726
727        if self.arguments.add(value).is_ok() {
728            self.columns.push(column);
729            let idx = self.arguments.len();
730            self.expressions.push(format!("${}", idx));
731        }
732
733        self
734    }
735
736    pub fn set_expr<T: 'a + sqlx::Encode<'a, Postgres> + sqlx::Type<Postgres> + Send>(
737        &mut self,
738        column: &'a str,
739        expression: &str,
740        values: Vec<T>,
741    ) -> &mut Self {
742        if self.columns.contains(&column) {
743            return self;
744        }
745
746        let start_len = self.arguments.len();
747
748        for value in values {
749            if self.arguments.add(value).is_err() {
750                return self;
751            }
752        }
753
754        let mut expr = expression.to_string();
755        let added_count = self.arguments.len() - start_len;
756
757        for i in (1..=added_count).rev() {
758            let global_idx = start_len + i;
759            expr = expr.replace(&format!("${}", i), &format!("${}", global_idx));
760        }
761
762        self.columns.push(column);
763        self.expressions.push(expr);
764
765        self
766    }
767
768    pub fn returning(mut self, clause: &'a str) -> Self {
769        self.returning_clause = Some(clause);
770        self
771    }
772
773    fn build_sql(&self) -> String {
774        let columns_sql = self.columns.join(", ");
775        let values_sql = self.expressions.join(", ");
776
777        let mut sql = format!(
778            "INSERT INTO {} ({}) VALUES ({})",
779            self.table, columns_sql, values_sql
780        );
781
782        if let Some(clause) = self.returning_clause {
783            sql.push_str(" RETURNING ");
784            sql.push_str(clause);
785        }
786
787        sql
788    }
789
790    pub async fn execute(
791        self,
792        executor: impl sqlx::Executor<'a, Database = Postgres>,
793    ) -> Result<sqlx::postgres::PgQueryResult, sqlx::Error> {
794        let sql = self.build_sql();
795        sqlx::query_with(&sql, self.arguments)
796            .execute(executor)
797            .await
798    }
799
800    pub async fn fetch_one(
801        self,
802        executor: impl sqlx::Executor<'a, Database = Postgres>,
803    ) -> Result<sqlx::postgres::PgRow, sqlx::Error> {
804        let sql = self.build_sql();
805        sqlx::query_with(&sql, self.arguments)
806            .fetch_one(executor)
807            .await
808    }
809}
810
811pub struct UpdateQueryBuilder<'a> {
812    builder: QueryBuilder<'a, Postgres>,
813    updated_fields: HashSet<&'a str>,
814    has_set_fields: bool,
815}
816
817impl<'a> UpdateQueryBuilder<'a> {
818    pub fn new(table: &'a str) -> Self {
819        let mut builder = QueryBuilder::new("UPDATE ");
820        builder.push(table);
821        builder.push(" SET ");
822
823        Self {
824            builder,
825            updated_fields: HashSet::new(),
826            has_set_fields: false,
827        }
828    }
829
830    /// Adds a field to be updated, if `None`, will not add the field
831    /// To set a field to null (`None`), you need a `Some(None)`
832    pub fn set<T: 'a + sqlx::Encode<'a, Postgres> + sqlx::Type<Postgres> + Send>(
833        &mut self,
834        column: &'a str,
835        value: Option<T>,
836    ) -> &mut Self {
837        let Some(value) = value else {
838            return self;
839        };
840
841        if !self.updated_fields.insert(column) {
842            return self;
843        }
844
845        if self.has_set_fields {
846            self.builder.push(", ");
847        }
848
849        self.builder.push(column);
850        self.builder.push(" = ");
851        self.builder.push_bind(value);
852
853        self.has_set_fields = true;
854        self
855    }
856
857    pub fn where_eq<T: 'a + sqlx::Encode<'a, Postgres> + sqlx::Type<Postgres> + Send>(
858        &mut self,
859        column: &'a str,
860        value: T,
861    ) -> &mut Self {
862        self.builder.push(" WHERE ");
863        self.builder.push(column);
864        self.builder.push(" = ");
865        self.builder.push_bind(value);
866        self
867    }
868
869    pub async fn execute(
870        mut self,
871        executor: impl sqlx::Executor<'a, Database = Postgres>,
872    ) -> Result<sqlx::any::AnyQueryResult, sqlx::Error> {
873        if !self.has_set_fields {
874            return Ok(sqlx::any::AnyQueryResult::default());
875        }
876
877        let query = self.builder.build();
878        query.execute(executor).await.map(|r| r.into())
879    }
880}