shared/
cache.rs

1use crate::{env::RedisMode, response::ApiResponse};
2use axum::http::StatusCode;
3use compact_str::ToCompactString;
4use rustis::{
5    client::Client,
6    commands::{
7        GenericCommands, InfoSection, ServerCommands, SetCondition, SetExpiration, StringCommands,
8    },
9    resp::BulkString,
10};
11use serde::{Serialize, de::DeserializeOwned};
12use std::{
13    future::Future,
14    sync::{
15        Arc,
16        atomic::{AtomicU64, Ordering},
17    },
18    time::{Duration, Instant},
19};
20
21#[derive(Clone, Serialize)]
22pub struct BulkStringRef<'a>(
23    #[serde(
24        deserialize_with = "::rustis::resp::deserialize_byte_buf",
25        serialize_with = "::rustis::resp::serialize_byte_buf"
26    )]
27    pub &'a [u8],
28);
29
30#[derive(Clone, Debug)]
31struct DataEntry {
32    data: Arc<Vec<u8>>,
33    intended_ttl: Duration,
34}
35
36#[derive(Clone, Debug)]
37struct LockEntry {
38    semaphore: Arc<tokio::sync::Semaphore>,
39}
40
41struct DataExpiry;
42
43impl moka::Expiry<compact_str::CompactString, DataEntry> for DataExpiry {
44    fn expire_after_create(
45        &self,
46        _key: &compact_str::CompactString,
47        value: &DataEntry,
48        _created_at: Instant,
49    ) -> Option<Duration> {
50        Some(value.intended_ttl)
51    }
52}
53
54pub struct Cache {
55    client: Option<Arc<Client>>,
56    use_internal_cache: bool,
57    local: moka::future::Cache<compact_str::CompactString, DataEntry>,
58    local_task: tokio::task::JoinHandle<()>,
59    local_locks: moka::future::Cache<compact_str::CompactString, LockEntry>,
60    local_locks_task: tokio::task::JoinHandle<()>,
61    local_ratelimits: moka::future::Cache<compact_str::CompactString, (u64, u64)>,
62
63    cache_calls: AtomicU64,
64    cache_latency_ns_total: AtomicU64,
65    cache_latency_ns_max: AtomicU64,
66    cache_misses: AtomicU64,
67}
68
69impl Cache {
70    pub async fn new(env: &crate::env::Env) -> Arc<Self> {
71        let start = std::time::Instant::now();
72
73        let client = match &env.redis_mode {
74            RedisMode::Redis { redis_url } => {
75                if let Some(redis_url) = redis_url {
76                    Some(Arc::new(Client::connect(redis_url.clone()).await.unwrap()))
77                } else {
78                    None
79                }
80            }
81            RedisMode::Sentinel {
82                cluster_name,
83                redis_sentinels,
84            } => Some(Arc::new(
85                Client::connect(
86                    format!(
87                        "redis-sentinel://{}/{cluster_name}/0",
88                        redis_sentinels.join(",")
89                    )
90                    .as_str(),
91                )
92                .await
93                .unwrap(),
94            )),
95        };
96
97        let local = moka::future::Cache::builder()
98            .max_capacity(16384)
99            .expire_after(DataExpiry)
100            .build();
101
102        let local_task = tokio::spawn({
103            let local = local.clone();
104
105            async move {
106                loop {
107                    tokio::time::sleep(Duration::from_secs(10)).await;
108                    local.run_pending_tasks().await;
109                }
110            }
111        });
112
113        let local_locks = moka::future::Cache::builder().max_capacity(4096).build();
114
115        let local_locks_task = tokio::spawn({
116            let local_locks = local_locks.clone();
117
118            async move {
119                loop {
120                    tokio::time::sleep(Duration::from_secs(10)).await;
121                    local_locks.run_pending_tasks().await;
122                }
123            }
124        });
125
126        let local_ratelimits = moka::future::Cache::builder().max_capacity(16384).build();
127
128        let instance = Arc::new(Self {
129            client,
130            use_internal_cache: env.app_use_internal_cache,
131            local,
132            local_task,
133            local_locks,
134            local_locks_task,
135            local_ratelimits,
136            cache_calls: AtomicU64::new(0),
137            cache_latency_ns_total: AtomicU64::new(0),
138            cache_latency_ns_max: AtomicU64::new(0),
139            cache_misses: AtomicU64::new(0),
140        });
141
142        let version = instance
143            .version()
144            .await
145            .unwrap_or_else(|_| "unknown".into());
146
147        tracing::info!(
148            "cache connected (redis@{}, {}ms, moka_enabled={})",
149            version,
150            start.elapsed().as_millis(),
151            env.app_use_internal_cache
152        );
153
154        instance
155    }
156
157    pub async fn version(&self) -> Result<compact_str::CompactString, rustis::Error> {
158        let Some(client) = &self.client else {
159            return Ok("memory-only".into());
160        };
161
162        let version: String = client.info([InfoSection::Server]).await?;
163        let version = version
164            .lines()
165            .find(|line| line.starts_with("redis_version:"))
166            .unwrap_or("redis_version:unknown")
167            .split(':')
168            .nth(1)
169            .unwrap_or("unknown")
170            .into();
171
172        Ok(version)
173    }
174
175    pub async fn ratelimit(
176        &self,
177        limit_identifier: impl AsRef<str>,
178        limit: u64,
179        limit_window: u64,
180        client: impl AsRef<str>,
181    ) -> Result<(), ApiResponse> {
182        let key = compact_str::format_compact!(
183            "ratelimit::{}::{}",
184            limit_identifier.as_ref(),
185            client.as_ref()
186        );
187
188        let now = chrono::Utc::now().timestamp();
189
190        if let Some(redis_client) = &self.client {
191            let expiry = redis_client.expiretime(&key).await.unwrap_or_default();
192            let expire_unix: u64 = if expiry > now + 2 {
193                expiry as u64
194            } else {
195                now as u64 + limit_window
196            };
197
198            let limit_used = redis_client.get::<u64>(&key).await.unwrap_or_default() + 1;
199            redis_client
200                .set_with_options(key, limit_used, None, SetExpiration::Exat(expire_unix))
201                .await?;
202
203            if limit_used >= limit {
204                return Err(ApiResponse::error(format!(
205                    "you are ratelimited, retry in {}s",
206                    expiry - now
207                ))
208                .with_status(StatusCode::TOO_MANY_REQUESTS));
209            }
210        } else {
211            let mut current_count = 0;
212            let mut expire_unix = now as u64 + limit_window;
213
214            if let Some((count, exp)) = self.local_ratelimits.get(&key).await
215                && exp > now as u64 + 2
216            {
217                current_count = count;
218                expire_unix = exp;
219            }
220
221            let limit_used = current_count + 1;
222            self.local_ratelimits
223                .insert(key, (limit_used, expire_unix))
224                .await;
225
226            if limit_used >= limit {
227                return Err(ApiResponse::error(format!(
228                    "you are ratelimited, retry in {}s",
229                    expire_unix.saturating_sub(now as u64)
230                ))
231                .with_status(StatusCode::TOO_MANY_REQUESTS));
232            }
233        }
234
235        Ok(())
236    }
237
238    #[tracing::instrument(skip(self))]
239    pub async fn lock(
240        &self,
241        lock_id: impl Into<compact_str::CompactString> + std::fmt::Debug,
242        ttl: Option<u64>,
243        timeout: Option<u64>,
244    ) -> Result<CacheLock, anyhow::Error> {
245        let lock_id = lock_id.into();
246        let redis_key = compact_str::format_compact!("lock::{}", lock_id);
247        let ttl_secs = ttl.unwrap_or(30);
248        let deadline = timeout.map(|ms| Instant::now() + Duration::from_secs(ms));
249
250        tracing::debug!("acquiring cache lock");
251
252        let entry = self
253            .local_locks
254            .entry(lock_id.clone())
255            .or_insert_with(async {
256                LockEntry {
257                    semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
258                }
259            })
260            .await
261            .into_value();
262
263        let permit = match deadline {
264            Some(dl) => {
265                let remaining = dl.saturating_duration_since(Instant::now());
266                tokio::time::timeout(remaining, entry.semaphore.acquire_owned())
267                    .await
268                    .map_err(|_| anyhow::anyhow!("timed out waiting for cache lock `{}`", lock_id))?
269                    .map_err(|_| anyhow::anyhow!("semaphore closed for lock `{}`", lock_id))?
270            }
271            None => entry
272                .semaphore
273                .acquire_owned()
274                .await
275                .map_err(|_| anyhow::anyhow!("semaphore closed for lock `{}`", lock_id))?,
276        };
277
278        if let Some(redis_client) = &self.client {
279            match Self::try_acquire_redis_lock(redis_client, &redis_key, ttl_secs, deadline).await?
280            {
281                true => {
282                    tracing::debug!("acquired redis cache lock");
283                    Ok(CacheLock::new(
284                        lock_id,
285                        Some(redis_client.clone()),
286                        permit,
287                        ttl,
288                    ))
289                }
290                false => anyhow::bail!("timed out acquiring redis lock `{}`", lock_id),
291            }
292        } else {
293            tracing::debug!("acquired memory cache lock");
294            Ok(CacheLock::new(lock_id, None, permit, ttl))
295        }
296    }
297
298    async fn try_acquire_redis_lock(
299        client: &Arc<Client>,
300        redis_key: &compact_str::CompactString,
301        ttl_secs: u64,
302        deadline: Option<Instant>,
303    ) -> Result<bool, anyhow::Error> {
304        loop {
305            let acquired = client
306                .set_with_options(
307                    redis_key.as_str(),
308                    "1",
309                    SetCondition::NX,
310                    SetExpiration::Ex(ttl_secs),
311                )
312                .await
313                .unwrap_or(false);
314
315            if acquired {
316                return Ok(true);
317            }
318
319            if let Some(dl) = deadline {
320                let remaining = dl.saturating_duration_since(Instant::now());
321                if remaining.is_zero() {
322                    return Ok(false);
323                }
324                tokio::time::sleep(remaining.min(Duration::from_millis(50))).await;
325            } else {
326                tokio::time::sleep(Duration::from_millis(50)).await;
327            }
328        }
329    }
330
331    #[tracing::instrument(skip(self, fn_compute))]
332    pub async fn cached<
333        T: Serialize + DeserializeOwned + Send,
334        F: FnOnce() -> Fut,
335        Fut: Future<Output = Result<T, FutErr>>,
336        FutErr: Into<anyhow::Error> + Send + Sync + 'static,
337    >(
338        &self,
339        key: &str,
340        ttl: u64,
341        fn_compute: F,
342    ) -> Result<T, anyhow::Error> {
343        let effective_moka_ttl = if self.use_internal_cache {
344            Duration::from_secs(ttl)
345        } else {
346            Duration::from_millis(50)
347        };
348
349        let client_opt = self.client.clone();
350
351        self.cache_calls.fetch_add(1, Ordering::Relaxed);
352        let start_time = Instant::now();
353
354        if let Some(entry) = self.local.get(key).await {
355            tracing::debug!("found in moka cache");
356            return Ok(rmp_serde::from_slice::<T>(&entry.data)?);
357        }
358
359        let entry = self
360            .local
361            .try_get_with(key.to_compact_string(), async move {
362                if let Some(client) = &client_opt {
363                    tracing::debug!("checking redis cache");
364                    let cached_value: Option<BulkString> = client
365                        .get(key)
366                        .await
367                        .map_err(|err| {
368                            tracing::error!("redis get error: {:?}", err);
369                            err
370                        })
371                        .ok()
372                        .flatten();
373
374                    if let Some(value) = cached_value {
375                        tracing::debug!("found in redis cache");
376                        return Ok(DataEntry {
377                            data: Arc::new(value.to_vec()),
378                            intended_ttl: effective_moka_ttl,
379                        });
380                    }
381                }
382
383                self.cache_misses.fetch_add(1, Ordering::Relaxed);
384
385                tracing::debug!("executing compute");
386                let result = fn_compute().await.map_err(|e| e.into())?;
387                tracing::debug!("executed compute");
388
389                let serialized = rmp_serde::to_vec(&result)?;
390                let serialized_arc = Arc::new(serialized);
391
392                if let Some(client) = &client_opt {
393                    let _ = client
394                        .set_with_options(
395                            key,
396                            BulkStringRef(&serialized_arc),
397                            None,
398                            SetExpiration::Ex(ttl),
399                        )
400                        .await;
401                }
402
403                Ok::<_, anyhow::Error>(DataEntry {
404                    data: serialized_arc,
405                    intended_ttl: effective_moka_ttl,
406                })
407            })
408            .await;
409
410        let elapsed_ns = start_time.elapsed().as_nanos() as u64;
411        self.cache_latency_ns_total
412            .fetch_add(elapsed_ns, Ordering::Relaxed);
413
414        let _ = self.cache_latency_ns_max.fetch_update(
415            Ordering::Relaxed,
416            Ordering::Relaxed,
417            |current_max| {
418                if elapsed_ns > current_max {
419                    Some(elapsed_ns)
420                } else {
421                    Some(current_max)
422                }
423            },
424        );
425
426        match entry {
427            Ok(internal_entry) => Ok(rmp_serde::from_slice::<T>(&internal_entry.data)?),
428            Err(arc_error) => Err(anyhow::anyhow!("cache computation failed: {:?}", arc_error)),
429        }
430    }
431
432    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, anyhow::Error> {
433        if let Some(entry) = self.local.get(key).await {
434            tracing::debug!("get: found in moka cache");
435            return Ok(Some(rmp_serde::from_slice::<T>(&entry.data)?));
436        }
437
438        if let Some(client) = &self.client {
439            tracing::debug!("get: checking redis cache");
440            let cached_value: Option<BulkString> = client.get(key).await?;
441
442            if let Some(value) = cached_value {
443                tracing::debug!("get: found in redis cache");
444                let data = Arc::new(value.to_vec());
445                return Ok(Some(rmp_serde::from_slice::<T>(&data)?));
446            }
447        }
448
449        Ok(None)
450    }
451
452    pub async fn set<T: Serialize + Send + Sync>(
453        &self,
454        key: &str,
455        ttl: u64,
456        value: &T,
457    ) -> Result<(), anyhow::Error> {
458        let serialized = rmp_serde::to_vec(value)?;
459        let serialized_arc = Arc::new(serialized);
460
461        let effective_moka_ttl = if self.use_internal_cache {
462            Duration::from_secs(ttl)
463        } else {
464            Duration::from_millis(50)
465        };
466
467        self.local
468            .insert(
469                key.to_compact_string(),
470                DataEntry {
471                    data: serialized_arc.clone(),
472                    intended_ttl: effective_moka_ttl,
473                },
474            )
475            .await;
476
477        if let Some(client) = &self.client {
478            client
479                .set_with_options(
480                    key,
481                    BulkStringRef(&serialized_arc),
482                    None,
483                    SetExpiration::Ex(ttl),
484                )
485                .await?;
486        }
487
488        Ok(())
489    }
490
491    pub async fn invalidate(&self, key: &str) -> Result<(), anyhow::Error> {
492        self.local.invalidate(key).await;
493        if let Some(client) = &self.client {
494            client.del(key).await?;
495        }
496
497        Ok(())
498    }
499
500    #[inline]
501    pub fn cache_calls(&self) -> u64 {
502        self.cache_calls.load(Ordering::Relaxed)
503    }
504
505    #[inline]
506    pub fn cache_misses(&self) -> u64 {
507        self.cache_misses.load(Ordering::Relaxed)
508    }
509
510    #[inline]
511    pub fn cache_latency_ns_average(&self) -> u64 {
512        let calls = self.cache_calls();
513        if calls == 0 {
514            0
515        } else {
516            self.cache_latency_ns_total.load(Ordering::Relaxed) / calls
517        }
518    }
519}
520
521impl Drop for Cache {
522    fn drop(&mut self) {
523        self.local_task.abort();
524        self.local_locks_task.abort();
525    }
526}
527
528pub struct CacheLock {
529    lock_id: Option<compact_str::CompactString>,
530    redis_client: Option<Arc<Client>>,
531    permit: Option<tokio::sync::OwnedSemaphorePermit>,
532    ttl_guard: Option<tokio::task::JoinHandle<()>>,
533}
534
535impl CacheLock {
536    fn new(
537        lock_id: compact_str::CompactString,
538        redis_client: Option<Arc<Client>>,
539        permit: tokio::sync::OwnedSemaphorePermit,
540        ttl: Option<u64>,
541    ) -> Self {
542        let ttl_guard = ttl.and_then(|secs| {
543            let lock_id_clone = lock_id.clone();
544            redis_client.clone().map(|client| {
545                tokio::spawn(async move {
546                    tokio::time::sleep(Duration::from_secs(secs)).await;
547                    tracing::warn!(%lock_id_clone, "cache lock TTL expired; force-releasing");
548                    let redis_key = compact_str::format_compact!("lock::{}", lock_id_clone);
549                    let _ = client.del(&redis_key).await;
550                })
551            })
552        });
553
554        Self {
555            lock_id: Some(lock_id),
556            redis_client,
557            permit: Some(permit),
558            ttl_guard,
559        }
560    }
561
562    #[inline]
563    pub fn is_active(&self) -> bool {
564        self.lock_id.is_some() && self.ttl_guard.as_ref().is_none_or(|h| !h.is_finished())
565    }
566}
567
568impl Drop for CacheLock {
569    fn drop(&mut self) {
570        if let Some(ttl_guard) = self.ttl_guard.take() {
571            ttl_guard.abort();
572        }
573
574        self.permit.take();
575
576        if let Some(lock_id) = self.lock_id.take()
577            && let Some(client) = self.redis_client.take()
578        {
579            tokio::spawn(async move {
580                let redis_key = compact_str::format_compact!("lock::{}", lock_id);
581                let _ = client.del(&redis_key).await;
582            });
583        }
584    }
585}