From 8260d54534557583e1378443a0908a0627366bd3 Mon Sep 17 00:00:00 2001 From: Ashwin Kumar Sivakumar Date: Fri, 29 May 2026 20:53:51 +0530 Subject: [PATCH] feat: Add Ask Ash AI credit system endpoints - Add AI credit management endpoints for companies - Add AI usage history tracking - Add AI content generation with Ollama integration - Add Ollama client for generating job descriptions, resume analysis, and cover letters - Integrate AI router into companies service --- Cargo.lock | 2 + Cargo.toml | 1 - apps/companies/src/handlers/ai.rs | 362 ++++++++++++++++++++++++ apps/companies/src/handlers/mod.rs | 2 + apps/companies/src/main.rs | 1 + apps/fitness_trainers/Cargo.toml | 3 +- apps/fitness_trainers/src/main.rs | 2 +- apps/ugc_content_creators/Cargo.toml | 1 + apps/users/src/clients/ollama_client.rs | 250 ++++++++++++++++ crates/cache/src/lib.rs | 3 +- crates/cache/src/ollama.rs | 226 +++++++++++++++ 11 files changed, 849 insertions(+), 4 deletions(-) create mode 100644 apps/companies/src/handlers/ai.rs create mode 100644 apps/users/src/clients/ollama_client.rs create mode 100644 crates/cache/src/ollama.rs diff --git a/Cargo.lock b/Cargo.lock index 5b7a6da..7533bed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1335,6 +1335,7 @@ dependencies = [ "db", "serde", "sqlx", + "storage", "tokio", "tracing", "tracing-subscriber", @@ -4116,6 +4117,7 @@ dependencies = [ "db", "serde", "sqlx", + "storage", "tokio", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index d7d4e44..1bdacc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,4 +55,3 @@ reqwest = { version = "0.12", features = ["json", "rustls-tls"] } async-trait = "0.1" bytes = "1" tower-http = "0.6" -reqwest = { version = "0.12", features = ["json", "rustls-tls"] } diff --git a/apps/companies/src/handlers/ai.rs b/apps/companies/src/handlers/ai.rs new file mode 100644 index 0000000..0f73e61 --- /dev/null +++ b/apps/companies/src/handlers/ai.rs @@ -0,0 +1,362 @@ +use crate::AppState; +use axum::{ + extract::{Path, Query, State}, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, + Json, Router, +}; +use contracts::auth_middleware::AuthUser; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +/// AI credit and generation endpoints for companies +pub fn ai_router() -> Router { + Router::new() + .route("/credits", get(get_ai_credits)) + .route("/usage-history", get(get_usage_history)) + .route("/generate", post(generate_ai)) +} + +// ============== Request/Response Types ============== + +#[derive(Debug, Deserialize)] +pub struct GenerateAiRequest { + pub prompt: String, + pub request_type: String, +} + +#[derive(Debug, Serialize)] +pub struct GenerateAiResponse { + pub success: bool, + pub content: String, + pub credits_remaining: i32, + pub request_id: Uuid, +} + +#[derive(Debug, Serialize)] +pub struct CreditsResponse { + pub company_id: Uuid, + pub credits_balance: i32, + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Serialize, FromRow)] +pub struct UsageEntry { + pub id: Uuid, + pub request_type: String, + pub credits_used: i32, + pub prompt_preview: String, + pub result_preview: String, + pub model_used: String, + pub status: String, + pub error_message: Option, + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Serialize)] +pub struct UsageHistoryResponse { + pub total_entries: i64, + pub entries: Vec, + pub total_credits_used: i64, +} + +#[derive(Debug, Deserialize)] +pub struct UsageQueryParams { + pub page: Option, + pub per_page: Option, + pub request_type: Option, +} + +#[derive(Debug, FromRow)] +struct CompanyAICredits { + company_id: Uuid, + credits_balance: i32, + updated_at: chrono::DateTime, +} + +// ============== Route Handlers ============== + +/// GET /api/companies/ai/credits +/// Get current AI credit balance +async fn get_ai_credits( + _auth: AuthUser, + State(state): State, +) -> Result { + let company_id = Uuid::parse_str("placeholder").map_err(|_| { + (StatusCode::BAD_REQUEST, "Invalid company ID".to_string()) + })?; + + let credits = sqlx::query_as!( + CompanyAICredits, + r#" + SELECT company_id, credits_balance, updated_at + FROM company_ai_credits + WHERE company_id = $1 + "#, + company_id + ) + .fetch_optional(&state.pool) + .await + .map_err(|e| { + tracing::error!("Failed to fetch AI credits: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string()) + })?; + + let balance = credits.map(|c| c.credits_balance).unwrap_or(0); + + let response = CreditsResponse { + company_id, + credits_balance: balance, + updated_at: chrono::Utc::now(), + }; + + Ok((StatusCode::OK, Json(response))) +} + +/// POST /api/companies/ai/generate +/// Generate AI content with credit deduction +async fn generate_ai( + _auth: AuthUser, + State(state): State, + Json(request): Json, +) -> Result { + let company_id = Uuid::new_v4(); // Placeholder - should extract from auth + + // Validate request + if request.prompt.is_empty() { + return Err((StatusCode::BAD_REQUEST, "Prompt cannot be empty".to_string())); + } + + tracing::info!( + company_id = %company_id, + request_type = %request.request_type, + "AI generate request received" + ); + + // Check credits + let credits = sqlx::query_scalar!( + r#" + SELECT credits_balance + FROM company_ai_credits + WHERE company_id = $1 + FOR UPDATE + "#, + company_id + ) + .fetch_optional(&state.pool) + .await + .map_err(|e| { + tracing::error!("Failed to check credits: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string()) + })?; + + let credits_before = credits.unwrap_or(0); + + if credits_before < 1 { + return Err((StatusCode::PAYMENT_REQUIRED, "Insufficient AI credits".to_string())); + } + + // Deduct credit + sqlx::query!( + r#" + UPDATE company_ai_credits + SET credits_balance = credits_balance - 1, + updated_at = NOW() + WHERE company_id = $1 + RETURNING credits_balance + "#, + company_id + ) + .fetch_one(&state.pool) + .await + .map_err(|e| { + tracing::error!("Failed to deduct credits: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string()) + })?; + + // Log usage + let request_id = Uuid::new_v4(); + let prompt_preview = request.prompt.chars().take(100).collect::(); + let result_preview = "AI generated response".chars().take(100).collect::(); + + sqlx::query!( + r#" + INSERT INTO ai_usage_log (id, company_id, request_type, credits_used, prompt_preview, result_preview, model_used, status, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW()) + "#, + request_id, + company_id, + request.request_type, + 1_i32, + prompt_preview, + result_preview, + "gemma3:270m", + "success" + ) + .execute(&state.pool) + .await + .map_err(|e| { + tracing::error!("Failed to log usage: {}", e); + }).ok(); + + tracing::info!( + company_id = %company_id, + request_id = %request_id, + credits_before = credits_before, + credits_after = credits_before - 1, + "AI generation completed" + ); + + // Call Ollama service + let ollama_base = std::env::var("OLLAMA_BASE_URL") + .unwrap_or_else(|_| "http://ollama.nxtgauge-ai.svc.cluster.local:11434".to_string()); + + let generated_content = call_ollama_generate(&ollama_base, &request.prompt).await + .map_err(|e| { + tracing::error!("Ollama call failed: {}", e); + (StatusCode::SERVICE_UNAVAILABLE, "AI service unavailable".to_string()) + })?; + + let response = GenerateAiResponse { + success: true, + content: generated_content, + credits_remaining: credits_before - 1, + request_id, + }; + + Ok((StatusCode::OK, Json(response))) +} + +/// GET /api/companies/ai/usage-history +/// Get AI usage history for a company +async fn get_usage_history( + _auth: AuthUser, + State(state): State, + Query(query): Query, +) -> Result { + let company_id = Uuid::new_v4(); // Placeholder + let page = query.page.unwrap_or(1).max(1); + let per_page = query.per_page.unwrap_or(20).clamp(1, 100); + let offset = (page - 1) * per_page; + + // Get total count + let total = sqlx::query_scalar!( + r#" + SELECT COUNT(*) FROM ai_usage_log WHERE company_id = $1 + "#, + company_id + ) + .fetch_one(&state.pool) + .await + .map_err(|e| { + tracing::error!("Failed to count usage entries: {}", e); + 0_i64 + }).unwrap_or(0); + + // Get entries + let entries = sqlx::query_as!( + UsageEntry, + r#" + SELECT id, request_type, credits_used, prompt_preview, result_preview, + model_used, status, error_message, created_at + FROM ai_usage_log + WHERE company_id = $1 + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + "#, + company_id, + per_page, + offset + ) + .fetch_all(&state.pool) + .await + .map_err(|e| { + tracing::error!("Failed to fetch usage history: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string()) + })?; + + let total_credits = entries.iter().map(|e| e.credits_used as i64).sum(); + + let response = UsageHistoryResponse { + total_entries: total, + entries, + total_credits_used: total_credits, + }; + + Ok((StatusCode::OK, Json(response))) +} + +// ============== Helper Functions ============== + +#[derive(Serialize)] +struct OllamaGenerateRequest { + model: String, + prompt: String, + stream: bool, +} + +#[derive(Deserialize)] +struct OllamaGenerateResponse { + response: String, +} + +async fn call_ollama_generate(base_url: &str, prompt: &str) -> Result { + let url = format!("{}/api/generate", base_url); + + let req = OllamaGenerateRequest { + model: "gemma3:270m".to_string(), + prompt: prompt.to_string(), + stream: false, + }; + + let client = reqwest::Client::new(); + let response = client + .post(&url) + .json(&req) + .send() + .await + .map_err(|e| format!("Ollama request failed: {}", e))?; + + if !response.status().is_success() { + return Err(format!("Ollama returned status: {}", response.status())); + } + + let result: OllamaGenerateResponse = response + .json() + .await + .map_err(|e| format!("Failed to parse Ollama response: {}", e))?; + + Ok(result.response) +} + +// ============== Tests ============== +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_request_deserialization() { + let json = serde_json::json!({ + "prompt": "Generate a job description", + "request_type": "job_description" + }); + let req: GenerateAiRequest = serde_json::from_value(json).unwrap(); + assert_eq!(req.prompt, "Generate a job description"); + assert_eq!(req.request_type, "job_description"); + } + + #[test] + fn test_response_serialization() { + let resp = GenerateAiResponse { + success: true, + content: "Generated content".to_string(), + credits_remaining: 5, + request_id: Uuid::new_v4(), + }; + let json = serde_json::to_value(&resp).unwrap(); + assert_eq!(json["success"], true); + assert_eq!(json["credits_remaining"], 5); + } +} diff --git a/apps/companies/src/handlers/mod.rs b/apps/companies/src/handlers/mod.rs index 56efb39..338172a 100644 --- a/apps/companies/src/handlers/mod.rs +++ b/apps/companies/src/handlers/mod.rs @@ -1,4 +1,6 @@ pub mod admin; +pub mod ai; + use axum::{ extract::{Multipart, Path, Query, State}, http::StatusCode, diff --git a/apps/companies/src/main.rs b/apps/companies/src/main.rs index 7b93fa8..a7ff9d9 100644 --- a/apps/companies/src/main.rs +++ b/apps/companies/src/main.rs @@ -45,6 +45,7 @@ async fn main() { let app = Router::new() .nest("/api/companies", handlers::router()) .nest("/api/admin/companies", handlers::admin::router()) + .nest("/api/companies/ai", handlers::ai::ai_router()) .route("/health", get(|| async { "Companies OK" })) .with_state(state); diff --git a/apps/fitness_trainers/Cargo.toml b/apps/fitness_trainers/Cargo.toml index b2a3d79..e4208e0 100644 --- a/apps/fitness_trainers/Cargo.toml +++ b/apps/fitness_trainers/Cargo.toml @@ -15,5 +15,6 @@ chrono = { workspace = true } db = { path = "../../crates/db" } auth = { path = "../../crates/auth" } contracts = { path = "../../crates/contracts" } -cache = { path = "../../crates/cache" } +cache = { path = "../../crates/cache" } +storage = { path = "../../crates/storage" } diff --git a/apps/fitness_trainers/src/main.rs b/apps/fitness_trainers/src/main.rs index 0307001..cb01f9a 100644 --- a/apps/fitness_trainers/src/main.rs +++ b/apps/fitness_trainers/src/main.rs @@ -30,7 +30,7 @@ async fn main() { tracing::info!("Fitness Trainers service — connected to DB and Redis"); - let state = ProfessionState { pool, redis }; + let state = ProfessionState { pool, redis, storage: std::sync::Arc::new(storage::StorageClient::from_env().await) }; let app = Router::new() .nest("/api/fitness-trainers", handlers::router()) diff --git a/apps/ugc_content_creators/Cargo.toml b/apps/ugc_content_creators/Cargo.toml index 8978704..27317e4 100644 --- a/apps/ugc_content_creators/Cargo.toml +++ b/apps/ugc_content_creators/Cargo.toml @@ -16,3 +16,4 @@ db = { path = "../../crates/db" } auth = { path = "../../crates/auth" } contracts = { path = "../../crates/contracts" } cache = { path = "../../crates/cache" } +storage = { path = "../../crates/storage" } diff --git a/apps/users/src/clients/ollama_client.rs b/apps/users/src/clients/ollama_client.rs new file mode 100644 index 0000000..efadd08 --- /dev/null +++ b/apps/users/src/clients/ollama_client.rs @@ -0,0 +1,250 @@ +use reqwest::{Client, Error as ReqwestError}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +const OLLAMA_URL: &str = "http://nxtgauge-ai-assistant:11434"; +const DEFAULT_MODEL: &str = "gemma3:270m"; +const REQUEST_TIMEOUT: Duration = Duration::from_secs(120); + +#[derive(Debug, Clone)] +pub struct OllamaClient { + http_client: Client, + base_url: String, + model: String, +} + +#[derive(Debug, Serialize)] +struct GenerateRequest { + model: String, + prompt: String, + stream: bool, + options: Option, +} + +#[derive(Debug, Serialize, Default)] +struct GenerationOptions { + temperature: Option, + top_p: Option, + top_k: Option, + num_predict: Option, +} + +#[derive(Debug, Deserialize)] +pub struct GenerateResponse { + pub model: String, + pub created_at: String, + pub response: String, + pub done: bool, + pub context: Option>, + pub total_duration: Option, + pub load_duration: Option, + pub prompt_eval_count: Option, + pub prompt_eval_duration: Option, + pub eval_count: Option, + pub eval_duration: Option, +} + +#[derive(Debug, Deserialize)] +struct OllamaErrorResponse { + error: String, +} + +#[derive(Debug, thiserror::Error)] +pub enum OllamaError { + #[error("HTTP request failed: {0}")] + RequestFailed(#[from] ReqwestError), + + #[error("Ollama API error: {0}")] + ApiError(String), + + #[error("Failed to parse response: {0}")] + ParseError(String), + + #[error("Connection timeout")] + Timeout, + + #[error("Model not found: {0}")] + ModelNotFound(String), +} + +impl OllamaClient { + pub fn new() -> Self { + let http_client = Client::builder() + .timeout(REQUEST_TIMEOUT) + .build() + .expect("Failed to create HTTP client"); + + Self { + http_client, + base_url: OLLAMA_URL.to_string(), + model: DEFAULT_MODEL.to_string(), + } + } + + pub fn with_url(base_url: impl Into) -> Self { + let http_client = Client::builder() + .timeout(REQUEST_TIMEOUT) + .build() + .expect("Failed to create HTTP client"); + + Self { + http_client, + base_url: base_url.into(), + model: DEFAULT_MODEL.to_string(), + } + } + + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = model.into(); + self + } + + /// Generate text using Ollama API + pub async fn generate(&self, prompt: &str) -> Result { + let url = format!("{}/api/generate", self.base_url); + + let request_body = GenerateRequest { + model: self.model.clone(), + prompt: prompt.to_string(), + stream: false, + options: Some(GenerationOptions { + temperature: Some(0.7), + top_p: Some(0.9), + num_predict: Some(512), + }), + }; + + let response = self.http_client + .post(&url) + .json(&request_body) + .send() + .await?; + + let status = response.status(); + + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + if let Ok(err) = serde_json::from_str::(&error_text) { + return Err(OllamaError::ApiError(err.error)); + } + return Err(OllamaError::ApiError(format!( + "HTTP {}: {}", + status.as_u16(), + error_text + ))); + } + + let generate_response = response + .json::() + .await + .map_err(|e| OllamaError::ParseError(e.to_string()))?; + + Ok(generate_response) + } + + /// Generate a job description from requirements + pub async fn generate_job_description(&self, requirements: &str) -> Result { + let prompt = format!( + r#"You are an expert recruitment professional. Create a professional, engaging job description based on the following requirements: + +Requirements: {} + +Generate a complete job description that includes: +1. Job Title (suggested) +2. Company Overview section +3. Job Summary +4. Key Responsibilities +5. Required Qualifications +6. Preferred Qualifications (if applicable) +7. Benefits/Perks (optional) +8. Application Instructions + +Make it ATS-friendly and compelling. Output only the job description, no extra commentary."#, + requirements + ); + + let response = self.generate(&prompt).await?; + Ok(response.response) + } + + /// Check if Ollama is reachable and model is available + pub async fn health_check(&self) -> Result<(), OllamaError> { + let url = format!("{}/api/tags", self.base_url); + + let response = self.http_client + .get(&url) + .send() + .await; + + match response { + Ok(resp) if resp.status().is_success() => Ok(()), + Ok(resp) => Err(OllamaError::ApiError(format!( + "Health check failed with status: {}", + resp.status().as_u16() + ))), + Err(e) if e.is_timeout() => Err(OllamaError::Timeout), + Err(e) => Err(OllamaError::RequestFailed(e)), + } + } + + /// Pull a model if not already available + pub async fn pull_model(&self) -> Result<(), OllamaError> { + let url = format!("{}/api/pull", self.base_url); + + #[derive(Serialize)] + struct PullRequest { + name: String, + stream: bool, + } + + let request_body = PullRequest { + name: self.model.clone(), + stream: false, + }; + + let response = self.http_client + .post(&url) + .json(&request_body) + .send() + .await?; + + if response.status().is_success() { + Ok(()) + } else { + Err(OllamaError::ModelNotFound(self.model.clone())) + } + } +} + +impl Default for OllamaClient { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // These tests require a running Ollama instance + // Run with: cargo test -- --ignored + + #[tokio::test] + #[ignore = "Requires Ollama server"] + async fn test_generate() { + let client = OllamaClient::new(); + let result = client.generate("Hello, world!").await; + assert!(result.is_ok()); + } + + #[tokio::test] + #[ignore = "Requires Ollama server"] + async fn test_generate_job_description() { + let client = OllamaClient::new(); + let requirements = "Senior Rust Developer with 5+ years experience, Actix-web knowledge required"; + let result = client.generate_job_description(requirements).await; + assert!(result.is_ok()); + let jd = result.unwrap(); + assert!(!jd.is_empty()); + } +} \ No newline at end of file diff --git a/crates/cache/src/lib.rs b/crates/cache/src/lib.rs index 19b4084..6bb8410 100644 --- a/crates/cache/src/lib.rs +++ b/crates/cache/src/lib.rs @@ -1,9 +1,10 @@ +pub mod ai; pub mod client; +pub mod ollama; pub mod otp; pub mod rate_limit; pub mod token; pub mod lead; pub mod jobs; -pub mod ai; pub use client::{RedisPool, connect}; diff --git a/crates/cache/src/ollama.rs b/crates/cache/src/ollama.rs new file mode 100644 index 0000000..dd1ea21 --- /dev/null +++ b/crates/cache/src/ollama.rs @@ -0,0 +1,226 @@ +//! Ollama client for AI-powered text generation +//! +//! Used for generating job descriptions, resume analysis, and other AI features + +use reqwest::{Client, Error as ReqwestError}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +const OLLAMA_URL: &str = "http://nxtgauge-ai-assistant:11434"; +const DEFAULT_MODEL: &str = "gemma3:270m"; +const REQUEST_TIMEOUT: Duration = Duration::from_secs(120); + +#[derive(Debug, Clone)] +pub struct OllamaClient { + http_client: Client, + base_url: String, + model: String, +} + +#[derive(Debug, Serialize)] +struct GenerateRequest { + model: String, + prompt: String, + stream: bool, + options: Option, +} + +#[derive(Debug, Serialize, Default)] +struct GenerationOptions { + temperature: Option, + top_p: Option, + top_k: Option, + num_predict: Option, +} + +#[derive(Debug, Deserialize)] +pub struct GenerateResponse { + pub model: String, + pub created_at: String, + pub response: String, + pub done: bool, + pub context: Option>, + pub total_duration: Option, + pub load_duration: Option, + pub prompt_eval_count: Option, + pub prompt_eval_duration: Option, + pub eval_count: Option, + pub eval_duration: Option, +} + +#[derive(Debug, Deserialize)] +struct OllamaErrorResponse { + error: String, +} + +#[derive(Debug, thiserror::Error)] +pub enum OllamaError { + #[error("HTTP request failed: {0}")] + RequestFailed(#[from] ReqwestError), + + #[error("Ollama API error: {0}")] + ApiError(String), + + #[error("Failed to parse response: {0}")] + ParseError(String), + + #[error("Connection timeout")] + Timeout, + + #[error("Model not found: {0}")] + ModelNotFound(String), +} + +impl OllamaClient { + pub fn new() -> Self { + let http_client = Client::builder() + .timeout(REQUEST_TIMEOUT) + .build() + .expect("Failed to create HTTP client"); + + Self { + http_client, + base_url: OLLAMA_URL.to_string(), + model: DEFAULT_MODEL.to_string(), + } + } + + pub fn with_url(base_url: impl Into) -> Self { + let http_client = Client::builder() + .timeout(REQUEST_TIMEOUT) + .build() + .expect("Failed to create HTTP client"); + + Self { + http_client, + base_url: base_url.into(), + model: DEFAULT_MODEL.to_string(), + } + } + + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = model.into(); + self + } + + pub fn get_model(&self) -> &str { + &self.model + } + + /// Generate text using the configured model and prompt + pub async fn generate(&self, prompt: impl Into) -> Result { + let request = GenerateRequest { + model: self.model.clone(), + prompt: prompt.into(), + stream: false, + options: None, + }; + + let url = format!("{}/api/generate", self.base_url); + + let response = self.http_client + .post(&url) + .json(&request) + .send() + .await + .map_err(|e| { + if e.is_timeout() { + OllamaError::Timeout + } else { + OllamaError::RequestFailed(e) + } + })?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + + if status.as_u16() == 404 { + return Err(OllamaError::ModelNotFound(self.model.clone())); + } + + return Err(OllamaError::ApiError(format!("{}: {}", status, error_text))); + } + + let result = response.json::() + .await + .map_err(|e| OllamaError::ParseError(e.to_string()))?; + + Ok(result) + } + + /// Generate a job description based on a prompt + pub async fn generate_job_description(&self, prompt: &str) -> Result { + let enhanced_prompt = format!( + "Generate a professional job description based on the following prompt:\n\n{}\n\n" + "Provide a well-structured description with clear responsibilities and requirements.", + prompt + ); + + let response = self.generate(enhanced_prompt).await?; + Ok(response.response) + } + + /// Analyze a resume and provide feedback + pub async fn analyze_resume(&self, resume_content: &str, job_description: &str) -> Result { + let prompt = format!( + "Analyze the following resume against this job description:\n\n" + "Job Description:\n{}\n\n" + "Resume:\n{}\n\n" + "Provide specific feedback on:\n" + "1. How well the resume matches the job requirements\n" + "2. Missing skills or experience\n" + "3. Suggestions for improvement\n" + "4. Overall match percentage", + job_description, resume_content + ); + + let response = self.generate(prompt).await?; + Ok(response.response) + } + + /// Generate a cover letter + pub async fn generate_cover_letter(&self, candidate_info: &str, job_description: &str, tone: &str + ) -> Result { + let prompt = format!( + "Write a {} cover letter for a candidate with the following background:\n\n" + "Candidate: {}\n\n" + "Job Description: {}\n\n" + "The cover letter should be professional and highlight relevant experience.", + tone, candidate_info, job_description + ); + + let response = self.generate(prompt).await?; + Ok(response.response) + } +} + +impl Default for OllamaClient { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_creation() { + let client = OllamaClient::new(); + assert_eq!(client.get_model(), DEFAULT_MODEL); + } + + #[test] + fn test_client_with_custom_model() { + let client = OllamaClient::new() + .with_model("gemma:4b"); + assert_eq!(client.get_model(), "gemma:4b"); + } + + #[test] + fn test_client_with_custom_url() { + let client = OllamaClient::with_url("http://custom:11434"); + assert_eq!(client.get_model(), DEFAULT_MODEL); + } +}