Skip to main content

shared/
cache.rs

1use crate::{env::RedisMode, response::ApiResponse};
2use axum::http::StatusCode;
3use compact_str::ToCompactString;
4use rustis::{
5    client::Client,
6    commands::{
7        GenericCommands, InfoSection, ServerCommands, SetCondition, SetExpiration, StringCommands,
8    },
9    resp::BulkString,
10};
11use serde::{Serialize, de::DeserializeOwned};
12use std::{
13    future::Future,
14    sync::{
15        Arc,
16        atomic::{AtomicU64, Ordering},
17    },
18    time::{Duration, Instant},
19};
20
21#[derive(Clone, Serialize)]
22pub struct BulkStringRef<'a>(
23    #[serde(
24        deserialize_with = "::rustis::resp::deserialize_byte_buf",
25        serialize_with = "::rustis::resp::serialize_byte_buf"
26    )]
27    pub &'a [u8],
28);
29
30#[derive(Clone, Debug)]
31struct DataEntry {
32    data: Arc<Vec<u8>>,
33    intended_ttl: Duration,
34}
35
36#[derive(Clone, Debug)]
37struct LockEntry {
38    semaphore: Arc<tokio::sync::Semaphore>,
39}
40
41struct DataExpiry;
42
43impl moka::Expiry<compact_str::CompactString, DataEntry> for DataExpiry {
44    fn expire_after_create(
45        &self,
46        _key: &compact_str::CompactString,
47        value: &DataEntry,
48        _created_at: Instant,
49    ) -> Option<Duration> {
50        Some(value.intended_ttl)
51    }
52}
53
54pub struct Cache {
55    client: Option<Arc<Client>>,
56    use_internal_cache: bool,
57    local: moka::future::Cache<compact_str::CompactString, DataEntry>,
58    local_task: tokio::task::JoinHandle<()>,
59    local_locks: moka::future::Cache<compact_str::CompactString, LockEntry>,
60    local_locks_task: tokio::task::JoinHandle<()>,
61    local_ratelimits: moka::future::Cache<compact_str::CompactString, (u64, u64)>,
62
63    cache_calls: AtomicU64,
64    cache_latency_ns_total: AtomicU64,
65    cache_latency_ns_max: AtomicU64,
66    cache_misses: AtomicU64,
67}
68
69impl Cache {
70    pub async fn new(env: &crate::env::Env) -> Arc<Self> {
71        let start = std::time::Instant::now();
72
73        let client = match &env.redis_mode {
74            RedisMode::Redis { redis_url } => {
75                if let Some(redis_url) = redis_url {
76                    Some(Arc::new(Client::connect(redis_url.clone()).await.unwrap()))
77                } else {
78                    None
79                }
80            }
81            RedisMode::Sentinel {
82                cluster_name,
83                redis_sentinels,
84            } => Some(Arc::new(
85                Client::connect(
86                    format!(
87                        "redis-sentinel://{}/{cluster_name}/0",
88                        redis_sentinels.join(",")
89                    )
90                    .as_str(),
91                )
92                .await
93                .unwrap(),
94            )),
95        };
96
97        let local = moka::future::Cache::builder()
98            .max_capacity(16384)
99            .expire_after(DataExpiry)
100            .build();
101
102        let local_task = tokio::spawn({
103            let local = local.clone();
104
105            async move {
106                loop {
107                    tokio::time::sleep(Duration::from_secs(10)).await;
108                    local.run_pending_tasks().await;
109                }
110            }
111        });
112
113        let local_locks = moka::future::Cache::builder().max_capacity(4096).build();
114
115        let local_locks_task = tokio::spawn({
116            let local_locks = local_locks.clone();
117
118            async move {
119                loop {
120                    tokio::time::sleep(Duration::from_secs(10)).await;
121                    local_locks.run_pending_tasks().await;
122                }
123            }
124        });
125
126        let local_ratelimits = moka::future::Cache::builder().max_capacity(16384).build();
127
128        let instance = Arc::new(Self {
129            client,
130            use_internal_cache: env.app_use_internal_cache,
131            local,
132            local_task,
133            local_locks,
134            local_locks_task,
135            local_ratelimits,
136            cache_calls: AtomicU64::new(0),
137            cache_latency_ns_total: AtomicU64::new(0),
138            cache_latency_ns_max: AtomicU64::new(0),
139            cache_misses: AtomicU64::new(0),
140        });
141
142        let version = instance
143            .version()
144            .await
145            .unwrap_or_else(|_| "unknown".into());
146
147        tracing::info!(
148            "cache connected (redis@{}, {}ms, moka_enabled={})",
149            version,
150            start.elapsed().as_millis(),
151            env.app_use_internal_cache
152        );
153
154        instance
155    }
156
157    pub async fn version(&self) -> Result<compact_str::CompactString, rustis::Error> {
158        let Some(client) = &self.client else {
159            return Ok("memory-only".into());
160        };
161
162        let version: String = client.info([InfoSection::Server]).await?;
163        let version = version
164            .lines()
165            .find(|line| line.starts_with("redis_version:"))
166            .unwrap_or("redis_version:unknown")
167            .split(':')
168            .nth(1)
169            .unwrap_or("unknown")
170            .into();
171
172        Ok(version)
173    }
174
175    pub async fn ratelimit(
176        &self,
177        limit_identifier: impl AsRef<str>,
178        limit: u64,
179        limit_window: u64,
180        client: impl AsRef<str>,
181    ) -> Result<(), ApiResponse> {
182        let key = compact_str::format_compact!(
183            "ratelimit::{}::{}",
184            limit_identifier.as_ref(),
185            client.as_ref()
186        );
187
188        let now = chrono::Utc::now().timestamp();
189
190        if let Some(redis_client) = &self.client {
191            let expiry = redis_client.expiretime(&key).await.unwrap_or_default();
192            let expire_unix: u64 = if expiry > now + 2 {
193                expiry as u64
194            } else {
195                now as u64 + limit_window
196            };
197
198            let limit_used = redis_client.get::<u64>(&key).await.unwrap_or_default() + 1;
199            redis_client
200                .set_with_options(key, limit_used, None, SetExpiration::Exat(expire_unix))
201                .await?;
202
203            if limit_used >= limit {
204                return Err(ApiResponse::error(format!(
205                    "you are ratelimited, retry in {}s",
206                    expiry - now
207                ))
208                .with_status(StatusCode::TOO_MANY_REQUESTS)
209                .with_header("X-RateLimit-Limit", limit.to_string())
210                .with_header(
211                    "X-RateLimit-Remaining",
212                    limit.saturating_sub(limit_used).to_string(),
213                )
214                .with_header("X-RateLimit-Reset", expire_unix.to_string())
215                .with_header("Retry-After", (expiry - now).to_compact_string()));
216            }
217        } else {
218            let mut current_count = 0;
219            let mut expire_unix = now as u64 + limit_window;
220
221            if let Some((count, exp)) = self.local_ratelimits.get(&key).await
222                && exp > now as u64 + 2
223            {
224                current_count = count;
225                expire_unix = exp;
226            }
227
228            let limit_used = current_count + 1;
229            self.local_ratelimits
230                .insert(key, (limit_used, expire_unix))
231                .await;
232
233            if limit_used >= limit {
234                return Err(ApiResponse::error(format!(
235                    "you are ratelimited, retry in {}s",
236                    expire_unix.saturating_sub(now as u64)
237                ))
238                .with_status(StatusCode::TOO_MANY_REQUESTS)
239                .with_header("X-RateLimit-Limit", limit.to_string())
240                .with_header(
241                    "X-RateLimit-Remaining",
242                    limit.saturating_sub(limit_used).to_string(),
243                )
244                .with_header("X-RateLimit-Reset", expire_unix.to_string())
245                .with_header(
246                    "Retry-After",
247                    (expire_unix.saturating_sub(now as u64)).to_compact_string(),
248                ));
249            }
250        }
251
252        Ok(())
253    }
254
255    #[tracing::instrument(skip(self))]
256    pub async fn lock(
257        &self,
258        lock_id: impl Into<compact_str::CompactString> + std::fmt::Debug,
259        ttl: Option<u64>,
260        timeout: Option<u64>,
261    ) -> Result<CacheLock, anyhow::Error> {
262        let lock_id = lock_id.into();
263        let redis_key = compact_str::format_compact!("lock::{}", lock_id);
264        let ttl_secs = ttl.unwrap_or(30);
265        let deadline = timeout.map(|ms| Instant::now() + Duration::from_millis(ms));
266
267        tracing::debug!("acquiring cache lock");
268
269        let entry = self
270            .local_locks
271            .entry(lock_id.clone())
272            .or_insert_with(async {
273                LockEntry {
274                    semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
275                }
276            })
277            .await
278            .into_value();
279
280        let permit = match deadline {
281            Some(dl) => {
282                let remaining = dl.saturating_duration_since(Instant::now());
283                tokio::time::timeout(remaining, entry.semaphore.acquire_owned())
284                    .await
285                    .map_err(|_| anyhow::anyhow!("timed out waiting for cache lock `{}`", lock_id))?
286                    .map_err(|_| anyhow::anyhow!("semaphore closed for lock `{}`", lock_id))?
287            }
288            None => entry
289                .semaphore
290                .acquire_owned()
291                .await
292                .map_err(|_| anyhow::anyhow!("semaphore closed for lock `{}`", lock_id))?,
293        };
294
295        if let Some(redis_client) = &self.client {
296            match Self::try_acquire_redis_lock(redis_client, &redis_key, ttl_secs, deadline).await?
297            {
298                true => {
299                    tracing::debug!("acquired redis cache lock");
300                    Ok(CacheLock::new(
301                        lock_id,
302                        Some(redis_client.clone()),
303                        permit,
304                        ttl,
305                    ))
306                }
307                false => anyhow::bail!("timed out acquiring redis lock `{}`", lock_id),
308            }
309        } else {
310            tracing::debug!("acquired memory cache lock");
311            Ok(CacheLock::new(lock_id, None, permit, ttl))
312        }
313    }
314
315    async fn try_acquire_redis_lock(
316        client: &Arc<Client>,
317        redis_key: &compact_str::CompactString,
318        ttl_secs: u64,
319        deadline: Option<Instant>,
320    ) -> Result<bool, anyhow::Error> {
321        loop {
322            let acquired = client
323                .set_with_options(
324                    redis_key.as_str(),
325                    "1",
326                    SetCondition::NX,
327                    SetExpiration::Ex(ttl_secs),
328                )
329                .await
330                .unwrap_or(false);
331
332            if acquired {
333                return Ok(true);
334            }
335
336            if let Some(dl) = deadline {
337                let remaining = dl.saturating_duration_since(Instant::now());
338                if remaining.is_zero() {
339                    return Ok(false);
340                }
341                tokio::time::sleep(remaining.min(Duration::from_millis(50))).await;
342            } else {
343                tokio::time::sleep(Duration::from_millis(50)).await;
344            }
345        }
346    }
347
348    #[tracing::instrument(skip(self, fn_compute))]
349    pub async fn cached<
350        T: Serialize + DeserializeOwned + Send,
351        F: FnOnce() -> Fut,
352        Fut: Future<Output = Result<T, FutErr>>,
353        FutErr: Into<anyhow::Error> + Send + Sync + 'static,
354    >(
355        &self,
356        key: &str,
357        ttl: u64,
358        fn_compute: F,
359    ) -> Result<T, anyhow::Error> {
360        let effective_moka_ttl = if self.use_internal_cache {
361            Duration::from_secs(ttl)
362        } else {
363            Duration::from_millis(50)
364        };
365
366        let client_opt = self.client.clone();
367
368        self.cache_calls.fetch_add(1, Ordering::Relaxed);
369        let start_time = Instant::now();
370
371        let entry = self
372            .local
373            .try_get_with(key.to_compact_string(), async move {
374                if let Some(client) = &client_opt {
375                    tracing::debug!("checking redis cache");
376                    let cached_value: Option<BulkString> = client
377                        .get(key)
378                        .await
379                        .map_err(|err| {
380                            tracing::error!("redis get error: {:?}", err);
381                            err
382                        })
383                        .ok()
384                        .flatten();
385
386                    if let Some(value) = cached_value {
387                        tracing::debug!("found in redis cache");
388                        return Ok(DataEntry {
389                            data: Arc::new(value.to_vec()),
390                            intended_ttl: effective_moka_ttl,
391                        });
392                    }
393                }
394
395                self.cache_misses.fetch_add(1, Ordering::Relaxed);
396
397                tracing::debug!("executing compute");
398                let result = fn_compute().await.map_err(|e| e.into())?;
399                tracing::debug!("executed compute");
400
401                let serialized = rmp_serde::to_vec(&result)?;
402                let serialized_arc = Arc::new(serialized);
403
404                if let Some(client) = &client_opt {
405                    let _ = client
406                        .set_with_options(
407                            key,
408                            BulkStringRef(&serialized_arc),
409                            None,
410                            SetExpiration::Ex(ttl),
411                        )
412                        .await;
413                }
414
415                Ok::<_, anyhow::Error>(DataEntry {
416                    data: serialized_arc,
417                    intended_ttl: effective_moka_ttl,
418                })
419            })
420            .await;
421
422        let elapsed_ns = start_time.elapsed().as_nanos() as u64;
423        self.cache_latency_ns_total
424            .fetch_add(elapsed_ns, Ordering::Relaxed);
425
426        let _ = self.cache_latency_ns_max.fetch_update(
427            Ordering::Relaxed,
428            Ordering::Relaxed,
429            |current_max| {
430                if elapsed_ns > current_max {
431                    Some(elapsed_ns)
432                } else {
433                    Some(current_max)
434                }
435            },
436        );
437
438        match entry {
439            Ok(internal_entry) => Ok(rmp_serde::from_slice::<T>(&internal_entry.data)?),
440            Err(arc_error) => Err(anyhow::anyhow!("cache computation failed: {:?}", arc_error)),
441        }
442    }
443
444    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, anyhow::Error> {
445        if let Some(entry) = self.local.get(key).await {
446            tracing::debug!("get: found in moka cache");
447            return Ok(Some(rmp_serde::from_slice::<T>(&entry.data)?));
448        }
449
450        if let Some(client) = &self.client {
451            tracing::debug!("get: checking redis cache");
452            let cached_value: Option<BulkString> = client.get(key).await?;
453
454            if let Some(value) = cached_value {
455                tracing::debug!("get: found in redis cache");
456                let data = Arc::new(value.to_vec());
457                return Ok(Some(rmp_serde::from_slice::<T>(&data)?));
458            }
459        }
460
461        Ok(None)
462    }
463
464    pub async fn get_raw(&self, key: &str) -> Result<Option<Arc<Vec<u8>>>, anyhow::Error> {
465        if let Some(entry) = self.local.get(key).await {
466            tracing::debug!("get_raw: found in moka cache");
467            return Ok(Some(entry.data.clone()));
468        }
469
470        if let Some(client) = &self.client {
471            tracing::debug!("get_raw: checking redis cache");
472            let cached_value: Option<BulkString> = client.get(key).await?;
473
474            if let Some(value) = cached_value {
475                tracing::debug!("get_raw: found in redis cache");
476                return Ok(Some(Arc::new(value.to_vec())));
477            }
478        }
479
480        Ok(None)
481    }
482
483    pub async fn set<T: Serialize + Send + Sync>(
484        &self,
485        key: &str,
486        ttl: u64,
487        value: &T,
488    ) -> Result<(), anyhow::Error> {
489        let serialized = rmp_serde::to_vec(value)?;
490        let serialized_arc = Arc::new(serialized);
491
492        let effective_moka_ttl = if self.use_internal_cache {
493            Duration::from_secs(ttl)
494        } else {
495            Duration::from_millis(50)
496        };
497
498        self.local
499            .insert(
500                key.to_compact_string(),
501                DataEntry {
502                    data: serialized_arc.clone(),
503                    intended_ttl: effective_moka_ttl,
504                },
505            )
506            .await;
507
508        if let Some(client) = &self.client {
509            client
510                .set_with_options(
511                    key,
512                    BulkStringRef(&serialized_arc),
513                    None,
514                    SetExpiration::Ex(ttl),
515                )
516                .await?;
517        }
518
519        Ok(())
520    }
521
522    pub async fn set_raw(
523        &self,
524        key: &str,
525        ttl: u64,
526        value: impl Into<Arc<Vec<u8>>>,
527    ) -> Result<(), anyhow::Error> {
528        let serialized_arc = value.into();
529
530        let effective_moka_ttl = if self.use_internal_cache {
531            Duration::from_secs(ttl)
532        } else {
533            Duration::from_millis(50)
534        };
535
536        self.local
537            .insert(
538                key.to_compact_string(),
539                DataEntry {
540                    data: serialized_arc.clone(),
541                    intended_ttl: effective_moka_ttl,
542                },
543            )
544            .await;
545
546        if let Some(client) = &self.client {
547            client
548                .set_with_options(
549                    key,
550                    BulkStringRef(&serialized_arc),
551                    None,
552                    SetExpiration::Ex(ttl),
553                )
554                .await?;
555        }
556
557        Ok(())
558    }
559
560    pub async fn exists(&self, key: &str) -> Result<bool, anyhow::Error> {
561        if self.local.contains_key(key) {
562            return Ok(true);
563        }
564
565        if let Some(client) = &self.client {
566            Ok(client.exists(key).await? > 0)
567        } else {
568            Ok(false)
569        }
570    }
571
572    pub async fn list(
573        &self,
574        prefix: &str,
575    ) -> Result<Vec<compact_str::CompactString>, anyhow::Error> {
576        if let Some(client) = &self.client {
577            let keys = client.keys(format!("{}*", prefix)).await?;
578            Ok(keys)
579        } else {
580            let mut keys = Vec::new();
581            for (key, _) in self.local.iter() {
582                if key.starts_with(prefix) {
583                    keys.push(key.to_compact_string());
584                }
585            }
586            Ok(keys)
587        }
588    }
589
590    pub async fn invalidate(&self, key: &str) -> Result<(), anyhow::Error> {
591        self.local.invalidate(key).await;
592        if let Some(client) = &self.client {
593            client.del(key).await?;
594        }
595
596        Ok(())
597    }
598
599    #[inline]
600    pub fn cache_calls(&self) -> u64 {
601        self.cache_calls.load(Ordering::Relaxed)
602    }
603
604    #[inline]
605    pub fn cache_misses(&self) -> u64 {
606        self.cache_misses.load(Ordering::Relaxed)
607    }
608
609    #[inline]
610    pub fn cache_latency_ns_average(&self) -> u64 {
611        let calls = self.cache_calls();
612        self.cache_latency_ns_total
613            .load(Ordering::Relaxed)
614            .checked_div(calls)
615            .unwrap_or(0)
616    }
617
618    #[inline]
619    pub fn cache_latency_ns_max(&self) -> u64 {
620        self.cache_latency_ns_max.load(Ordering::Relaxed)
621    }
622}
623
624impl Drop for Cache {
625    fn drop(&mut self) {
626        self.local_task.abort();
627        self.local_locks_task.abort();
628    }
629}
630
631pub struct CacheLock {
632    lock_id: Option<compact_str::CompactString>,
633    redis_client: Option<Arc<Client>>,
634    permit: Option<tokio::sync::OwnedSemaphorePermit>,
635    ttl_guard: Option<tokio::task::JoinHandle<()>>,
636}
637
638impl CacheLock {
639    fn new(
640        lock_id: compact_str::CompactString,
641        redis_client: Option<Arc<Client>>,
642        permit: tokio::sync::OwnedSemaphorePermit,
643        ttl: Option<u64>,
644    ) -> Self {
645        let ttl_guard = ttl.and_then(|secs| {
646            let lock_id_clone = lock_id.clone();
647            redis_client.clone().map(|client| {
648                tokio::spawn(async move {
649                    tokio::time::sleep(Duration::from_secs(secs)).await;
650                    tracing::warn!(%lock_id_clone, "cache lock TTL expired; force-releasing");
651                    let redis_key = compact_str::format_compact!("lock::{}", lock_id_clone);
652                    let _ = client.del(&redis_key).await;
653                })
654            })
655        });
656
657        Self {
658            lock_id: Some(lock_id),
659            redis_client,
660            permit: Some(permit),
661            ttl_guard,
662        }
663    }
664
665    #[inline]
666    pub fn is_active(&self) -> bool {
667        self.lock_id.is_some() && self.ttl_guard.as_ref().is_none_or(|h| !h.is_finished())
668    }
669}
670
671impl Drop for CacheLock {
672    fn drop(&mut self) {
673        if let Some(ttl_guard) = self.ttl_guard.take() {
674            ttl_guard.abort();
675        }
676
677        self.permit.take();
678
679        if let Some(lock_id) = self.lock_id.take()
680            && let Some(client) = self.redis_client.take()
681        {
682            tokio::spawn(async move {
683                let redis_key = compact_str::format_compact!("lock::{}", lock_id);
684                let _ = client.del(&redis_key).await;
685            });
686        }
687    }
688}