diff --git a/api/database.py b/api/database.py index 7235f63..d6932a7 100644 --- a/api/database.py +++ b/api/database.py @@ -9,14 +9,16 @@ This module provides: import sqlite3 from pathlib import Path -from typing import Optional import os +from tools.deployment_config import get_db_path def get_db_connection(db_path: str = "data/jobs.db") -> sqlite3.Connection: """ Get SQLite database connection with proper configuration. + Automatically resolves to dev database if DEPLOYMENT_MODE=DEV. + Args: db_path: Path to SQLite database file @@ -28,17 +30,35 @@ def get_db_connection(db_path: str = "data/jobs.db") -> sqlite3.Connection: - Row factory for dict-like access - Check same thread disabled for FastAPI async compatibility """ + # Resolve path based on deployment mode + resolved_path = get_db_path(db_path) + # Ensure data directory exists - db_path_obj = Path(db_path) + db_path_obj = Path(resolved_path) db_path_obj.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(db_path, check_same_thread=False) + conn = sqlite3.connect(resolved_path, check_same_thread=False) conn.execute("PRAGMA foreign_keys = ON") conn.row_factory = sqlite3.Row return conn +def resolve_db_path(db_path: str) -> str: + """ + Resolve database path based on deployment mode + + Convenience function for testing. + + Args: + db_path: Base database path + + Returns: + Resolved path (dev or prod) + """ + return get_db_path(db_path) + + def initialize_database(db_path: str = "data/jobs.db") -> None: """ Create all database tables with enhanced schema. diff --git a/tests/unit/test_dev_database.py b/tests/unit/test_dev_database.py index 8447d7a..56a9de3 100644 --- a/tests/unit/test_dev_database.py +++ b/tests/unit/test_dev_database.py @@ -82,3 +82,20 @@ def test_initialize_dev_respects_preserve_flag(tmp_path): conn.close() os.environ.pop("PRESERVE_DEV_DATA") + + +def test_get_db_connection_resolves_dev_path(): + """Test that get_db_connection uses dev path in DEV mode""" + import os + os.environ["DEPLOYMENT_MODE"] = "DEV" + + # This should automatically resolve to dev database + # We're just testing the path logic, not actually creating DB + from api.database import resolve_db_path + + prod_path = "data/trading.db" + dev_path = resolve_db_path(prod_path) + + assert dev_path == "data/trading_dev.db" + + os.environ["DEPLOYMENT_MODE"] = "PROD"