1use crate::{env::RedisMode, response::ApiResponse};
2use axum::http::StatusCode;
3use compact_str::ToCompactString;
4use rustis::{
5 client::Client,
6 commands::{
7 GenericCommands, InfoSection, ServerCommands, SetCondition, SetExpiration, StringCommands,
8 },
9 resp::BulkString,
10};
11use serde::{Serialize, de::DeserializeOwned};
12use std::{
13 future::Future,
14 sync::{
15 Arc,
16 atomic::{AtomicU64, Ordering},
17 },
18 time::{Duration, Instant},
19};
20
21#[derive(Clone, Serialize)]
22pub struct BulkStringRef<'a>(
23 #[serde(
24 deserialize_with = "::rustis::resp::deserialize_byte_buf",
25 serialize_with = "::rustis::resp::serialize_byte_buf"
26 )]
27 pub &'a [u8],
28);
29
30#[derive(Clone, Debug)]
31struct DataEntry {
32 data: Arc<Vec<u8>>,
33 intended_ttl: Duration,
34}
35
36#[derive(Clone, Debug)]
37struct LockEntry {
38 semaphore: Arc<tokio::sync::Semaphore>,
39}
40
41struct DataExpiry;
42
43impl moka::Expiry<compact_str::CompactString, DataEntry> for DataExpiry {
44 fn expire_after_create(
45 &self,
46 _key: &compact_str::CompactString,
47 value: &DataEntry,
48 _created_at: Instant,
49 ) -> Option<Duration> {
50 Some(value.intended_ttl)
51 }
52}
53
54pub struct Cache {
55 client: Option<Arc<Client>>,
56 use_internal_cache: bool,
57 local: moka::future::Cache<compact_str::CompactString, DataEntry>,
58 local_task: tokio::task::JoinHandle<()>,
59 local_locks: moka::future::Cache<compact_str::CompactString, LockEntry>,
60 local_locks_task: tokio::task::JoinHandle<()>,
61 local_ratelimits: moka::future::Cache<compact_str::CompactString, (u64, u64)>,
62
63 cache_calls: AtomicU64,
64 cache_latency_ns_total: AtomicU64,
65 cache_latency_ns_max: AtomicU64,
66 cache_misses: AtomicU64,
67}
68
69impl Cache {
70 pub async fn new(env: &crate::env::Env) -> Arc<Self> {
71 let start = std::time::Instant::now();
72
73 let client = match &env.redis_mode {
74 RedisMode::Redis { redis_url } => {
75 if let Some(redis_url) = redis_url {
76 Some(Arc::new(Client::connect(redis_url.clone()).await.unwrap()))
77 } else {
78 None
79 }
80 }
81 RedisMode::Sentinel {
82 cluster_name,
83 redis_sentinels,
84 } => Some(Arc::new(
85 Client::connect(
86 format!(
87 "redis-sentinel://{}/{cluster_name}/0",
88 redis_sentinels.join(",")
89 )
90 .as_str(),
91 )
92 .await
93 .unwrap(),
94 )),
95 };
96
97 let local = moka::future::Cache::builder()
98 .max_capacity(16384)
99 .expire_after(DataExpiry)
100 .build();
101
102 let local_task = tokio::spawn({
103 let local = local.clone();
104
105 async move {
106 loop {
107 tokio::time::sleep(Duration::from_secs(10)).await;
108 local.run_pending_tasks().await;
109 }
110 }
111 });
112
113 let local_locks = moka::future::Cache::builder().max_capacity(4096).build();
114
115 let local_locks_task = tokio::spawn({
116 let local_locks = local_locks.clone();
117
118 async move {
119 loop {
120 tokio::time::sleep(Duration::from_secs(10)).await;
121 local_locks.run_pending_tasks().await;
122 }
123 }
124 });
125
126 let local_ratelimits = moka::future::Cache::builder().max_capacity(16384).build();
127
128 let instance = Arc::new(Self {
129 client,
130 use_internal_cache: env.app_use_internal_cache,
131 local,
132 local_task,
133 local_locks,
134 local_locks_task,
135 local_ratelimits,
136 cache_calls: AtomicU64::new(0),
137 cache_latency_ns_total: AtomicU64::new(0),
138 cache_latency_ns_max: AtomicU64::new(0),
139 cache_misses: AtomicU64::new(0),
140 });
141
142 let version = instance
143 .version()
144 .await
145 .unwrap_or_else(|_| "unknown".into());
146
147 tracing::info!(
148 "cache connected (redis@{}, {}ms, moka_enabled={})",
149 version,
150 start.elapsed().as_millis(),
151 env.app_use_internal_cache
152 );
153
154 instance
155 }
156
157 pub async fn version(&self) -> Result<compact_str::CompactString, rustis::Error> {
158 let Some(client) = &self.client else {
159 return Ok("memory-only".into());
160 };
161
162 let version: String = client.info([InfoSection::Server]).await?;
163 let version = version
164 .lines()
165 .find(|line| line.starts_with("redis_version:"))
166 .unwrap_or("redis_version:unknown")
167 .split(':')
168 .nth(1)
169 .unwrap_or("unknown")
170 .into();
171
172 Ok(version)
173 }
174
175 pub async fn ratelimit(
176 &self,
177 limit_identifier: impl AsRef<str>,
178 limit: u64,
179 limit_window: u64,
180 client: impl AsRef<str>,
181 ) -> Result<(), ApiResponse> {
182 let key = compact_str::format_compact!(
183 "ratelimit::{}::{}",
184 limit_identifier.as_ref(),
185 client.as_ref()
186 );
187
188 let now = chrono::Utc::now().timestamp();
189
190 if let Some(redis_client) = &self.client {
191 let expiry = redis_client.expiretime(&key).await.unwrap_or_default();
192 let expire_unix: u64 = if expiry > now + 2 {
193 expiry as u64
194 } else {
195 now as u64 + limit_window
196 };
197
198 let limit_used = redis_client.get::<u64>(&key).await.unwrap_or_default() + 1;
199 redis_client
200 .set_with_options(key, limit_used, None, SetExpiration::Exat(expire_unix))
201 .await?;
202
203 if limit_used >= limit {
204 return Err(ApiResponse::error(format!(
205 "you are ratelimited, retry in {}s",
206 expiry - now
207 ))
208 .with_status(StatusCode::TOO_MANY_REQUESTS));
209 }
210 } else {
211 let mut current_count = 0;
212 let mut expire_unix = now as u64 + limit_window;
213
214 if let Some((count, exp)) = self.local_ratelimits.get(&key).await
215 && exp > now as u64 + 2
216 {
217 current_count = count;
218 expire_unix = exp;
219 }
220
221 let limit_used = current_count + 1;
222 self.local_ratelimits
223 .insert(key, (limit_used, expire_unix))
224 .await;
225
226 if limit_used >= limit {
227 return Err(ApiResponse::error(format!(
228 "you are ratelimited, retry in {}s",
229 expire_unix.saturating_sub(now as u64)
230 ))
231 .with_status(StatusCode::TOO_MANY_REQUESTS));
232 }
233 }
234
235 Ok(())
236 }
237
238 #[tracing::instrument(skip(self))]
239 pub async fn lock(
240 &self,
241 lock_id: impl Into<compact_str::CompactString> + std::fmt::Debug,
242 ttl: Option<u64>,
243 timeout: Option<u64>,
244 ) -> Result<CacheLock, anyhow::Error> {
245 let lock_id = lock_id.into();
246 let redis_key = compact_str::format_compact!("lock::{}", lock_id);
247 let ttl_secs = ttl.unwrap_or(30);
248 let deadline = timeout.map(|ms| Instant::now() + Duration::from_secs(ms));
249
250 tracing::debug!("acquiring cache lock");
251
252 let entry = self
253 .local_locks
254 .entry(lock_id.clone())
255 .or_insert_with(async {
256 LockEntry {
257 semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
258 }
259 })
260 .await
261 .into_value();
262
263 let permit = match deadline {
264 Some(dl) => {
265 let remaining = dl.saturating_duration_since(Instant::now());
266 tokio::time::timeout(remaining, entry.semaphore.acquire_owned())
267 .await
268 .map_err(|_| anyhow::anyhow!("timed out waiting for cache lock `{}`", lock_id))?
269 .map_err(|_| anyhow::anyhow!("semaphore closed for lock `{}`", lock_id))?
270 }
271 None => entry
272 .semaphore
273 .acquire_owned()
274 .await
275 .map_err(|_| anyhow::anyhow!("semaphore closed for lock `{}`", lock_id))?,
276 };
277
278 if let Some(redis_client) = &self.client {
279 match Self::try_acquire_redis_lock(redis_client, &redis_key, ttl_secs, deadline).await?
280 {
281 true => {
282 tracing::debug!("acquired redis cache lock");
283 Ok(CacheLock::new(
284 lock_id,
285 Some(redis_client.clone()),
286 permit,
287 ttl,
288 ))
289 }
290 false => anyhow::bail!("timed out acquiring redis lock `{}`", lock_id),
291 }
292 } else {
293 tracing::debug!("acquired memory cache lock");
294 Ok(CacheLock::new(lock_id, None, permit, ttl))
295 }
296 }
297
298 async fn try_acquire_redis_lock(
299 client: &Arc<Client>,
300 redis_key: &compact_str::CompactString,
301 ttl_secs: u64,
302 deadline: Option<Instant>,
303 ) -> Result<bool, anyhow::Error> {
304 loop {
305 let acquired = client
306 .set_with_options(
307 redis_key.as_str(),
308 "1",
309 SetCondition::NX,
310 SetExpiration::Ex(ttl_secs),
311 )
312 .await
313 .unwrap_or(false);
314
315 if acquired {
316 return Ok(true);
317 }
318
319 if let Some(dl) = deadline {
320 let remaining = dl.saturating_duration_since(Instant::now());
321 if remaining.is_zero() {
322 return Ok(false);
323 }
324 tokio::time::sleep(remaining.min(Duration::from_millis(50))).await;
325 } else {
326 tokio::time::sleep(Duration::from_millis(50)).await;
327 }
328 }
329 }
330
331 #[tracing::instrument(skip(self, fn_compute))]
332 pub async fn cached<
333 T: Serialize + DeserializeOwned + Send,
334 F: FnOnce() -> Fut,
335 Fut: Future<Output = Result<T, FutErr>>,
336 FutErr: Into<anyhow::Error> + Send + Sync + 'static,
337 >(
338 &self,
339 key: &str,
340 ttl: u64,
341 fn_compute: F,
342 ) -> Result<T, anyhow::Error> {
343 let effective_moka_ttl = if self.use_internal_cache {
344 Duration::from_secs(ttl)
345 } else {
346 Duration::from_millis(50)
347 };
348
349 let client_opt = self.client.clone();
350
351 self.cache_calls.fetch_add(1, Ordering::Relaxed);
352 let start_time = Instant::now();
353
354 if let Some(entry) = self.local.get(key).await {
355 tracing::debug!("found in moka cache");
356 return Ok(rmp_serde::from_slice::<T>(&entry.data)?);
357 }
358
359 let entry = self
360 .local
361 .try_get_with(key.to_compact_string(), async move {
362 if let Some(client) = &client_opt {
363 tracing::debug!("checking redis cache");
364 let cached_value: Option<BulkString> = client
365 .get(key)
366 .await
367 .map_err(|err| {
368 tracing::error!("redis get error: {:?}", err);
369 err
370 })
371 .ok()
372 .flatten();
373
374 if let Some(value) = cached_value {
375 tracing::debug!("found in redis cache");
376 return Ok(DataEntry {
377 data: Arc::new(value.to_vec()),
378 intended_ttl: effective_moka_ttl,
379 });
380 }
381 }
382
383 self.cache_misses.fetch_add(1, Ordering::Relaxed);
384
385 tracing::debug!("executing compute");
386 let result = fn_compute().await.map_err(|e| e.into())?;
387 tracing::debug!("executed compute");
388
389 let serialized = rmp_serde::to_vec(&result)?;
390 let serialized_arc = Arc::new(serialized);
391
392 if let Some(client) = &client_opt {
393 let _ = client
394 .set_with_options(
395 key,
396 BulkStringRef(&serialized_arc),
397 None,
398 SetExpiration::Ex(ttl),
399 )
400 .await;
401 }
402
403 Ok::<_, anyhow::Error>(DataEntry {
404 data: serialized_arc,
405 intended_ttl: effective_moka_ttl,
406 })
407 })
408 .await;
409
410 let elapsed_ns = start_time.elapsed().as_nanos() as u64;
411 self.cache_latency_ns_total
412 .fetch_add(elapsed_ns, Ordering::Relaxed);
413
414 let _ = self.cache_latency_ns_max.fetch_update(
415 Ordering::Relaxed,
416 Ordering::Relaxed,
417 |current_max| {
418 if elapsed_ns > current_max {
419 Some(elapsed_ns)
420 } else {
421 Some(current_max)
422 }
423 },
424 );
425
426 match entry {
427 Ok(internal_entry) => Ok(rmp_serde::from_slice::<T>(&internal_entry.data)?),
428 Err(arc_error) => Err(anyhow::anyhow!("cache computation failed: {:?}", arc_error)),
429 }
430 }
431
432 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, anyhow::Error> {
433 if let Some(entry) = self.local.get(key).await {
434 tracing::debug!("get: found in moka cache");
435 return Ok(Some(rmp_serde::from_slice::<T>(&entry.data)?));
436 }
437
438 if let Some(client) = &self.client {
439 tracing::debug!("get: checking redis cache");
440 let cached_value: Option<BulkString> = client.get(key).await?;
441
442 if let Some(value) = cached_value {
443 tracing::debug!("get: found in redis cache");
444 let data = Arc::new(value.to_vec());
445 return Ok(Some(rmp_serde::from_slice::<T>(&data)?));
446 }
447 }
448
449 Ok(None)
450 }
451
452 pub async fn set<T: Serialize + Send + Sync>(
453 &self,
454 key: &str,
455 ttl: u64,
456 value: &T,
457 ) -> Result<(), anyhow::Error> {
458 let serialized = rmp_serde::to_vec(value)?;
459 let serialized_arc = Arc::new(serialized);
460
461 let effective_moka_ttl = if self.use_internal_cache {
462 Duration::from_secs(ttl)
463 } else {
464 Duration::from_millis(50)
465 };
466
467 self.local
468 .insert(
469 key.to_compact_string(),
470 DataEntry {
471 data: serialized_arc.clone(),
472 intended_ttl: effective_moka_ttl,
473 },
474 )
475 .await;
476
477 if let Some(client) = &self.client {
478 client
479 .set_with_options(
480 key,
481 BulkStringRef(&serialized_arc),
482 None,
483 SetExpiration::Ex(ttl),
484 )
485 .await?;
486 }
487
488 Ok(())
489 }
490
491 pub async fn invalidate(&self, key: &str) -> Result<(), anyhow::Error> {
492 self.local.invalidate(key).await;
493 if let Some(client) = &self.client {
494 client.del(key).await?;
495 }
496
497 Ok(())
498 }
499
500 #[inline]
501 pub fn cache_calls(&self) -> u64 {
502 self.cache_calls.load(Ordering::Relaxed)
503 }
504
505 #[inline]
506 pub fn cache_misses(&self) -> u64 {
507 self.cache_misses.load(Ordering::Relaxed)
508 }
509
510 #[inline]
511 pub fn cache_latency_ns_average(&self) -> u64 {
512 let calls = self.cache_calls();
513 if calls == 0 {
514 0
515 } else {
516 self.cache_latency_ns_total.load(Ordering::Relaxed) / calls
517 }
518 }
519}
520
521impl Drop for Cache {
522 fn drop(&mut self) {
523 self.local_task.abort();
524 self.local_locks_task.abort();
525 }
526}
527
528pub struct CacheLock {
529 lock_id: Option<compact_str::CompactString>,
530 redis_client: Option<Arc<Client>>,
531 permit: Option<tokio::sync::OwnedSemaphorePermit>,
532 ttl_guard: Option<tokio::task::JoinHandle<()>>,
533}
534
535impl CacheLock {
536 fn new(
537 lock_id: compact_str::CompactString,
538 redis_client: Option<Arc<Client>>,
539 permit: tokio::sync::OwnedSemaphorePermit,
540 ttl: Option<u64>,
541 ) -> Self {
542 let ttl_guard = ttl.and_then(|secs| {
543 let lock_id_clone = lock_id.clone();
544 redis_client.clone().map(|client| {
545 tokio::spawn(async move {
546 tokio::time::sleep(Duration::from_secs(secs)).await;
547 tracing::warn!(%lock_id_clone, "cache lock TTL expired; force-releasing");
548 let redis_key = compact_str::format_compact!("lock::{}", lock_id_clone);
549 let _ = client.del(&redis_key).await;
550 })
551 })
552 });
553
554 Self {
555 lock_id: Some(lock_id),
556 redis_client,
557 permit: Some(permit),
558 ttl_guard,
559 }
560 }
561
562 #[inline]
563 pub fn is_active(&self) -> bool {
564 self.lock_id.is_some() && self.ttl_guard.as_ref().is_none_or(|h| !h.is_finished())
565 }
566}
567
568impl Drop for CacheLock {
569 fn drop(&mut self) {
570 if let Some(ttl_guard) = self.ttl_guard.take() {
571 ttl_guard.abort();
572 }
573
574 self.permit.take();
575
576 if let Some(lock_id) = self.lock_id.take()
577 && let Some(client) = self.redis_client.take()
578 {
579 tokio::spawn(async move {
580 let redis_key = compact_str::format_compact!("lock::{}", lock_id);
581 let _ = client.del(&redis_key).await;
582 });
583 }
584 }
585}