From 8fb2ead8fffea71b9d5f3cca2228e54608b697e4 Mon Sep 17 00:00:00 2001 From: Bill Date: Sat, 1 Nov 2025 11:20:15 -0400 Subject: [PATCH] feat: add dev database initialization and cleanup functions --- api/database.py | 51 ++++++++++++++++++++ tests/unit/test_dev_database.py | 84 +++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 tests/unit/test_dev_database.py diff --git a/api/database.py b/api/database.py index d1fe94c..7235f63 100644 --- a/api/database.py +++ b/api/database.py @@ -213,6 +213,57 @@ def initialize_database(db_path: str = "data/jobs.db") -> None: conn.close() +def initialize_dev_database(db_path: str = "data/trading_dev.db") -> None: + """ + Initialize dev database with clean schema + + Deletes and recreates dev database unless PRESERVE_DEV_DATA=true. + Used at startup in DEV mode to ensure clean testing environment. + + Args: + db_path: Path to dev database file + """ + from tools.deployment_config import should_preserve_dev_data + + if should_preserve_dev_data(): + print(f"â„šī¸ PRESERVE_DEV_DATA=true, keeping existing dev database: {db_path}") + # Ensure schema exists even if preserving data + if not Path(db_path).exists(): + print(f"📁 Dev database doesn't exist, creating: {db_path}") + initialize_database(db_path) + return + + # Delete existing dev database + if Path(db_path).exists(): + print(f"đŸ—‘ī¸ Removing existing dev database: {db_path}") + Path(db_path).unlink() + + # Create fresh dev database + print(f"📁 Creating fresh dev database: {db_path}") + initialize_database(db_path) + + +def cleanup_dev_database(db_path: str = "data/trading_dev.db", data_path: str = "./data/dev_agent_data") -> None: + """ + Cleanup dev database and data files + + Args: + db_path: Path to dev database file + data_path: Path to dev data directory + """ + import shutil + + # Remove dev database + if Path(db_path).exists(): + print(f"đŸ—‘ī¸ Removing dev database: {db_path}") + Path(db_path).unlink() + + # Remove dev data directory + if Path(data_path).exists(): + print(f"đŸ—‘ī¸ Removing dev data directory: {data_path}") + shutil.rmtree(data_path) + + def _migrate_schema(cursor: sqlite3.Cursor) -> None: """Migrate existing database schema to latest version.""" # Check if positions table exists and has simulation_run_id column diff --git a/tests/unit/test_dev_database.py b/tests/unit/test_dev_database.py new file mode 100644 index 0000000..8447d7a --- /dev/null +++ b/tests/unit/test_dev_database.py @@ -0,0 +1,84 @@ +import os +import pytest +from pathlib import Path +from api.database import initialize_dev_database, cleanup_dev_database + + +def test_initialize_dev_database_creates_fresh_db(tmp_path): + """Test dev database initialization creates clean schema""" + db_path = str(tmp_path / "test_dev.db") + + # Create initial database with some data + from api.database import get_db_connection, initialize_database + initialize_database(db_path) + conn = get_db_connection(db_path) + conn.execute("INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at) VALUES (?, ?, ?, ?, ?, ?)", + ("test-job", "config.json", "completed", "2025-01-01:2025-01-31", '["model1"]', "2025-01-01T00:00:00")) + conn.commit() + conn.close() + + # Verify data exists + conn = get_db_connection(db_path) + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM jobs") + assert cursor.fetchone()[0] == 1 + conn.close() + + # Initialize dev database (should reset) + initialize_dev_database(db_path) + + # Verify data is cleared + conn = get_db_connection(db_path) + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM jobs") + assert cursor.fetchone()[0] == 0 + conn.close() + + +def test_cleanup_dev_database_removes_files(tmp_path): + """Test dev cleanup removes database and data files""" + # Setup dev files + db_path = str(tmp_path / "test_dev.db") + data_path = str(tmp_path / "dev_agent_data") + + Path(db_path).touch() + Path(data_path).mkdir(parents=True, exist_ok=True) + (Path(data_path) / "test_file.jsonl").touch() + + # Verify files exist + assert Path(db_path).exists() + assert Path(data_path).exists() + + # Cleanup + cleanup_dev_database(db_path, data_path) + + # Verify files removed + assert not Path(db_path).exists() + assert not Path(data_path).exists() + + +def test_initialize_dev_respects_preserve_flag(tmp_path): + """Test that PRESERVE_DEV_DATA flag prevents cleanup""" + os.environ["PRESERVE_DEV_DATA"] = "true" + db_path = str(tmp_path / "test_dev.db") + + # Create database with data + from api.database import get_db_connection, initialize_database + initialize_database(db_path) + conn = get_db_connection(db_path) + conn.execute("INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at) VALUES (?, ?, ?, ?, ?, ?)", + ("test-job", "config.json", "completed", "2025-01-01:2025-01-31", '["model1"]', "2025-01-01T00:00:00")) + conn.commit() + conn.close() + + # Initialize with preserve flag + initialize_dev_database(db_path) + + # Verify data is preserved + conn = get_db_connection(db_path) + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM jobs") + assert cursor.fetchone()[0] == 1 + conn.close() + + os.environ.pop("PRESERVE_DEV_DATA")