diff --git a/crates/db-migrate/src/main.rs b/crates/db-migrate/src/main.rs index 343171b..a250a20 100644 --- a/crates/db-migrate/src/main.rs +++ b/crates/db-migrate/src/main.rs @@ -21,6 +21,14 @@ async fn main() -> Result<()> { .context("Failed to connect to database")?; tracing::info!("Connected to database"); + let drop_existing = std::env::var("DROP_EXISTING_TABLES") + .unwrap_or_default() + .to_lowercase(); + + if drop_existing == "true" || drop_existing == "1" || drop_existing == "yes" { + drop_all_tables(&pool).await?; + } + let migrations_dir = std::env::var("MIGRATIONS_DIR") .unwrap_or_else(|_| "/migrations".to_string()); @@ -30,6 +38,30 @@ async fn main() -> Result<()> { Ok(()) } +async fn drop_all_tables(pool: &sqlx::PgPool) -> Result<()> { + tracing::warn!("DROP_EXISTING_TABLES is enabled - dropping all tables!"); + + let rows: Vec<(String,)> = sqlx::query_as("SELECT tablename FROM pg_tables WHERE schemaname = 'public'") + .fetch_all(pool) + .await?; + + if rows.is_empty() { + tracing::info!("No tables to drop"); + return Ok(()); + } + + tracing::info!("Found {} tables to drop", rows.len()); + + for (table_name,) in rows { + let sql = format!("DROP TABLE IF EXISTS \"{}\" CASCADE", table_name); + tracing::info!("Dropping table: {}", table_name); + sqlx::raw_sql(&sql).execute(pool).await?; + } + + tracing::info!("All tables dropped successfully"); + Ok(()) +} + async fn run_migrations(pool: &sqlx::PgPool, migrations_dir: &str) -> Result<()> { let migrations_path = Path::new(migrations_dir);