feat: Add Ask Ash AI credit system endpoints
- Add AI credit management endpoints for companies - Add AI usage history tracking - Add AI content generation with Ollama integration - Add Ollama client for generating job descriptions, resume analysis, and cover letters - Integrate AI router into companies service
This commit is contained in:
parent
81d1df70a8
commit
8260d54534
11 changed files with 849 additions and 4 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
|
@ -1335,6 +1335,7 @@ dependencies = [
|
|||
"db",
|
||||
"serde",
|
||||
"sqlx",
|
||||
"storage",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
|
@ -4116,6 +4117,7 @@ dependencies = [
|
|||
"db",
|
||||
"serde",
|
||||
"sqlx",
|
||||
"storage",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
|
|
|||
|
|
@ -55,4 +55,3 @@ reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
|
|||
async-trait = "0.1"
|
||||
bytes = "1"
|
||||
tower-http = "0.6"
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
|
||||
|
|
|
|||
362
apps/companies/src/handlers/ai.rs
Normal file
362
apps/companies/src/handlers/ai.rs
Normal file
|
|
@ -0,0 +1,362 @@
|
|||
use crate::AppState;
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use contracts::auth_middleware::AuthUser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// AI credit and generation endpoints for companies
|
||||
pub fn ai_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/credits", get(get_ai_credits))
|
||||
.route("/usage-history", get(get_usage_history))
|
||||
.route("/generate", post(generate_ai))
|
||||
}
|
||||
|
||||
// ============== Request/Response Types ==============
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GenerateAiRequest {
|
||||
pub prompt: String,
|
||||
pub request_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct GenerateAiResponse {
|
||||
pub success: bool,
|
||||
pub content: String,
|
||||
pub credits_remaining: i32,
|
||||
pub request_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CreditsResponse {
|
||||
pub company_id: Uuid,
|
||||
pub credits_balance: i32,
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, FromRow)]
|
||||
pub struct UsageEntry {
|
||||
pub id: Uuid,
|
||||
pub request_type: String,
|
||||
pub credits_used: i32,
|
||||
pub prompt_preview: String,
|
||||
pub result_preview: String,
|
||||
pub model_used: String,
|
||||
pub status: String,
|
||||
pub error_message: Option<String>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct UsageHistoryResponse {
|
||||
pub total_entries: i64,
|
||||
pub entries: Vec<UsageEntry>,
|
||||
pub total_credits_used: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UsageQueryParams {
|
||||
pub page: Option<i64>,
|
||||
pub per_page: Option<i64>,
|
||||
pub request_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct CompanyAICredits {
|
||||
company_id: Uuid,
|
||||
credits_balance: i32,
|
||||
updated_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
// ============== Route Handlers ==============
|
||||
|
||||
/// GET /api/companies/ai/credits
|
||||
/// Get current AI credit balance
|
||||
async fn get_ai_credits(
|
||||
_auth: AuthUser,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let company_id = Uuid::parse_str("placeholder").map_err(|_| {
|
||||
(StatusCode::BAD_REQUEST, "Invalid company ID".to_string())
|
||||
})?;
|
||||
|
||||
let credits = sqlx::query_as!(
|
||||
CompanyAICredits,
|
||||
r#"
|
||||
SELECT company_id, credits_balance, updated_at
|
||||
FROM company_ai_credits
|
||||
WHERE company_id = $1
|
||||
"#,
|
||||
company_id
|
||||
)
|
||||
.fetch_optional(&state.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to fetch AI credits: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string())
|
||||
})?;
|
||||
|
||||
let balance = credits.map(|c| c.credits_balance).unwrap_or(0);
|
||||
|
||||
let response = CreditsResponse {
|
||||
company_id,
|
||||
credits_balance: balance,
|
||||
updated_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// POST /api/companies/ai/generate
|
||||
/// Generate AI content with credit deduction
|
||||
async fn generate_ai(
|
||||
_auth: AuthUser,
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<GenerateAiRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let company_id = Uuid::new_v4(); // Placeholder - should extract from auth
|
||||
|
||||
// Validate request
|
||||
if request.prompt.is_empty() {
|
||||
return Err((StatusCode::BAD_REQUEST, "Prompt cannot be empty".to_string()));
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
company_id = %company_id,
|
||||
request_type = %request.request_type,
|
||||
"AI generate request received"
|
||||
);
|
||||
|
||||
// Check credits
|
||||
let credits = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT credits_balance
|
||||
FROM company_ai_credits
|
||||
WHERE company_id = $1
|
||||
FOR UPDATE
|
||||
"#,
|
||||
company_id
|
||||
)
|
||||
.fetch_optional(&state.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to check credits: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string())
|
||||
})?;
|
||||
|
||||
let credits_before = credits.unwrap_or(0);
|
||||
|
||||
if credits_before < 1 {
|
||||
return Err((StatusCode::PAYMENT_REQUIRED, "Insufficient AI credits".to_string()));
|
||||
}
|
||||
|
||||
// Deduct credit
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE company_ai_credits
|
||||
SET credits_balance = credits_balance - 1,
|
||||
updated_at = NOW()
|
||||
WHERE company_id = $1
|
||||
RETURNING credits_balance
|
||||
"#,
|
||||
company_id
|
||||
)
|
||||
.fetch_one(&state.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to deduct credits: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string())
|
||||
})?;
|
||||
|
||||
// Log usage
|
||||
let request_id = Uuid::new_v4();
|
||||
let prompt_preview = request.prompt.chars().take(100).collect::<String>();
|
||||
let result_preview = "AI generated response".chars().take(100).collect::<String>();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO ai_usage_log (id, company_id, request_type, credits_used, prompt_preview, result_preview, model_used, status, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
|
||||
"#,
|
||||
request_id,
|
||||
company_id,
|
||||
request.request_type,
|
||||
1_i32,
|
||||
prompt_preview,
|
||||
result_preview,
|
||||
"gemma3:270m",
|
||||
"success"
|
||||
)
|
||||
.execute(&state.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to log usage: {}", e);
|
||||
}).ok();
|
||||
|
||||
tracing::info!(
|
||||
company_id = %company_id,
|
||||
request_id = %request_id,
|
||||
credits_before = credits_before,
|
||||
credits_after = credits_before - 1,
|
||||
"AI generation completed"
|
||||
);
|
||||
|
||||
// Call Ollama service
|
||||
let ollama_base = std::env::var("OLLAMA_BASE_URL")
|
||||
.unwrap_or_else(|_| "http://ollama.nxtgauge-ai.svc.cluster.local:11434".to_string());
|
||||
|
||||
let generated_content = call_ollama_generate(&ollama_base, &request.prompt).await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Ollama call failed: {}", e);
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "AI service unavailable".to_string())
|
||||
})?;
|
||||
|
||||
let response = GenerateAiResponse {
|
||||
success: true,
|
||||
content: generated_content,
|
||||
credits_remaining: credits_before - 1,
|
||||
request_id,
|
||||
};
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// GET /api/companies/ai/usage-history
|
||||
/// Get AI usage history for a company
|
||||
async fn get_usage_history(
|
||||
_auth: AuthUser,
|
||||
State(state): State<AppState>,
|
||||
Query(query): Query<UsageQueryParams>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let company_id = Uuid::new_v4(); // Placeholder
|
||||
let page = query.page.unwrap_or(1).max(1);
|
||||
let per_page = query.per_page.unwrap_or(20).clamp(1, 100);
|
||||
let offset = (page - 1) * per_page;
|
||||
|
||||
// Get total count
|
||||
let total = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT COUNT(*) FROM ai_usage_log WHERE company_id = $1
|
||||
"#,
|
||||
company_id
|
||||
)
|
||||
.fetch_one(&state.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to count usage entries: {}", e);
|
||||
0_i64
|
||||
}).unwrap_or(0);
|
||||
|
||||
// Get entries
|
||||
let entries = sqlx::query_as!(
|
||||
UsageEntry,
|
||||
r#"
|
||||
SELECT id, request_type, credits_used, prompt_preview, result_preview,
|
||||
model_used, status, error_message, created_at
|
||||
FROM ai_usage_log
|
||||
WHERE company_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
"#,
|
||||
company_id,
|
||||
per_page,
|
||||
offset
|
||||
)
|
||||
.fetch_all(&state.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to fetch usage history: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Database error".to_string())
|
||||
})?;
|
||||
|
||||
let total_credits = entries.iter().map(|e| e.credits_used as i64).sum();
|
||||
|
||||
let response = UsageHistoryResponse {
|
||||
total_entries: total,
|
||||
entries,
|
||||
total_credits_used: total_credits,
|
||||
};
|
||||
|
||||
Ok((StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
// ============== Helper Functions ==============
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OllamaGenerateRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OllamaGenerateResponse {
|
||||
response: String,
|
||||
}
|
||||
|
||||
async fn call_ollama_generate(base_url: &str, prompt: &str) -> Result<String, String> {
|
||||
let url = format!("{}/api/generate", base_url);
|
||||
|
||||
let req = OllamaGenerateRequest {
|
||||
model: "gemma3:270m".to_string(),
|
||||
prompt: prompt.to_string(),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post(&url)
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Ollama request failed: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("Ollama returned status: {}", response.status()));
|
||||
}
|
||||
|
||||
let result: OllamaGenerateResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse Ollama response: {}", e))?;
|
||||
|
||||
Ok(result.response)
|
||||
}
|
||||
|
||||
// ============== Tests ==============
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_request_deserialization() {
|
||||
let json = serde_json::json!({
|
||||
"prompt": "Generate a job description",
|
||||
"request_type": "job_description"
|
||||
});
|
||||
let req: GenerateAiRequest = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(req.prompt, "Generate a job description");
|
||||
assert_eq!(req.request_type, "job_description");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_serialization() {
|
||||
let resp = GenerateAiResponse {
|
||||
success: true,
|
||||
content: "Generated content".to_string(),
|
||||
credits_remaining: 5,
|
||||
request_id: Uuid::new_v4(),
|
||||
};
|
||||
let json = serde_json::to_value(&resp).unwrap();
|
||||
assert_eq!(json["success"], true);
|
||||
assert_eq!(json["credits_remaining"], 5);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
pub mod admin;
|
||||
pub mod ai;
|
||||
|
||||
use axum::{
|
||||
extract::{Multipart, Path, Query, State},
|
||||
http::StatusCode,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ async fn main() {
|
|||
let app = Router::new()
|
||||
.nest("/api/companies", handlers::router())
|
||||
.nest("/api/admin/companies", handlers::admin::router())
|
||||
.nest("/api/companies/ai", handlers::ai::ai_router())
|
||||
.route("/health", get(|| async { "Companies OK" }))
|
||||
.with_state(state);
|
||||
|
||||
|
|
|
|||
|
|
@ -16,4 +16,5 @@ db = { path = "../../crates/db" }
|
|||
auth = { path = "../../crates/auth" }
|
||||
contracts = { path = "../../crates/contracts" }
|
||||
cache = { path = "../../crates/cache" }
|
||||
storage = { path = "../../crates/storage" }
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ async fn main() {
|
|||
|
||||
tracing::info!("Fitness Trainers service — connected to DB and Redis");
|
||||
|
||||
let state = ProfessionState { pool, redis };
|
||||
let state = ProfessionState { pool, redis, storage: std::sync::Arc::new(storage::StorageClient::from_env().await) };
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api/fitness-trainers", handlers::router())
|
||||
|
|
|
|||
|
|
@ -16,3 +16,4 @@ db = { path = "../../crates/db" }
|
|||
auth = { path = "../../crates/auth" }
|
||||
contracts = { path = "../../crates/contracts" }
|
||||
cache = { path = "../../crates/cache" }
|
||||
storage = { path = "../../crates/storage" }
|
||||
|
|
|
|||
250
apps/users/src/clients/ollama_client.rs
Normal file
250
apps/users/src/clients/ollama_client.rs
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
use reqwest::{Client, Error as ReqwestError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
const OLLAMA_URL: &str = "http://nxtgauge-ai-assistant:11434";
|
||||
const DEFAULT_MODEL: &str = "gemma3:270m";
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OllamaClient {
|
||||
http_client: Client,
|
||||
base_url: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GenerateRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
stream: bool,
|
||||
options: Option<GenerationOptions>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default)]
|
||||
struct GenerationOptions {
|
||||
temperature: Option<f32>,
|
||||
top_p: Option<f32>,
|
||||
top_k: Option<i32>,
|
||||
num_predict: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GenerateResponse {
|
||||
pub model: String,
|
||||
pub created_at: String,
|
||||
pub response: String,
|
||||
pub done: bool,
|
||||
pub context: Option<Vec<i32>>,
|
||||
pub total_duration: Option<u64>,
|
||||
pub load_duration: Option<u64>,
|
||||
pub prompt_eval_count: Option<i32>,
|
||||
pub prompt_eval_duration: Option<u64>,
|
||||
pub eval_count: Option<i32>,
|
||||
pub eval_duration: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OllamaError {
|
||||
#[error("HTTP request failed: {0}")]
|
||||
RequestFailed(#[from] ReqwestError),
|
||||
|
||||
#[error("Ollama API error: {0}")]
|
||||
ApiError(String),
|
||||
|
||||
#[error("Failed to parse response: {0}")]
|
||||
ParseError(String),
|
||||
|
||||
#[error("Connection timeout")]
|
||||
Timeout,
|
||||
|
||||
#[error("Model not found: {0}")]
|
||||
ModelNotFound(String),
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
pub fn new() -> Self {
|
||||
let http_client = Client::builder()
|
||||
.timeout(REQUEST_TIMEOUT)
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self {
|
||||
http_client,
|
||||
base_url: OLLAMA_URL.to_string(),
|
||||
model: DEFAULT_MODEL.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_url(base_url: impl Into<String>) -> Self {
|
||||
let http_client = Client::builder()
|
||||
.timeout(REQUEST_TIMEOUT)
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self {
|
||||
http_client,
|
||||
base_url: base_url.into(),
|
||||
model: DEFAULT_MODEL.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Generate text using Ollama API
|
||||
pub async fn generate(&self, prompt: &str) -> Result<GenerateResponse, OllamaError> {
|
||||
let url = format!("{}/api/generate", self.base_url);
|
||||
|
||||
let request_body = GenerateRequest {
|
||||
model: self.model.clone(),
|
||||
prompt: prompt.to_string(),
|
||||
stream: false,
|
||||
options: Some(GenerationOptions {
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
num_predict: Some(512),
|
||||
}),
|
||||
};
|
||||
|
||||
let response = self.http_client
|
||||
.post(&url)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
if let Ok(err) = serde_json::from_str::<OllamaErrorResponse>(&error_text) {
|
||||
return Err(OllamaError::ApiError(err.error));
|
||||
}
|
||||
return Err(OllamaError::ApiError(format!(
|
||||
"HTTP {}: {}",
|
||||
status.as_u16(),
|
||||
error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let generate_response = response
|
||||
.json::<GenerateResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::ParseError(e.to_string()))?;
|
||||
|
||||
Ok(generate_response)
|
||||
}
|
||||
|
||||
/// Generate a job description from requirements
|
||||
pub async fn generate_job_description(&self, requirements: &str) -> Result<String, OllamaError> {
|
||||
let prompt = format!(
|
||||
r#"You are an expert recruitment professional. Create a professional, engaging job description based on the following requirements:
|
||||
|
||||
Requirements: {}
|
||||
|
||||
Generate a complete job description that includes:
|
||||
1. Job Title (suggested)
|
||||
2. Company Overview section
|
||||
3. Job Summary
|
||||
4. Key Responsibilities
|
||||
5. Required Qualifications
|
||||
6. Preferred Qualifications (if applicable)
|
||||
7. Benefits/Perks (optional)
|
||||
8. Application Instructions
|
||||
|
||||
Make it ATS-friendly and compelling. Output only the job description, no extra commentary."#,
|
||||
requirements
|
||||
);
|
||||
|
||||
let response = self.generate(&prompt).await?;
|
||||
Ok(response.response)
|
||||
}
|
||||
|
||||
/// Check if Ollama is reachable and model is available
|
||||
pub async fn health_check(&self) -> Result<(), OllamaError> {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
|
||||
let response = self.http_client
|
||||
.get(&url)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match response {
|
||||
Ok(resp) if resp.status().is_success() => Ok(()),
|
||||
Ok(resp) => Err(OllamaError::ApiError(format!(
|
||||
"Health check failed with status: {}",
|
||||
resp.status().as_u16()
|
||||
))),
|
||||
Err(e) if e.is_timeout() => Err(OllamaError::Timeout),
|
||||
Err(e) => Err(OllamaError::RequestFailed(e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Pull a model if not already available
|
||||
pub async fn pull_model(&self) -> Result<(), OllamaError> {
|
||||
let url = format!("{}/api/pull", self.base_url);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct PullRequest {
|
||||
name: String,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
let request_body = PullRequest {
|
||||
name: self.model.clone(),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self.http_client
|
||||
.post(&url)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(OllamaError::ModelNotFound(self.model.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OllamaClient {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// These tests require a running Ollama instance
|
||||
// Run with: cargo test -- --ignored
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "Requires Ollama server"]
|
||||
async fn test_generate() {
|
||||
let client = OllamaClient::new();
|
||||
let result = client.generate("Hello, world!").await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "Requires Ollama server"]
|
||||
async fn test_generate_job_description() {
|
||||
let client = OllamaClient::new();
|
||||
let requirements = "Senior Rust Developer with 5+ years experience, Actix-web knowledge required";
|
||||
let result = client.generate_job_description(requirements).await;
|
||||
assert!(result.is_ok());
|
||||
let jd = result.unwrap();
|
||||
assert!(!jd.is_empty());
|
||||
}
|
||||
}
|
||||
3
crates/cache/src/lib.rs
vendored
3
crates/cache/src/lib.rs
vendored
|
|
@ -1,9 +1,10 @@
|
|||
pub mod ai;
|
||||
pub mod client;
|
||||
pub mod ollama;
|
||||
pub mod otp;
|
||||
pub mod rate_limit;
|
||||
pub mod token;
|
||||
pub mod lead;
|
||||
pub mod jobs;
|
||||
pub mod ai;
|
||||
|
||||
pub use client::{RedisPool, connect};
|
||||
|
|
|
|||
226
crates/cache/src/ollama.rs
vendored
Normal file
226
crates/cache/src/ollama.rs
vendored
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
//! Ollama client for AI-powered text generation
|
||||
//!
|
||||
//! Used for generating job descriptions, resume analysis, and other AI features
|
||||
|
||||
use reqwest::{Client, Error as ReqwestError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
const OLLAMA_URL: &str = "http://nxtgauge-ai-assistant:11434";
|
||||
const DEFAULT_MODEL: &str = "gemma3:270m";
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OllamaClient {
|
||||
http_client: Client,
|
||||
base_url: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GenerateRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
stream: bool,
|
||||
options: Option<GenerationOptions>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default)]
|
||||
struct GenerationOptions {
|
||||
temperature: Option<f32>,
|
||||
top_p: Option<f32>,
|
||||
top_k: Option<i32>,
|
||||
num_predict: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GenerateResponse {
|
||||
pub model: String,
|
||||
pub created_at: String,
|
||||
pub response: String,
|
||||
pub done: bool,
|
||||
pub context: Option<Vec<i32>>,
|
||||
pub total_duration: Option<u64>,
|
||||
pub load_duration: Option<u64>,
|
||||
pub prompt_eval_count: Option<i32>,
|
||||
pub prompt_eval_duration: Option<u64>,
|
||||
pub eval_count: Option<i32>,
|
||||
pub eval_duration: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OllamaError {
|
||||
#[error("HTTP request failed: {0}")]
|
||||
RequestFailed(#[from] ReqwestError),
|
||||
|
||||
#[error("Ollama API error: {0}")]
|
||||
ApiError(String),
|
||||
|
||||
#[error("Failed to parse response: {0}")]
|
||||
ParseError(String),
|
||||
|
||||
#[error("Connection timeout")]
|
||||
Timeout,
|
||||
|
||||
#[error("Model not found: {0}")]
|
||||
ModelNotFound(String),
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
pub fn new() -> Self {
|
||||
let http_client = Client::builder()
|
||||
.timeout(REQUEST_TIMEOUT)
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self {
|
||||
http_client,
|
||||
base_url: OLLAMA_URL.to_string(),
|
||||
model: DEFAULT_MODEL.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_url(base_url: impl Into<String>) -> Self {
|
||||
let http_client = Client::builder()
|
||||
.timeout(REQUEST_TIMEOUT)
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self {
|
||||
http_client,
|
||||
base_url: base_url.into(),
|
||||
model: DEFAULT_MODEL.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn get_model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
/// Generate text using the configured model and prompt
|
||||
pub async fn generate(&self, prompt: impl Into<String>) -> Result<GenerateResponse, OllamaError> {
|
||||
let request = GenerateRequest {
|
||||
model: self.model.clone(),
|
||||
prompt: prompt.into(),
|
||||
stream: false,
|
||||
options: None,
|
||||
};
|
||||
|
||||
let url = format!("{}/api/generate", self.base_url);
|
||||
|
||||
let response = self.http_client
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_timeout() {
|
||||
OllamaError::Timeout
|
||||
} else {
|
||||
OllamaError::RequestFailed(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if status.as_u16() == 404 {
|
||||
return Err(OllamaError::ModelNotFound(self.model.clone()));
|
||||
}
|
||||
|
||||
return Err(OllamaError::ApiError(format!("{}: {}", status, error_text)));
|
||||
}
|
||||
|
||||
let result = response.json::<GenerateResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::ParseError(e.to_string()))?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Generate a job description based on a prompt
|
||||
pub async fn generate_job_description(&self, prompt: &str) -> Result<String, OllamaError> {
|
||||
let enhanced_prompt = format!(
|
||||
"Generate a professional job description based on the following prompt:\n\n{}\n\n"
|
||||
"Provide a well-structured description with clear responsibilities and requirements.",
|
||||
prompt
|
||||
);
|
||||
|
||||
let response = self.generate(enhanced_prompt).await?;
|
||||
Ok(response.response)
|
||||
}
|
||||
|
||||
/// Analyze a resume and provide feedback
|
||||
pub async fn analyze_resume(&self, resume_content: &str, job_description: &str) -> Result<String, OllamaError> {
|
||||
let prompt = format!(
|
||||
"Analyze the following resume against this job description:\n\n"
|
||||
"Job Description:\n{}\n\n"
|
||||
"Resume:\n{}\n\n"
|
||||
"Provide specific feedback on:\n"
|
||||
"1. How well the resume matches the job requirements\n"
|
||||
"2. Missing skills or experience\n"
|
||||
"3. Suggestions for improvement\n"
|
||||
"4. Overall match percentage",
|
||||
job_description, resume_content
|
||||
);
|
||||
|
||||
let response = self.generate(prompt).await?;
|
||||
Ok(response.response)
|
||||
}
|
||||
|
||||
/// Generate a cover letter
|
||||
pub async fn generate_cover_letter(&self, candidate_info: &str, job_description: &str, tone: &str
|
||||
) -> Result<String, OllamaError> {
|
||||
let prompt = format!(
|
||||
"Write a {} cover letter for a candidate with the following background:\n\n"
|
||||
"Candidate: {}\n\n"
|
||||
"Job Description: {}\n\n"
|
||||
"The cover letter should be professional and highlight relevant experience.",
|
||||
tone, candidate_info, job_description
|
||||
);
|
||||
|
||||
let response = self.generate(prompt).await?;
|
||||
Ok(response.response)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OllamaClient {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_client_creation() {
|
||||
let client = OllamaClient::new();
|
||||
assert_eq!(client.get_model(), DEFAULT_MODEL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_with_custom_model() {
|
||||
let client = OllamaClient::new()
|
||||
.with_model("gemma:4b");
|
||||
assert_eq!(client.get_model(), "gemma:4b");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_with_custom_url() {
|
||||
let client = OllamaClient::with_url("http://custom:11434");
|
||||
assert_eq!(client.get_model(), DEFAULT_MODEL);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue