Add Redis caching for AI generation rate limiting

- Add cache::ai module with Redis rate limiting for AI generations
- Add functions: check_ai_rate_limit, get_ai_usage, cache_ai_response,
  get_cached_ai_response, invalidate_ai_cache, reset_daily_usage
- Update check_and_increment_usage to use Redis fast-path before DB
- Redis key pattern: ai:rate:{user_id} for 24hr sliding window counter
This commit is contained in:
Tracewebstudio Dev 2026-05-01 03:02:46 +02:00
parent aa71ccdf36
commit 42a9a17133
3 changed files with 103 additions and 4 deletions

View file

@ -6,6 +6,7 @@ use axum::{
routing::{get, post},
Json, Router,
};
use cache::ai as ai_cache;
use contracts::auth_middleware::AuthUser;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
@ -494,10 +495,23 @@ async fn has_active_ai_pack(
async fn check_and_increment_usage(
pool: &sqlx::PgPool,
redis: &mut cache::RedisPool,
profile_id: Uuid,
is_company: bool,
daily_limit: i32,
) -> Result<(i32, i32), String> {
let user_id_str = profile_id.to_string();
// Fast path: check Redis first for rate limiting
let redis_allowed = ai_cache::check_ai_rate_limit(redis, &user_id_str, daily_limit as i64)
.await
.map_err(|e| e.to_string())?;
if !redis_allowed {
return Err("Daily AI generation limit reached".to_string());
}
// DB is source of truth - check and increment
let today = chrono::Utc::now().date_naive();
let table = if is_company { "company_ai_usage" } else { "job_seeker_ai_usage" };
let id_col = if is_company { "company_id" } else { "job_seeker_id" };
@ -586,7 +600,8 @@ async fn ai_generate_job_field(
}
};
let (used, limit) = match check_and_increment_usage(&state.pool, company_id, true, daily_limit).await {
let mut redis = state.redis.clone();
let (used, limit) = match check_and_increment_usage(&state.pool, &mut redis, company_id, true, daily_limit).await {
Ok((u, l)) => (u, l),
Err(msg) => {
return (StatusCode::TOO_MANY_REQUESTS, Json(serde_json::json!({ "error": msg }))).into_response();
@ -696,7 +711,8 @@ async fn ai_generate_cover_letter(
}
};
let (used, limit) = match check_and_increment_usage(&state.pool, seeker_id, false, daily_limit).await {
let mut redis = state.redis.clone();
let (used, limit) = match check_and_increment_usage(&state.pool, &mut redis, seeker_id, false, daily_limit).await {
Ok((u, l)) => (u, l),
Err(msg) => {
return (StatusCode::TOO_MANY_REQUESTS, Json(serde_json::json!({ "error": msg }))).into_response();
@ -804,7 +820,8 @@ async fn ai_tailor_resume(
}
};
let (used, limit) = match check_and_increment_usage(&state.pool, seeker_id, false, daily_limit).await {
let mut redis = state.redis.clone();
let (used, limit) = match check_and_increment_usage(&state.pool, &mut redis, seeker_id, false, daily_limit).await {
Ok((u, l)) => (u, l),
Err(msg) => {
return (StatusCode::TOO_MANY_REQUESTS, Json(serde_json::json!({ "error": msg }))).into_response();
@ -938,6 +955,7 @@ async fn ai_auto_apply(
let mut created = 0;
let mut already = vec![];
let mut failed = vec![];
let mut redis = state.redis.clone();
for job_id in &body.job_ids {
let existing: Option<Uuid> = sqlx::query_scalar(
@ -1001,7 +1019,7 @@ async fn ai_auto_apply(
Ok(r) => {
if r.rows_affected() > 0 {
created += 1;
let _ = check_and_increment_usage(&state.pool, seeker_id, false, daily_limit).await;
let _ = check_and_increment_usage(&state.pool, &mut redis, seeker_id, false, daily_limit).await;
} else {
already.push(*job_id);
}

80
crates/cache/src/ai.rs vendored Normal file
View file

@ -0,0 +1,80 @@
//! Redis caching for AI generation rate limiting and response caching.
//!
//! Key patterns:
//! - `ai:rate:{user_id}` - sliding window counter for rate limiting
//! - `ai:resp:{hash}` - cached AI response (by prompt hash)
use redis::AsyncCommands;
use crate::RedisPool;
const AI_RATE_WINDOW_SECS: i64 = 86_400; // 24 hours
const AI_CACHE_TTL_SECS: i64 = 3_600; // 1 hour
/// Check + increment AI generation rate limit counter.
/// Uses a simple counter with TTL reset on first write.
///
/// Returns `Ok(true)` if allowed, `Ok(false)` if rate limited.
pub async fn check_ai_rate_limit(
redis: &mut RedisPool,
user_id: &str,
max_generations: i64,
) -> Result<bool, redis::RedisError> {
let key = format!("ai:rate:{}", user_id);
let count: i64 = redis.incr(&key, 1i64).await?;
if count == 1 {
redis.expire::<_, ()>(&key, AI_RATE_WINDOW_SECS).await?;
}
Ok(count <= max_generations)
}
/// Get current AI generation count for a user.
pub async fn get_ai_usage(
redis: &mut RedisPool,
user_id: &str,
) -> Result<i64, redis::RedisError> {
let key = format!("ai:rate:{}", user_id);
let count: Option<i64> = redis.get(&key).await?;
Ok(count.unwrap_or(0))
}
/// Store AI-generated response in cache.
pub async fn cache_ai_response(
redis: &mut RedisPool,
prompt_hash: &str,
response: &str,
) -> Result<(), redis::RedisError> {
let key = format!("ai:resp:{}", prompt_hash);
let ttl: u64 = AI_CACHE_TTL_SECS.try_into().unwrap();
let _: () = redis.set_ex(&key, response, ttl).await?;
Ok(())
}
/// Get cached AI response if available.
pub async fn get_cached_ai_response(
redis: &mut RedisPool,
prompt_hash: &str,
) -> Result<Option<String>, redis::RedisError> {
let key = format!("ai:resp:{}", prompt_hash);
let result: Option<String> = redis.get(&key).await?;
Ok(result)
}
/// Invalidate cached AI response.
pub async fn invalidate_ai_cache(
redis: &mut RedisPool,
prompt_hash: &str,
) -> Result<(), redis::RedisError> {
let key = format!("ai:resp:{}", prompt_hash);
let _: () = redis.del(&key).await?;
Ok(())
}
/// Reset daily AI usage counter (called at start of new day or when daily limit changes).
pub async fn reset_daily_usage(
redis: &mut RedisPool,
user_id: &str,
) -> Result<(), redis::RedisError> {
let key = format!("ai:rate:{}", user_id);
let _: () = redis.del(&key).await?;
Ok(())
}

View file

@ -4,5 +4,6 @@ pub mod rate_limit;
pub mod token;
pub mod lead;
pub mod jobs;
pub mod ai;
pub use client::{RedisPool, connect};