459 lines
15 KiB
Rust
459 lines
15 KiB
Rust
|
|
use auth::{
|
||
|
|
crypto::{hash_password, verify_password},
|
||
|
|
jwt::generate_tokens,
|
||
|
|
};
|
||
|
|
use ax_um_state_alias::AppState; // I'll use crate::AppState
|
||
|
|
use axum::{
|
||
|
|
extract::State,
|
||
|
|
http::{header::SET_COOKIE, StatusCode},
|
||
|
|
response::IntoResponse,
|
||
|
|
routing::{get, post},
|
||
|
|
Json, Router,
|
||
|
|
};
|
||
|
|
use chrono::{Duration, Utc};
|
||
|
|
use db::models::user::{CreateUserPayload, UserRepository};
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use contracts::auth_middleware::AuthUser;
|
||
|
|
use crate::AppState;
|
||
|
|
|
||
|
|
pub fn router() -> Router<AppState> {
|
||
|
|
Router::new()
|
||
|
|
.route("/register", post(register))
|
||
|
|
.route("/login", post(login))
|
||
|
|
.route("/logout", post(logout))
|
||
|
|
.route("/refresh", post(refresh))
|
||
|
|
.route("/session", get(session))
|
||
|
|
.route("/verify-email", post(verify_email))
|
||
|
|
.route("/resend-otp", post(resend_otp))
|
||
|
|
.route("/forgot-password", post(forgot_password))
|
||
|
|
.route("/reset-password", post(reset_password))
|
||
|
|
.route("/change-password", post(change_password))
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
// ── DTOs ──────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
#[derive(Deserialize)]
|
||
|
|
pub struct RegisterPayload {
|
||
|
|
pub full_name: String,
|
||
|
|
pub email: String,
|
||
|
|
pub phone: String,
|
||
|
|
pub password: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Deserialize)]
|
||
|
|
pub struct LoginPayload {
|
||
|
|
pub email: String,
|
||
|
|
pub password: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Deserialize)]
|
||
|
|
pub struct VerifyEmailPayload {
|
||
|
|
pub otp: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Deserialize)]
|
||
|
|
pub struct ForgotPasswordPayload {
|
||
|
|
pub email: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Deserialize)]
|
||
|
|
pub struct ResetPasswordPayload {
|
||
|
|
pub token: String,
|
||
|
|
pub new_password: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Deserialize)]
|
||
|
|
pub struct ChangePasswordPayload {
|
||
|
|
pub current_password: String,
|
||
|
|
pub new_password: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Serialize)]
|
||
|
|
pub struct RegisterResponse {
|
||
|
|
pub user_id: String,
|
||
|
|
pub email: String,
|
||
|
|
pub phone: String,
|
||
|
|
pub full_name: String,
|
||
|
|
pub status: String,
|
||
|
|
pub email_verified: bool,
|
||
|
|
pub created_at: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Serialize)]
|
||
|
|
pub struct LoginResponse {
|
||
|
|
pub access_token: String,
|
||
|
|
pub token_type: String,
|
||
|
|
pub expires_in: u64,
|
||
|
|
pub user: SessionUser,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Serialize)]
|
||
|
|
pub struct SessionUser {
|
||
|
|
pub id: String,
|
||
|
|
pub email: String,
|
||
|
|
pub full_name: String,
|
||
|
|
pub email_verified: bool,
|
||
|
|
pub roles: Vec<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Serialize)]
|
||
|
|
pub struct ErrorResponse {
|
||
|
|
pub error: String,
|
||
|
|
pub code: String,
|
||
|
|
#[serde(rename = "statusCode")]
|
||
|
|
pub status_code: u16,
|
||
|
|
}
|
||
|
|
|
||
|
|
fn err(status: StatusCode, msg: &str, code: &str) -> (StatusCode, Json<ErrorResponse>) {
|
||
|
|
(
|
||
|
|
status,
|
||
|
|
Json(ErrorResponse {
|
||
|
|
error: msg.to_string(),
|
||
|
|
code: code.to_string(),
|
||
|
|
status_code: status.as_u16(),
|
||
|
|
}),
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── Handlers ──────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
async fn register(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
Json(payload): Json<RegisterPayload>,
|
||
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||
|
|
// Basic validation
|
||
|
|
if payload.password.len() < 8 {
|
||
|
|
return Err(err(
|
||
|
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||
|
|
"Password minimum 8 characters",
|
||
|
|
"VALIDATION_ERROR",
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
let password_hash = hash_password(&payload.password).map_err(|e| {
|
||
|
|
err(
|
||
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||
|
|
&e.to_string(),
|
||
|
|
"INTERNAL_ERROR",
|
||
|
|
)
|
||
|
|
})?;
|
||
|
|
|
||
|
|
let user = UserRepository::create(
|
||
|
|
&state.pool,
|
||
|
|
CreateUserPayload {
|
||
|
|
full_name: payload.full_name,
|
||
|
|
email: payload.email.to_lowercase(),
|
||
|
|
phone: payload.phone,
|
||
|
|
password_hash,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
.map_err(|e| {
|
||
|
|
let msg = e.to_string();
|
||
|
|
if msg.contains("users_email_key") || msg.contains("email") && msg.contains("unique") {
|
||
|
|
err(StatusCode::CONFLICT, "Email already registered", "EMAIL_EXISTS")
|
||
|
|
} else if msg.contains("users_phone_key") || msg.contains("phone") && msg.contains("unique") {
|
||
|
|
err(StatusCode::CONFLICT, "Phone already registered", "PHONE_EXISTS")
|
||
|
|
} else {
|
||
|
|
err(StatusCode::INTERNAL_SERVER_ERROR, &msg, "DB_ERROR")
|
||
|
|
}
|
||
|
|
})?;
|
||
|
|
|
||
|
|
// Generate and send email OTP for verification
|
||
|
|
let otp = format!("{:06}", rand::random::<u32>() % 1000000);
|
||
|
|
let expires_at = Utc::now() + Duration::minutes(15);
|
||
|
|
|
||
|
|
UserRepository::set_email_verification_token(&state.pool, user.id, &otp, expires_at)
|
||
|
|
.await
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "DB_ERROR"))?;
|
||
|
|
|
||
|
|
let _ = state.mail.send_verification_email(&user.email, &user.full_name.unwrap_or_default(), &otp).await;
|
||
|
|
|
||
|
|
Ok((
|
||
|
|
StatusCode::CREATED,
|
||
|
|
Json(RegisterResponse {
|
||
|
|
user_id: user.id.to_string(),
|
||
|
|
email: user.email,
|
||
|
|
phone: user.phone.unwrap_or_default(),
|
||
|
|
full_name: user.full_name.unwrap_or_default(),
|
||
|
|
status: user.status,
|
||
|
|
email_verified: user.email_verified,
|
||
|
|
created_at: user.created_at.to_rfc3339(),
|
||
|
|
}),
|
||
|
|
))
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn login(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
Json(payload): Json<LoginPayload>,
|
||
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||
|
|
let user = UserRepository::get_by_email(&state.pool, &payload.email.to_lowercase())
|
||
|
|
.await
|
||
|
|
.map_err(|_| err(StatusCode::UNAUTHORIZED, "Invalid credentials", "INVALID_CREDENTIALS"))?;
|
||
|
|
|
||
|
|
// Check account status
|
||
|
|
if user.status == "SUSPENDED" {
|
||
|
|
return Err(err(StatusCode::FORBIDDEN, "Account suspended", "ACCOUNT_SUSPENDED"));
|
||
|
|
}
|
||
|
|
|
||
|
|
// Email verification check
|
||
|
|
if !user.email_verified {
|
||
|
|
return Err(err(StatusCode::UNAUTHORIZED, "Email not verified", "EMAIL_NOT_VERIFIED"));
|
||
|
|
}
|
||
|
|
|
||
|
|
let is_valid = verify_password(&payload.password, &user.password_hash).map_err(|e| {
|
||
|
|
err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "INTERNAL_ERROR")
|
||
|
|
})?;
|
||
|
|
|
||
|
|
if !is_valid {
|
||
|
|
return Err(err(StatusCode::UNAUTHORIZED, "Invalid credentials", "INVALID_CREDENTIALS"));
|
||
|
|
}
|
||
|
|
|
||
|
|
// Fetch user's active roles
|
||
|
|
let user_roles = UserRepository::get_user_role_keys(&state.pool, user.id)
|
||
|
|
.await
|
||
|
|
.unwrap_or_default();
|
||
|
|
|
||
|
|
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "changeme".to_string());
|
||
|
|
let tokens = generate_tokens(user.id.to_string(), user_roles.first().cloned(), &jwt_secret)
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "TOKEN_ERROR"))?;
|
||
|
|
|
||
|
|
UserRepository::store_refresh_token(
|
||
|
|
&state.pool,
|
||
|
|
user.id,
|
||
|
|
&tokens.refresh_token,
|
||
|
|
Utc::now() + Duration::days(30),
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "DB_ERROR"))?;
|
||
|
|
|
||
|
|
// Set refresh token as httpOnly cookie
|
||
|
|
let cookie = format!(
|
||
|
|
"nxtgauge_refresh_token={}; HttpOnly; Secure; SameSite=Strict; Path=/; Max-Age=2592000",
|
||
|
|
tokens.refresh_token
|
||
|
|
);
|
||
|
|
|
||
|
|
let response = Json(LoginResponse {
|
||
|
|
access_token: tokens.access_token,
|
||
|
|
token_type: "Bearer".to_string(),
|
||
|
|
expires_in: 900,
|
||
|
|
user: SessionUser {
|
||
|
|
id: user.id.to_string(),
|
||
|
|
email: user.email,
|
||
|
|
full_name: user.full_name.unwrap_or_default(),
|
||
|
|
email_verified: user.email_verified,
|
||
|
|
roles: user_roles,
|
||
|
|
},
|
||
|
|
});
|
||
|
|
|
||
|
|
Ok((
|
||
|
|
StatusCode::OK,
|
||
|
|
[(SET_COOKIE, cookie)],
|
||
|
|
response,
|
||
|
|
))
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn logout(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
// In real implementation: extract refresh token from cookie header
|
||
|
|
) -> impl IntoResponse {
|
||
|
|
// TODO: Revoke refresh token from cookie
|
||
|
|
let _ = &state.pool;
|
||
|
|
(StatusCode::OK, Json(serde_json::json!({ "message": "Logged out successfully" })))
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn refresh(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
// In real impl: read httpOnly cookie, not body
|
||
|
|
Json(payload): Json<serde_json::Value>,
|
||
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||
|
|
let token = payload["refresh_token"]
|
||
|
|
.as_str()
|
||
|
|
.ok_or_else(|| err(StatusCode::UNAUTHORIZED, "Refresh token missing", "REFRESH_TOKEN_INVALID"))?;
|
||
|
|
|
||
|
|
let rt = UserRepository::get_valid_refresh_token(&state.pool, token)
|
||
|
|
.await
|
||
|
|
.map_err(|_| err(StatusCode::UNAUTHORIZED, "Refresh token invalid", "REFRESH_TOKEN_INVALID"))?;
|
||
|
|
|
||
|
|
let user = UserRepository::get_by_id(&state.pool, rt.user_id)
|
||
|
|
.await
|
||
|
|
.map_err(|_| err(StatusCode::UNAUTHORIZED, "User not found", "INVALID_CREDENTIALS"))?;
|
||
|
|
|
||
|
|
let _ = UserRepository::revoke_refresh_token(&state.pool, token).await;
|
||
|
|
|
||
|
|
let user_roles = UserRepository::get_user_role_keys(&state.pool, user.id)
|
||
|
|
.await
|
||
|
|
.unwrap_or_default();
|
||
|
|
|
||
|
|
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "changeme".to_string());
|
||
|
|
let tokens = generate_tokens(user.id.to_string(), user_roles.first().cloned(), &jwt_secret)
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "TOKEN_ERROR"))?;
|
||
|
|
|
||
|
|
UserRepository::store_refresh_token(
|
||
|
|
&state.pool,
|
||
|
|
user.id,
|
||
|
|
&tokens.refresh_token,
|
||
|
|
Utc::now() + Duration::days(30),
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "DB_ERROR"))?;
|
||
|
|
|
||
|
|
Ok((
|
||
|
|
StatusCode::OK,
|
||
|
|
Json(serde_json::json!({
|
||
|
|
"access_token": tokens.access_token,
|
||
|
|
"expires_in": 900
|
||
|
|
})),
|
||
|
|
))
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
async fn session(
|
||
|
|
auth: AuthUser,
|
||
|
|
State(state): State<AppState>,
|
||
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||
|
|
let user = UserRepository::get_by_id(&state.pool, auth.user_id)
|
||
|
|
.await
|
||
|
|
.map_err(|_| err(StatusCode::UNAUTHORIZED, "User not found", "USER_NOT_FOUND"))?;
|
||
|
|
|
||
|
|
let user_roles = UserRepository::get_user_role_keys(&state.pool, user.id)
|
||
|
|
.await
|
||
|
|
.unwrap_or_default();
|
||
|
|
|
||
|
|
Ok(Json(SessionUser {
|
||
|
|
id: user.id.to_string(),
|
||
|
|
email: user.email,
|
||
|
|
full_name: user.full_name.unwrap_or_default(),
|
||
|
|
email_verified: user.email_verified,
|
||
|
|
roles: user_roles,
|
||
|
|
}))
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn verify_email(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
Json(payload): Json<VerifyEmailPayload>,
|
||
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||
|
|
let user = UserRepository::get_by_verification_token(&state.pool, &payload.otp)
|
||
|
|
.await
|
||
|
|
.map_err(|_| err(StatusCode::UNAUTHORIZED, "Invalid verification code", "INVALID_CODE"))?;
|
||
|
|
|
||
|
|
if let Some(expires_at) = user.email_verification_expires_at {
|
||
|
|
if expires_at < Utc::now() {
|
||
|
|
return Err(err(StatusCode::UNAUTHORIZED, "Verification code expired", "CODE_EXPIRED"));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
UserRepository::set_email_verified(&state.pool, user.id)
|
||
|
|
.await
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "DB_ERROR"))?;
|
||
|
|
|
||
|
|
Ok((StatusCode::OK, Json(serde_json::json!({ "message": "Email verified successfully" }))))
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Deserialize)]
|
||
|
|
pub struct ResendOtpPayload {
|
||
|
|
pub email: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn resend_otp(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
Json(payload): Json<ResendOtpPayload>,
|
||
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||
|
|
let user = UserRepository::get_by_email(&state.pool, &payload.email)
|
||
|
|
.await
|
||
|
|
.map_err(|_| (StatusCode::OK, Json(serde_json::json!({ "message": "If email exists, a new OTP has been sent" }))))?;
|
||
|
|
|
||
|
|
let otp = format!("{:06}", rand::random::<u32>() % 1000000);
|
||
|
|
let expires_at = Utc::now() + Duration::minutes(15);
|
||
|
|
|
||
|
|
UserRepository::set_email_verification_token(&state.pool, user.id, &otp, expires_at)
|
||
|
|
.await
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "DB_ERROR"))?;
|
||
|
|
|
||
|
|
let _ = state.mail.send_verification_email(&user.email, &user.full_name.unwrap_or_default(), &otp).await;
|
||
|
|
|
||
|
|
Ok((StatusCode::OK, Json(serde_json::json!({ "message": "If email exists, a new OTP has been sent" }))))
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn forgot_password(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
Json(payload): Json<ForgotPasswordPayload>,
|
||
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||
|
|
let user = UserRepository::get_by_email(&state.pool, &payload.email)
|
||
|
|
.await
|
||
|
|
.map_err(|_| (StatusCode::OK, Json(serde_json::json!({ "message": "Reset link sent if email exists" }))))?;
|
||
|
|
|
||
|
|
let token: String = uuid::Uuid::new_v4().to_string();
|
||
|
|
let expires_at = Utc::now() + Duration::hours(1);
|
||
|
|
|
||
|
|
UserRepository::set_reset_token(&state.pool, user.id, &token, expires_at)
|
||
|
|
.await
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "DB_ERROR"))?;
|
||
|
|
|
||
|
|
let _ = state.mail.send_password_reset_email(&user.email, &user.full_name.unwrap_or_default(), &token).await;
|
||
|
|
|
||
|
|
Ok((StatusCode::OK, Json(serde_json::json!({ "message": "Reset link sent if email exists" }))))
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn reset_password(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
Json(payload): Json<ResetPasswordPayload>,
|
||
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||
|
|
let user = UserRepository::get_by_reset_token(&state.pool, &payload.token)
|
||
|
|
.await
|
||
|
|
.map_err(|_| err(StatusCode::UNAUTHORIZED, "Invalid or expired reset token", "INVALID_TOKEN"))?;
|
||
|
|
|
||
|
|
if let Some(expires_at) = user.reset_password_expires_at {
|
||
|
|
if expires_at < Utc::now() {
|
||
|
|
return Err(err(StatusCode::UNAUTHORIZED, "Reset token expired", "TOKEN_EXPIRED"));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if payload.new_password.len() < 8 {
|
||
|
|
return Err(err(StatusCode::UNPROCESSABLE_ENTITY, "Password minimum 8 characters", "VALIDATION_ERROR"));
|
||
|
|
}
|
||
|
|
|
||
|
|
let password_hash = hash_password(&payload.new_password).map_err(|e| {
|
||
|
|
err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "INTERNAL_ERROR")
|
||
|
|
})?;
|
||
|
|
|
||
|
|
UserRepository::update_password(&state.pool, user.id, &password_hash)
|
||
|
|
.await
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "DB_ERROR"))?;
|
||
|
|
|
||
|
|
UserRepository::clear_reset_token(&state.pool, user.id)
|
||
|
|
.await
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "DB_ERROR"))?;
|
||
|
|
|
||
|
|
Ok((StatusCode::OK, Json(serde_json::json!({ "message": "Password reset successfully" }))))
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn change_password(
|
||
|
|
auth: AuthUser,
|
||
|
|
State(state): State<AppState>,
|
||
|
|
Json(payload): Json<ChangePasswordPayload>,
|
||
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||
|
|
let user = UserRepository::get_by_id(&state.pool, auth.user_id)
|
||
|
|
.await
|
||
|
|
.map_err(|_| err(StatusCode::UNAUTHORIZED, "User not found", "USER_NOT_FOUND"))?;
|
||
|
|
|
||
|
|
if !verify_password(&payload.current_password, &user.password_hash).map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "AUTH_ERROR"))? {
|
||
|
|
return Err(err(StatusCode::UNAUTHORIZED, "Incorrect current password", "INVALID_PASSWORD"));
|
||
|
|
}
|
||
|
|
|
||
|
|
if payload.new_password.len() < 8 {
|
||
|
|
return Err(err(StatusCode::UNPROCESSABLE_ENTITY, "Password minimum 8 characters", "VALIDATION_ERROR"));
|
||
|
|
}
|
||
|
|
|
||
|
|
let password_hash = hash_password(&payload.new_password).map_err(|e| {
|
||
|
|
err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "INTERNAL_ERROR")
|
||
|
|
})?;
|
||
|
|
|
||
|
|
UserRepository::update_password(&state.pool, user.id, &password_hash)
|
||
|
|
.await
|
||
|
|
.map_err(|e| err(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string(), "DB_ERROR"))?;
|
||
|
|
|
||
|
|
Ok((StatusCode::OK, Json(serde_json::json!({ "message": "Password changed successfully" }))))
|
||
|
|
}
|