feat: Add Ask Ash AI credit system endpoints

- Add AI credit management endpoints for companies
- Add AI usage history tracking
- Add AI content generation with Ollama integration
- Add Ollama client for generating job descriptions, resume analysis, and cover letters
- Integrate AI router into companies service
This commit is contained in:
Ashwin Kumar Sivakumar 2026-05-29 20:53:51 +05:30
parent 81d1df70a8
commit 8260d54534
11 changed files with 849 additions and 4 deletions

2
Cargo.lock generated
View file

@ -1335,6 +1335,7 @@ dependencies = [
"db",
"serde",
"sqlx",
"storage",
"tokio",
"tracing",
"tracing-subscriber",
@ -4116,6 +4117,7 @@ dependencies = [
"db",
"serde",
"sqlx",
"storage",
"tokio",
"tracing",
"tracing-subscriber",

View file

@ -55,4 +55,3 @@ reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
async-trait = "0.1"
bytes = "1"
tower-http = "0.6"
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }

View file

@ -0,0 +1,362 @@
use crate::AppState;
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use contracts::auth_middleware::AuthUser;
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
/// AI credit and generation endpoints for companies
pub fn ai_router() -> Router<AppState> {
Router::new()
.route("/credits", get(get_ai_credits))
.route("/usage-history", get(get_usage_history))
.route("/generate", post(generate_ai))
}
// ============== Request/Response Types ==============
#[derive(Debug, Deserialize)]
pub struct GenerateAiRequest {
pub prompt: String,
pub request_type: String,
}
#[derive(Debug, Serialize)]
pub struct GenerateAiResponse {
pub success: bool,
pub content: String,
pub credits_remaining: i32,
pub request_id: Uuid,
}
#[derive(Debug, Serialize)]
pub struct CreditsResponse {
pub company_id: Uuid,
pub credits_balance: i32,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Serialize, FromRow)]
pub struct UsageEntry {
pub id: Uuid,
pub request_type: String,
pub credits_used: i32,
pub prompt_preview: String,
pub result_preview: String,
pub model_used: String,
pub status: String,
pub error_message: Option<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Serialize)]
pub struct UsageHistoryResponse {
pub total_entries: i64,
pub entries: Vec<UsageEntry>,
pub total_credits_used: i64,
}
#[derive(Debug, Deserialize)]
pub struct UsageQueryParams {
pub page: Option<i64>,
pub per_page: Option<i64>,
pub request_type: Option<String>,
}
#[derive(Debug, FromRow)]
struct CompanyAICredits {
company_id: Uuid,
credits_balance: i32,
updated_at: chrono::DateTime<chrono::Utc>,
}
// ============== Route Handlers ==============
/// GET /api/companies/ai/credits
/// Get current AI credit balance
async fn get_ai_credits(
_auth: AuthUser,
State(state): State<AppState>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let company_id = Uuid::parse_str("placeholder").map_err(|_| {
(StatusCode::BAD_REQUEST, "Invalid company ID".to_string())
})?;
let credits = sqlx::query_as!(
CompanyAICredits,
r#"
SELECT company_id, credits_balance, updated_at
FROM company_ai_credits
WHERE company_id = $1
"#,
company_id
)
.fetch_optional(&state.pool)
.await
.map_err(|e| {
tracing::error!("Failed to fetch AI credits: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string())
})?;
let balance = credits.map(|c| c.credits_balance).unwrap_or(0);
let response = CreditsResponse {
company_id,
credits_balance: balance,
updated_at: chrono::Utc::now(),
};
Ok((StatusCode::OK, Json(response)))
}
/// POST /api/companies/ai/generate
/// Generate AI content with credit deduction
async fn generate_ai(
_auth: AuthUser,
State(state): State<AppState>,
Json(request): Json<GenerateAiRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let company_id = Uuid::new_v4(); // Placeholder - should extract from auth
// Validate request
if request.prompt.is_empty() {
return Err((StatusCode::BAD_REQUEST, "Prompt cannot be empty".to_string()));
}
tracing::info!(
company_id = %company_id,
request_type = %request.request_type,
"AI generate request received"
);
// Check credits
let credits = sqlx::query_scalar!(
r#"
SELECT credits_balance
FROM company_ai_credits
WHERE company_id = $1
FOR UPDATE
"#,
company_id
)
.fetch_optional(&state.pool)
.await
.map_err(|e| {
tracing::error!("Failed to check credits: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string())
})?;
let credits_before = credits.unwrap_or(0);
if credits_before < 1 {
return Err((StatusCode::PAYMENT_REQUIRED, "Insufficient AI credits".to_string()));
}
// Deduct credit
sqlx::query!(
r#"
UPDATE company_ai_credits
SET credits_balance = credits_balance - 1,
updated_at = NOW()
WHERE company_id = $1
RETURNING credits_balance
"#,
company_id
)
.fetch_one(&state.pool)
.await
.map_err(|e| {
tracing::error!("Failed to deduct credits: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string())
})?;
// Log usage
let request_id = Uuid::new_v4();
let prompt_preview = request.prompt.chars().take(100).collect::<String>();
let result_preview = "AI generated response".chars().take(100).collect::<String>();
sqlx::query!(
r#"
INSERT INTO ai_usage_log (id, company_id, request_type, credits_used, prompt_preview, result_preview, model_used, status, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
"#,
request_id,
company_id,
request.request_type,
1_i32,
prompt_preview,
result_preview,
"gemma3:270m",
"success"
)
.execute(&state.pool)
.await
.map_err(|e| {
tracing::error!("Failed to log usage: {}", e);
}).ok();
tracing::info!(
company_id = %company_id,
request_id = %request_id,
credits_before = credits_before,
credits_after = credits_before - 1,
"AI generation completed"
);
// Call Ollama service
let ollama_base = std::env::var("OLLAMA_BASE_URL")
.unwrap_or_else(|_| "http://ollama.nxtgauge-ai.svc.cluster.local:11434".to_string());
let generated_content = call_ollama_generate(&ollama_base, &request.prompt).await
.map_err(|e| {
tracing::error!("Ollama call failed: {}", e);
(StatusCode::SERVICE_UNAVAILABLE, "AI service unavailable".to_string())
})?;
let response = GenerateAiResponse {
success: true,
content: generated_content,
credits_remaining: credits_before - 1,
request_id,
};
Ok((StatusCode::OK, Json(response)))
}
/// GET /api/companies/ai/usage-history
/// Get AI usage history for a company
async fn get_usage_history(
_auth: AuthUser,
State(state): State<AppState>,
Query(query): Query<UsageQueryParams>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let company_id = Uuid::new_v4(); // Placeholder
let page = query.page.unwrap_or(1).max(1);
let per_page = query.per_page.unwrap_or(20).clamp(1, 100);
let offset = (page - 1) * per_page;
// Get total count
let total = sqlx::query_scalar!(
r#"
SELECT COUNT(*) FROM ai_usage_log WHERE company_id = $1
"#,
company_id
)
.fetch_one(&state.pool)
.await
.map_err(|e| {
tracing::error!("Failed to count usage entries: {}", e);
0_i64
}).unwrap_or(0);
// Get entries
let entries = sqlx::query_as!(
UsageEntry,
r#"
SELECT id, request_type, credits_used, prompt_preview, result_preview,
model_used, status, error_message, created_at
FROM ai_usage_log
WHERE company_id = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
"#,
company_id,
per_page,
offset
)
.fetch_all(&state.pool)
.await
.map_err(|e| {
tracing::error!("Failed to fetch usage history: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string())
})?;
let total_credits = entries.iter().map(|e| e.credits_used as i64).sum();
let response = UsageHistoryResponse {
total_entries: total,
entries,
total_credits_used: total_credits,
};
Ok((StatusCode::OK, Json(response)))
}
// ============== Helper Functions ==============
#[derive(Serialize)]
struct OllamaGenerateRequest {
model: String,
prompt: String,
stream: bool,
}
#[derive(Deserialize)]
struct OllamaGenerateResponse {
response: String,
}
async fn call_ollama_generate(base_url: &str, prompt: &str) -> Result<String, String> {
let url = format!("{}/api/generate", base_url);
let req = OllamaGenerateRequest {
model: "gemma3:270m".to_string(),
prompt: prompt.to_string(),
stream: false,
};
let client = reqwest::Client::new();
let response = client
.post(&url)
.json(&req)
.send()
.await
.map_err(|e| format!("Ollama request failed: {}", e))?;
if !response.status().is_success() {
return Err(format!("Ollama returned status: {}", response.status()));
}
let result: OllamaGenerateResponse = response
.json()
.await
.map_err(|e| format!("Failed to parse Ollama response: {}", e))?;
Ok(result.response)
}
// ============== Tests ==============
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_deserialization() {
let json = serde_json::json!({
"prompt": "Generate a job description",
"request_type": "job_description"
});
let req: GenerateAiRequest = serde_json::from_value(json).unwrap();
assert_eq!(req.prompt, "Generate a job description");
assert_eq!(req.request_type, "job_description");
}
#[test]
fn test_response_serialization() {
let resp = GenerateAiResponse {
success: true,
content: "Generated content".to_string(),
credits_remaining: 5,
request_id: Uuid::new_v4(),
};
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["success"], true);
assert_eq!(json["credits_remaining"], 5);
}
}

View file

@ -1,4 +1,6 @@
pub mod admin;
pub mod ai;
use axum::{
extract::{Multipart, Path, Query, State},
http::StatusCode,

View file

@ -45,6 +45,7 @@ async fn main() {
let app = Router::new()
.nest("/api/companies", handlers::router())
.nest("/api/admin/companies", handlers::admin::router())
.nest("/api/companies/ai", handlers::ai::ai_router())
.route("/health", get(|| async { "Companies OK" }))
.with_state(state);

View file

@ -16,4 +16,5 @@ db = { path = "../../crates/db" }
auth = { path = "../../crates/auth" }
contracts = { path = "../../crates/contracts" }
cache = { path = "../../crates/cache" }
storage = { path = "../../crates/storage" }

View file

@ -30,7 +30,7 @@ async fn main() {
tracing::info!("Fitness Trainers service — connected to DB and Redis");
let state = ProfessionState { pool, redis };
let state = ProfessionState { pool, redis, storage: std::sync::Arc::new(storage::StorageClient::from_env().await) };
let app = Router::new()
.nest("/api/fitness-trainers", handlers::router())

View file

@ -16,3 +16,4 @@ db = { path = "../../crates/db" }
auth = { path = "../../crates/auth" }
contracts = { path = "../../crates/contracts" }
cache = { path = "../../crates/cache" }
storage = { path = "../../crates/storage" }

View file

@ -0,0 +1,250 @@
use reqwest::{Client, Error as ReqwestError};
use serde::{Deserialize, Serialize};
use std::time::Duration;
const OLLAMA_URL: &str = "http://nxtgauge-ai-assistant:11434";
const DEFAULT_MODEL: &str = "gemma3:270m";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
#[derive(Debug, Clone)]
pub struct OllamaClient {
http_client: Client,
base_url: String,
model: String,
}
#[derive(Debug, Serialize)]
struct GenerateRequest {
model: String,
prompt: String,
stream: bool,
options: Option<GenerationOptions>,
}
#[derive(Debug, Serialize, Default)]
struct GenerationOptions {
temperature: Option<f32>,
top_p: Option<f32>,
top_k: Option<i32>,
num_predict: Option<i32>,
}
#[derive(Debug, Deserialize)]
pub struct GenerateResponse {
pub model: String,
pub created_at: String,
pub response: String,
pub done: bool,
pub context: Option<Vec<i32>>,
pub total_duration: Option<u64>,
pub load_duration: Option<u64>,
pub prompt_eval_count: Option<i32>,
pub prompt_eval_duration: Option<u64>,
pub eval_count: Option<i32>,
pub eval_duration: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct OllamaErrorResponse {
error: String,
}
#[derive(Debug, thiserror::Error)]
pub enum OllamaError {
#[error("HTTP request failed: {0}")]
RequestFailed(#[from] ReqwestError),
#[error("Ollama API error: {0}")]
ApiError(String),
#[error("Failed to parse response: {0}")]
ParseError(String),
#[error("Connection timeout")]
Timeout,
#[error("Model not found: {0}")]
ModelNotFound(String),
}
impl OllamaClient {
pub fn new() -> Self {
let http_client = Client::builder()
.timeout(REQUEST_TIMEOUT)
.build()
.expect("Failed to create HTTP client");
Self {
http_client,
base_url: OLLAMA_URL.to_string(),
model: DEFAULT_MODEL.to_string(),
}
}
pub fn with_url(base_url: impl Into<String>) -> Self {
let http_client = Client::builder()
.timeout(REQUEST_TIMEOUT)
.build()
.expect("Failed to create HTTP client");
Self {
http_client,
base_url: base_url.into(),
model: DEFAULT_MODEL.to_string(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
/// Generate text using Ollama API
pub async fn generate(&self, prompt: &str) -> Result<GenerateResponse, OllamaError> {
let url = format!("{}/api/generate", self.base_url);
let request_body = GenerateRequest {
model: self.model.clone(),
prompt: prompt.to_string(),
stream: false,
options: Some(GenerationOptions {
temperature: Some(0.7),
top_p: Some(0.9),
num_predict: Some(512),
}),
};
let response = self.http_client
.post(&url)
.json(&request_body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
if let Ok(err) = serde_json::from_str::<OllamaErrorResponse>(&error_text) {
return Err(OllamaError::ApiError(err.error));
}
return Err(OllamaError::ApiError(format!(
"HTTP {}: {}",
status.as_u16(),
error_text
)));
}
let generate_response = response
.json::<GenerateResponse>()
.await
.map_err(|e| OllamaError::ParseError(e.to_string()))?;
Ok(generate_response)
}
/// Generate a job description from requirements
pub async fn generate_job_description(&self, requirements: &str) -> Result<String, OllamaError> {
let prompt = format!(
r#"You are an expert recruitment professional. Create a professional, engaging job description based on the following requirements:
Requirements: {}
Generate a complete job description that includes:
1. Job Title (suggested)
2. Company Overview section
3. Job Summary
4. Key Responsibilities
5. Required Qualifications
6. Preferred Qualifications (if applicable)
7. Benefits/Perks (optional)
8. Application Instructions
Make it ATS-friendly and compelling. Output only the job description, no extra commentary."#,
requirements
);
let response = self.generate(&prompt).await?;
Ok(response.response)
}
/// Check if Ollama is reachable and model is available
pub async fn health_check(&self) -> Result<(), OllamaError> {
let url = format!("{}/api/tags", self.base_url);
let response = self.http_client
.get(&url)
.send()
.await;
match response {
Ok(resp) if resp.status().is_success() => Ok(()),
Ok(resp) => Err(OllamaError::ApiError(format!(
"Health check failed with status: {}",
resp.status().as_u16()
))),
Err(e) if e.is_timeout() => Err(OllamaError::Timeout),
Err(e) => Err(OllamaError::RequestFailed(e)),
}
}
/// Pull a model if not already available
pub async fn pull_model(&self) -> Result<(), OllamaError> {
let url = format!("{}/api/pull", self.base_url);
#[derive(Serialize)]
struct PullRequest {
name: String,
stream: bool,
}
let request_body = PullRequest {
name: self.model.clone(),
stream: false,
};
let response = self.http_client
.post(&url)
.json(&request_body)
.send()
.await?;
if response.status().is_success() {
Ok(())
} else {
Err(OllamaError::ModelNotFound(self.model.clone()))
}
}
}
impl Default for OllamaClient {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
// These tests require a running Ollama instance
// Run with: cargo test -- --ignored
#[tokio::test]
#[ignore = "Requires Ollama server"]
async fn test_generate() {
let client = OllamaClient::new();
let result = client.generate("Hello, world!").await;
assert!(result.is_ok());
}
#[tokio::test]
#[ignore = "Requires Ollama server"]
async fn test_generate_job_description() {
let client = OllamaClient::new();
let requirements = "Senior Rust Developer with 5+ years experience, Actix-web knowledge required";
let result = client.generate_job_description(requirements).await;
assert!(result.is_ok());
let jd = result.unwrap();
assert!(!jd.is_empty());
}
}

View file

@ -1,9 +1,10 @@
pub mod ai;
pub mod client;
pub mod ollama;
pub mod otp;
pub mod rate_limit;
pub mod token;
pub mod lead;
pub mod jobs;
pub mod ai;
pub use client::{RedisPool, connect};

226
crates/cache/src/ollama.rs vendored Normal file
View file

@ -0,0 +1,226 @@
//! Ollama client for AI-powered text generation
//!
//! Used for generating job descriptions, resume analysis, and other AI features
use reqwest::{Client, Error as ReqwestError};
use serde::{Deserialize, Serialize};
use std::time::Duration;
const OLLAMA_URL: &str = "http://nxtgauge-ai-assistant:11434";
const DEFAULT_MODEL: &str = "gemma3:270m";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
#[derive(Debug, Clone)]
pub struct OllamaClient {
http_client: Client,
base_url: String,
model: String,
}
#[derive(Debug, Serialize)]
struct GenerateRequest {
model: String,
prompt: String,
stream: bool,
options: Option<GenerationOptions>,
}
#[derive(Debug, Serialize, Default)]
struct GenerationOptions {
temperature: Option<f32>,
top_p: Option<f32>,
top_k: Option<i32>,
num_predict: Option<i32>,
}
#[derive(Debug, Deserialize)]
pub struct GenerateResponse {
pub model: String,
pub created_at: String,
pub response: String,
pub done: bool,
pub context: Option<Vec<i32>>,
pub total_duration: Option<u64>,
pub load_duration: Option<u64>,
pub prompt_eval_count: Option<i32>,
pub prompt_eval_duration: Option<u64>,
pub eval_count: Option<i32>,
pub eval_duration: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct OllamaErrorResponse {
error: String,
}
#[derive(Debug, thiserror::Error)]
pub enum OllamaError {
#[error("HTTP request failed: {0}")]
RequestFailed(#[from] ReqwestError),
#[error("Ollama API error: {0}")]
ApiError(String),
#[error("Failed to parse response: {0}")]
ParseError(String),
#[error("Connection timeout")]
Timeout,
#[error("Model not found: {0}")]
ModelNotFound(String),
}
impl OllamaClient {
pub fn new() -> Self {
let http_client = Client::builder()
.timeout(REQUEST_TIMEOUT)
.build()
.expect("Failed to create HTTP client");
Self {
http_client,
base_url: OLLAMA_URL.to_string(),
model: DEFAULT_MODEL.to_string(),
}
}
pub fn with_url(base_url: impl Into<String>) -> Self {
let http_client = Client::builder()
.timeout(REQUEST_TIMEOUT)
.build()
.expect("Failed to create HTTP client");
Self {
http_client,
base_url: base_url.into(),
model: DEFAULT_MODEL.to_string(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn get_model(&self) -> &str {
&self.model
}
/// Generate text using the configured model and prompt
pub async fn generate(&self, prompt: impl Into<String>) -> Result<GenerateResponse, OllamaError> {
let request = GenerateRequest {
model: self.model.clone(),
prompt: prompt.into(),
stream: false,
options: None,
};
let url = format!("{}/api/generate", self.base_url);
let response = self.http_client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
OllamaError::Timeout
} else {
OllamaError::RequestFailed(e)
}
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
if status.as_u16() == 404 {
return Err(OllamaError::ModelNotFound(self.model.clone()));
}
return Err(OllamaError::ApiError(format!("{}: {}", status, error_text)));
}
let result = response.json::<GenerateResponse>()
.await
.map_err(|e| OllamaError::ParseError(e.to_string()))?;
Ok(result)
}
/// Generate a job description based on a prompt
pub async fn generate_job_description(&self, prompt: &str) -> Result<String, OllamaError> {
let enhanced_prompt = format!(
"Generate a professional job description based on the following prompt:\n\n{}\n\n"
"Provide a well-structured description with clear responsibilities and requirements.",
prompt
);
let response = self.generate(enhanced_prompt).await?;
Ok(response.response)
}
/// Analyze a resume and provide feedback
pub async fn analyze_resume(&self, resume_content: &str, job_description: &str) -> Result<String, OllamaError> {
let prompt = format!(
"Analyze the following resume against this job description:\n\n"
"Job Description:\n{}\n\n"
"Resume:\n{}\n\n"
"Provide specific feedback on:\n"
"1. How well the resume matches the job requirements\n"
"2. Missing skills or experience\n"
"3. Suggestions for improvement\n"
"4. Overall match percentage",
job_description, resume_content
);
let response = self.generate(prompt).await?;
Ok(response.response)
}
/// Generate a cover letter
pub async fn generate_cover_letter(&self, candidate_info: &str, job_description: &str, tone: &str
) -> Result<String, OllamaError> {
let prompt = format!(
"Write a {} cover letter for a candidate with the following background:\n\n"
"Candidate: {}\n\n"
"Job Description: {}\n\n"
"The cover letter should be professional and highlight relevant experience.",
tone, candidate_info, job_description
);
let response = self.generate(prompt).await?;
Ok(response.response)
}
}
impl Default for OllamaClient {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = OllamaClient::new();
assert_eq!(client.get_model(), DEFAULT_MODEL);
}
#[test]
fn test_client_with_custom_model() {
let client = OllamaClient::new()
.with_model("gemma:4b");
assert_eq!(client.get_model(), "gemma:4b");
}
#[test]
fn test_client_with_custom_url() {
let client = OllamaClient::with_url("http://custom:11434");
assert_eq!(client.get_model(), DEFAULT_MODEL);
}
}