1use crate::{env::RedisMode, response::ApiResponse};
2use axum::http::StatusCode;
3use compact_str::ToCompactString;
4use rustis::{
5 client::Client,
6 commands::{
7 GenericCommands, InfoSection, ServerCommands, SetCondition, SetExpiration, StringCommands,
8 },
9 resp::BulkString,
10};
11use serde::{Serialize, de::DeserializeOwned};
12use std::{
13 future::Future,
14 sync::{
15 Arc,
16 atomic::{AtomicU64, Ordering},
17 },
18 time::{Duration, Instant},
19};
20
21#[derive(Clone, Serialize)]
22pub struct BulkStringRef<'a>(
23 #[serde(
24 deserialize_with = "::rustis::resp::deserialize_byte_buf",
25 serialize_with = "::rustis::resp::serialize_byte_buf"
26 )]
27 pub &'a [u8],
28);
29
30#[derive(Clone, Debug)]
31struct DataEntry {
32 data: Arc<Vec<u8>>,
33 intended_ttl: Duration,
34}
35
36#[derive(Clone, Debug)]
37struct LockEntry {
38 semaphore: Arc<tokio::sync::Semaphore>,
39}
40
41struct DataExpiry;
42
43impl moka::Expiry<compact_str::CompactString, DataEntry> for DataExpiry {
44 fn expire_after_create(
45 &self,
46 _key: &compact_str::CompactString,
47 value: &DataEntry,
48 _created_at: Instant,
49 ) -> Option<Duration> {
50 Some(value.intended_ttl)
51 }
52}
53
54pub struct Cache {
55 client: Option<Arc<Client>>,
56 use_internal_cache: bool,
57 local: moka::future::Cache<compact_str::CompactString, DataEntry>,
58 local_task: tokio::task::JoinHandle<()>,
59 local_locks: moka::future::Cache<compact_str::CompactString, LockEntry>,
60 local_locks_task: tokio::task::JoinHandle<()>,
61 local_ratelimits: moka::future::Cache<compact_str::CompactString, (u64, u64)>,
62
63 cache_calls: AtomicU64,
64 cache_latency_ns_total: AtomicU64,
65 cache_latency_ns_max: AtomicU64,
66 cache_misses: AtomicU64,
67}
68
69impl Cache {
70 pub async fn new(env: &crate::env::Env) -> Arc<Self> {
71 let start = std::time::Instant::now();
72
73 let client = match &env.redis_mode {
74 RedisMode::Redis { redis_url } => {
75 if let Some(redis_url) = redis_url {
76 Some(Arc::new(Client::connect(redis_url.clone()).await.unwrap()))
77 } else {
78 None
79 }
80 }
81 RedisMode::Sentinel {
82 cluster_name,
83 redis_sentinels,
84 } => Some(Arc::new(
85 Client::connect(
86 format!(
87 "redis-sentinel://{}/{cluster_name}/0",
88 redis_sentinels.join(",")
89 )
90 .as_str(),
91 )
92 .await
93 .unwrap(),
94 )),
95 };
96
97 let local = moka::future::Cache::builder()
98 .max_capacity(16384)
99 .expire_after(DataExpiry)
100 .build();
101
102 let local_task = tokio::spawn({
103 let local = local.clone();
104
105 async move {
106 loop {
107 tokio::time::sleep(Duration::from_secs(10)).await;
108 local.run_pending_tasks().await;
109 }
110 }
111 });
112
113 let local_locks = moka::future::Cache::builder().max_capacity(4096).build();
114
115 let local_locks_task = tokio::spawn({
116 let local_locks = local_locks.clone();
117
118 async move {
119 loop {
120 tokio::time::sleep(Duration::from_secs(10)).await;
121 local_locks.run_pending_tasks().await;
122 }
123 }
124 });
125
126 let local_ratelimits = moka::future::Cache::builder().max_capacity(16384).build();
127
128 let instance = Arc::new(Self {
129 client,
130 use_internal_cache: env.app_use_internal_cache,
131 local,
132 local_task,
133 local_locks,
134 local_locks_task,
135 local_ratelimits,
136 cache_calls: AtomicU64::new(0),
137 cache_latency_ns_total: AtomicU64::new(0),
138 cache_latency_ns_max: AtomicU64::new(0),
139 cache_misses: AtomicU64::new(0),
140 });
141
142 let version = instance
143 .version()
144 .await
145 .unwrap_or_else(|_| "unknown".into());
146
147 tracing::info!(
148 "cache connected (redis@{}, {}ms, moka_enabled={})",
149 version,
150 start.elapsed().as_millis(),
151 env.app_use_internal_cache
152 );
153
154 instance
155 }
156
157 pub async fn version(&self) -> Result<compact_str::CompactString, rustis::Error> {
158 let Some(client) = &self.client else {
159 return Ok("memory-only".into());
160 };
161
162 let version: String = client.info([InfoSection::Server]).await?;
163 let version = version
164 .lines()
165 .find(|line| line.starts_with("redis_version:"))
166 .unwrap_or("redis_version:unknown")
167 .split(':')
168 .nth(1)
169 .unwrap_or("unknown")
170 .into();
171
172 Ok(version)
173 }
174
175 pub async fn ratelimit(
176 &self,
177 limit_identifier: impl AsRef<str>,
178 limit: u64,
179 limit_window: u64,
180 client: impl AsRef<str>,
181 ) -> Result<(), ApiResponse> {
182 let key = compact_str::format_compact!(
183 "ratelimit::{}::{}",
184 limit_identifier.as_ref(),
185 client.as_ref()
186 );
187
188 let now = chrono::Utc::now().timestamp();
189
190 if let Some(redis_client) = &self.client {
191 let expiry = redis_client.expiretime(&key).await.unwrap_or_default();
192 let expire_unix: u64 = if expiry > now + 2 {
193 expiry as u64
194 } else {
195 now as u64 + limit_window
196 };
197
198 let limit_used = redis_client.get::<u64>(&key).await.unwrap_or_default() + 1;
199 redis_client
200 .set_with_options(key, limit_used, None, SetExpiration::Exat(expire_unix))
201 .await?;
202
203 if limit_used >= limit {
204 return Err(ApiResponse::error(format!(
205 "you are ratelimited, retry in {}s",
206 expiry - now
207 ))
208 .with_status(StatusCode::TOO_MANY_REQUESTS)
209 .with_header("X-RateLimit-Limit", limit.to_string())
210 .with_header(
211 "X-RateLimit-Remaining",
212 limit.saturating_sub(limit_used).to_string(),
213 )
214 .with_header("X-RateLimit-Reset", expire_unix.to_string())
215 .with_header("Retry-After", (expiry - now).to_compact_string()));
216 }
217 } else {
218 let mut current_count = 0;
219 let mut expire_unix = now as u64 + limit_window;
220
221 if let Some((count, exp)) = self.local_ratelimits.get(&key).await
222 && exp > now as u64 + 2
223 {
224 current_count = count;
225 expire_unix = exp;
226 }
227
228 let limit_used = current_count + 1;
229 self.local_ratelimits
230 .insert(key, (limit_used, expire_unix))
231 .await;
232
233 if limit_used >= limit {
234 return Err(ApiResponse::error(format!(
235 "you are ratelimited, retry in {}s",
236 expire_unix.saturating_sub(now as u64)
237 ))
238 .with_status(StatusCode::TOO_MANY_REQUESTS)
239 .with_header("X-RateLimit-Limit", limit.to_string())
240 .with_header(
241 "X-RateLimit-Remaining",
242 limit.saturating_sub(limit_used).to_string(),
243 )
244 .with_header("X-RateLimit-Reset", expire_unix.to_string())
245 .with_header(
246 "Retry-After",
247 (expire_unix.saturating_sub(now as u64)).to_compact_string(),
248 ));
249 }
250 }
251
252 Ok(())
253 }
254
255 #[tracing::instrument(skip(self))]
256 pub async fn lock(
257 &self,
258 lock_id: impl Into<compact_str::CompactString> + std::fmt::Debug,
259 ttl: Option<u64>,
260 timeout: Option<u64>,
261 ) -> Result<CacheLock, anyhow::Error> {
262 let lock_id = lock_id.into();
263 let redis_key = compact_str::format_compact!("lock::{}", lock_id);
264 let ttl_secs = ttl.unwrap_or(30);
265 let deadline = timeout.map(|ms| Instant::now() + Duration::from_millis(ms));
266
267 tracing::debug!("acquiring cache lock");
268
269 let entry = self
270 .local_locks
271 .entry(lock_id.clone())
272 .or_insert_with(async {
273 LockEntry {
274 semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
275 }
276 })
277 .await
278 .into_value();
279
280 let permit = match deadline {
281 Some(dl) => {
282 let remaining = dl.saturating_duration_since(Instant::now());
283 tokio::time::timeout(remaining, entry.semaphore.acquire_owned())
284 .await
285 .map_err(|_| anyhow::anyhow!("timed out waiting for cache lock `{}`", lock_id))?
286 .map_err(|_| anyhow::anyhow!("semaphore closed for lock `{}`", lock_id))?
287 }
288 None => entry
289 .semaphore
290 .acquire_owned()
291 .await
292 .map_err(|_| anyhow::anyhow!("semaphore closed for lock `{}`", lock_id))?,
293 };
294
295 if let Some(redis_client) = &self.client {
296 match Self::try_acquire_redis_lock(redis_client, &redis_key, ttl_secs, deadline).await?
297 {
298 true => {
299 tracing::debug!("acquired redis cache lock");
300 Ok(CacheLock::new(
301 lock_id,
302 Some(redis_client.clone()),
303 permit,
304 ttl,
305 ))
306 }
307 false => anyhow::bail!("timed out acquiring redis lock `{}`", lock_id),
308 }
309 } else {
310 tracing::debug!("acquired memory cache lock");
311 Ok(CacheLock::new(lock_id, None, permit, ttl))
312 }
313 }
314
315 async fn try_acquire_redis_lock(
316 client: &Arc<Client>,
317 redis_key: &compact_str::CompactString,
318 ttl_secs: u64,
319 deadline: Option<Instant>,
320 ) -> Result<bool, anyhow::Error> {
321 loop {
322 let acquired = client
323 .set_with_options(
324 redis_key.as_str(),
325 "1",
326 SetCondition::NX,
327 SetExpiration::Ex(ttl_secs),
328 )
329 .await
330 .unwrap_or(false);
331
332 if acquired {
333 return Ok(true);
334 }
335
336 if let Some(dl) = deadline {
337 let remaining = dl.saturating_duration_since(Instant::now());
338 if remaining.is_zero() {
339 return Ok(false);
340 }
341 tokio::time::sleep(remaining.min(Duration::from_millis(50))).await;
342 } else {
343 tokio::time::sleep(Duration::from_millis(50)).await;
344 }
345 }
346 }
347
348 #[tracing::instrument(skip(self, fn_compute))]
349 pub async fn cached<
350 T: Serialize + DeserializeOwned + Send,
351 F: FnOnce() -> Fut,
352 Fut: Future<Output = Result<T, FutErr>>,
353 FutErr: Into<anyhow::Error> + Send + Sync + 'static,
354 >(
355 &self,
356 key: &str,
357 ttl: u64,
358 fn_compute: F,
359 ) -> Result<T, anyhow::Error> {
360 let effective_moka_ttl = if self.use_internal_cache {
361 Duration::from_secs(ttl)
362 } else {
363 Duration::from_millis(50)
364 };
365
366 let client_opt = self.client.clone();
367
368 self.cache_calls.fetch_add(1, Ordering::Relaxed);
369 let start_time = Instant::now();
370
371 let entry = self
372 .local
373 .try_get_with(key.to_compact_string(), async move {
374 if let Some(client) = &client_opt {
375 tracing::debug!("checking redis cache");
376 let cached_value: Option<BulkString> = client
377 .get(key)
378 .await
379 .map_err(|err| {
380 tracing::error!("redis get error: {:?}", err);
381 err
382 })
383 .ok()
384 .flatten();
385
386 if let Some(value) = cached_value {
387 tracing::debug!("found in redis cache");
388 return Ok(DataEntry {
389 data: Arc::new(value.to_vec()),
390 intended_ttl: effective_moka_ttl,
391 });
392 }
393 }
394
395 self.cache_misses.fetch_add(1, Ordering::Relaxed);
396
397 tracing::debug!("executing compute");
398 let result = fn_compute().await.map_err(|e| e.into())?;
399 tracing::debug!("executed compute");
400
401 let serialized = rmp_serde::to_vec(&result)?;
402 let serialized_arc = Arc::new(serialized);
403
404 if let Some(client) = &client_opt {
405 let _ = client
406 .set_with_options(
407 key,
408 BulkStringRef(&serialized_arc),
409 None,
410 SetExpiration::Ex(ttl),
411 )
412 .await;
413 }
414
415 Ok::<_, anyhow::Error>(DataEntry {
416 data: serialized_arc,
417 intended_ttl: effective_moka_ttl,
418 })
419 })
420 .await;
421
422 let elapsed_ns = start_time.elapsed().as_nanos() as u64;
423 self.cache_latency_ns_total
424 .fetch_add(elapsed_ns, Ordering::Relaxed);
425
426 let _ = self.cache_latency_ns_max.fetch_update(
427 Ordering::Relaxed,
428 Ordering::Relaxed,
429 |current_max| {
430 if elapsed_ns > current_max {
431 Some(elapsed_ns)
432 } else {
433 Some(current_max)
434 }
435 },
436 );
437
438 match entry {
439 Ok(internal_entry) => Ok(rmp_serde::from_slice::<T>(&internal_entry.data)?),
440 Err(arc_error) => Err(anyhow::anyhow!("cache computation failed: {:?}", arc_error)),
441 }
442 }
443
444 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, anyhow::Error> {
445 if let Some(entry) = self.local.get(key).await {
446 tracing::debug!("get: found in moka cache");
447 return Ok(Some(rmp_serde::from_slice::<T>(&entry.data)?));
448 }
449
450 if let Some(client) = &self.client {
451 tracing::debug!("get: checking redis cache");
452 let cached_value: Option<BulkString> = client.get(key).await?;
453
454 if let Some(value) = cached_value {
455 tracing::debug!("get: found in redis cache");
456 let data = Arc::new(value.to_vec());
457 return Ok(Some(rmp_serde::from_slice::<T>(&data)?));
458 }
459 }
460
461 Ok(None)
462 }
463
464 pub async fn get_raw(&self, key: &str) -> Result<Option<Arc<Vec<u8>>>, anyhow::Error> {
465 if let Some(entry) = self.local.get(key).await {
466 tracing::debug!("get_raw: found in moka cache");
467 return Ok(Some(entry.data.clone()));
468 }
469
470 if let Some(client) = &self.client {
471 tracing::debug!("get_raw: checking redis cache");
472 let cached_value: Option<BulkString> = client.get(key).await?;
473
474 if let Some(value) = cached_value {
475 tracing::debug!("get_raw: found in redis cache");
476 return Ok(Some(Arc::new(value.to_vec())));
477 }
478 }
479
480 Ok(None)
481 }
482
483 pub async fn set<T: Serialize + Send + Sync>(
484 &self,
485 key: &str,
486 ttl: u64,
487 value: &T,
488 ) -> Result<(), anyhow::Error> {
489 let serialized = rmp_serde::to_vec(value)?;
490 let serialized_arc = Arc::new(serialized);
491
492 let effective_moka_ttl = if self.use_internal_cache {
493 Duration::from_secs(ttl)
494 } else {
495 Duration::from_millis(50)
496 };
497
498 self.local
499 .insert(
500 key.to_compact_string(),
501 DataEntry {
502 data: serialized_arc.clone(),
503 intended_ttl: effective_moka_ttl,
504 },
505 )
506 .await;
507
508 if let Some(client) = &self.client {
509 client
510 .set_with_options(
511 key,
512 BulkStringRef(&serialized_arc),
513 None,
514 SetExpiration::Ex(ttl),
515 )
516 .await?;
517 }
518
519 Ok(())
520 }
521
522 pub async fn set_raw(
523 &self,
524 key: &str,
525 ttl: u64,
526 value: impl Into<Arc<Vec<u8>>>,
527 ) -> Result<(), anyhow::Error> {
528 let serialized_arc = value.into();
529
530 let effective_moka_ttl = if self.use_internal_cache {
531 Duration::from_secs(ttl)
532 } else {
533 Duration::from_millis(50)
534 };
535
536 self.local
537 .insert(
538 key.to_compact_string(),
539 DataEntry {
540 data: serialized_arc.clone(),
541 intended_ttl: effective_moka_ttl,
542 },
543 )
544 .await;
545
546 if let Some(client) = &self.client {
547 client
548 .set_with_options(
549 key,
550 BulkStringRef(&serialized_arc),
551 None,
552 SetExpiration::Ex(ttl),
553 )
554 .await?;
555 }
556
557 Ok(())
558 }
559
560 pub async fn exists(&self, key: &str) -> Result<bool, anyhow::Error> {
561 if self.local.contains_key(key) {
562 return Ok(true);
563 }
564
565 if let Some(client) = &self.client {
566 Ok(client.exists(key).await? > 0)
567 } else {
568 Ok(false)
569 }
570 }
571
572 pub async fn list(
573 &self,
574 prefix: &str,
575 ) -> Result<Vec<compact_str::CompactString>, anyhow::Error> {
576 if let Some(client) = &self.client {
577 let keys = client.keys(format!("{}*", prefix)).await?;
578 Ok(keys)
579 } else {
580 let mut keys = Vec::new();
581 for (key, _) in self.local.iter() {
582 if key.starts_with(prefix) {
583 keys.push(key.to_compact_string());
584 }
585 }
586 Ok(keys)
587 }
588 }
589
590 pub async fn invalidate(&self, key: &str) -> Result<(), anyhow::Error> {
591 self.local.invalidate(key).await;
592 if let Some(client) = &self.client {
593 client.del(key).await?;
594 }
595
596 Ok(())
597 }
598
599 #[inline]
600 pub fn cache_calls(&self) -> u64 {
601 self.cache_calls.load(Ordering::Relaxed)
602 }
603
604 #[inline]
605 pub fn cache_misses(&self) -> u64 {
606 self.cache_misses.load(Ordering::Relaxed)
607 }
608
609 #[inline]
610 pub fn cache_latency_ns_average(&self) -> u64 {
611 let calls = self.cache_calls();
612 self.cache_latency_ns_total
613 .load(Ordering::Relaxed)
614 .checked_div(calls)
615 .unwrap_or(0)
616 }
617
618 #[inline]
619 pub fn cache_latency_ns_max(&self) -> u64 {
620 self.cache_latency_ns_max.load(Ordering::Relaxed)
621 }
622}
623
624impl Drop for Cache {
625 fn drop(&mut self) {
626 self.local_task.abort();
627 self.local_locks_task.abort();
628 }
629}
630
631pub struct CacheLock {
632 lock_id: Option<compact_str::CompactString>,
633 redis_client: Option<Arc<Client>>,
634 permit: Option<tokio::sync::OwnedSemaphorePermit>,
635 ttl_guard: Option<tokio::task::JoinHandle<()>>,
636}
637
638impl CacheLock {
639 fn new(
640 lock_id: compact_str::CompactString,
641 redis_client: Option<Arc<Client>>,
642 permit: tokio::sync::OwnedSemaphorePermit,
643 ttl: Option<u64>,
644 ) -> Self {
645 let ttl_guard = ttl.and_then(|secs| {
646 let lock_id_clone = lock_id.clone();
647 redis_client.clone().map(|client| {
648 tokio::spawn(async move {
649 tokio::time::sleep(Duration::from_secs(secs)).await;
650 tracing::warn!(%lock_id_clone, "cache lock TTL expired; force-releasing");
651 let redis_key = compact_str::format_compact!("lock::{}", lock_id_clone);
652 let _ = client.del(&redis_key).await;
653 })
654 })
655 });
656
657 Self {
658 lock_id: Some(lock_id),
659 redis_client,
660 permit: Some(permit),
661 ttl_guard,
662 }
663 }
664
665 #[inline]
666 pub fn is_active(&self) -> bool {
667 self.lock_id.is_some() && self.ttl_guard.as_ref().is_none_or(|h| !h.is_finished())
668 }
669}
670
671impl Drop for CacheLock {
672 fn drop(&mut self) {
673 if let Some(ttl_guard) = self.ttl_guard.take() {
674 ttl_guard.abort();
675 }
676
677 self.permit.take();
678
679 if let Some(lock_id) = self.lock_id.take()
680 && let Some(client) = self.redis_client.take()
681 {
682 tokio::spawn(async move {
683 let redis_key = compact_str::format_compact!("lock::{}", lock_id);
684 let _ = client.del(&redis_key).await;
685 });
686 }
687 }
688}