From f76c85b2535989548e65652af31a1005ec54423e Mon Sep 17 00:00:00 2001 From: Bill Date: Mon, 3 Nov 2025 23:09:02 -0500 Subject: [PATCH] feat: add database helper methods for trading_days schema Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- api/database.py | 224 ++++++++++++++++++++++ tests/unit/test_database_helpers.py | 288 ++++++++++++++++++++++++++++ 2 files changed, 512 insertions(+) create mode 100644 tests/unit/test_database_helpers.py diff --git a/api/database.py b/api/database.py index 8973bf3..256f7c1 100644 --- a/api/database.py +++ b/api/database.py @@ -540,3 +540,227 @@ def get_database_stats(db_path: str = "data/jobs.db") -> dict: conn.close() return stats + + +class Database: + """Database wrapper class with helper methods for trading_days schema.""" + + def __init__(self, db_path: str = None): + """Initialize database connection. + + Args: + db_path: Path to SQLite database file. + If None, uses default from deployment config. + """ + if db_path is None: + from tools.deployment_config import get_database_path + db_path = get_database_path() + + self.db_path = db_path + self.connection = sqlite3.connect(db_path, check_same_thread=False) + self.connection.row_factory = sqlite3.Row + + def create_trading_day( + self, + job_id: str, + model: str, + date: str, + starting_cash: float, + starting_portfolio_value: float, + daily_profit: float, + daily_return_pct: float, + ending_cash: float, + ending_portfolio_value: float, + reasoning_summary: str = None, + reasoning_full: str = None, + total_actions: int = 0, + session_duration_seconds: float = None, + days_since_last_trading: int = 1 + ) -> int: + """Create a new trading day record. + + Returns: + trading_day_id + """ + cursor = self.connection.execute( + """ + INSERT INTO trading_days ( + job_id, model, date, + starting_cash, starting_portfolio_value, + daily_profit, daily_return_pct, + ending_cash, ending_portfolio_value, + reasoning_summary, reasoning_full, + total_actions, session_duration_seconds, + days_since_last_trading, + completed_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + """, + ( + job_id, model, date, + starting_cash, starting_portfolio_value, + daily_profit, daily_return_pct, + ending_cash, ending_portfolio_value, + reasoning_summary, reasoning_full, + total_actions, session_duration_seconds, + days_since_last_trading + ) + ) + self.connection.commit() + return cursor.lastrowid + + def get_previous_trading_day( + self, + job_id: str, + model: str, + current_date: str + ) -> dict: + """Get the most recent trading day before current_date. + + Handles weekends/holidays by finding actual previous trading day. + + Returns: + dict with keys: id, date, ending_cash, ending_portfolio_value + or None if no previous day exists + """ + cursor = self.connection.execute( + """ + SELECT id, date, ending_cash, ending_portfolio_value + FROM trading_days + WHERE job_id = ? AND model = ? AND date < ? + ORDER BY date DESC + LIMIT 1 + """, + (job_id, model, current_date) + ) + + row = cursor.fetchone() + if row: + return { + "id": row[0], + "date": row[1], + "ending_cash": row[2], + "ending_portfolio_value": row[3] + } + return None + + def get_ending_holdings(self, trading_day_id: int) -> list: + """Get ending holdings for a trading day. + + Returns: + List of dicts with keys: symbol, quantity + """ + cursor = self.connection.execute( + """ + SELECT symbol, quantity + FROM holdings + WHERE trading_day_id = ? + ORDER BY symbol + """, + (trading_day_id,) + ) + + return [{"symbol": row[0], "quantity": row[1]} for row in cursor.fetchall()] + + def get_starting_holdings(self, trading_day_id: int) -> list: + """Get starting holdings from previous day's ending holdings. + + Returns: + List of dicts with keys: symbol, quantity + Empty list if first trading day + """ + # Get previous trading day + cursor = self.connection.execute( + """ + SELECT td_prev.id + FROM trading_days td_current + JOIN trading_days td_prev ON + td_prev.job_id = td_current.job_id AND + td_prev.model = td_current.model AND + td_prev.date < td_current.date + WHERE td_current.id = ? + ORDER BY td_prev.date DESC + LIMIT 1 + """, + (trading_day_id,) + ) + + row = cursor.fetchone() + if not row: + # First trading day - no previous holdings + return [] + + previous_day_id = row[0] + + # Get previous day's ending holdings + return self.get_ending_holdings(previous_day_id) + + def create_holding( + self, + trading_day_id: int, + symbol: str, + quantity: int + ) -> int: + """Create a holding record. + + Returns: + holding_id + """ + cursor = self.connection.execute( + """ + INSERT INTO holdings (trading_day_id, symbol, quantity) + VALUES (?, ?, ?) + """, + (trading_day_id, symbol, quantity) + ) + self.connection.commit() + return cursor.lastrowid + + def create_action( + self, + trading_day_id: int, + action_type: str, + symbol: str = None, + quantity: int = None, + price: float = None + ) -> int: + """Create an action record. + + Returns: + action_id + """ + cursor = self.connection.execute( + """ + INSERT INTO actions (trading_day_id, action_type, symbol, quantity, price) + VALUES (?, ?, ?, ?, ?) + """, + (trading_day_id, action_type, symbol, quantity, price) + ) + self.connection.commit() + return cursor.lastrowid + + def get_actions(self, trading_day_id: int) -> list: + """Get all actions for a trading day. + + Returns: + List of dicts with keys: action_type, symbol, quantity, price, created_at + """ + cursor = self.connection.execute( + """ + SELECT action_type, symbol, quantity, price, created_at + FROM actions + WHERE trading_day_id = ? + ORDER BY created_at + """, + (trading_day_id,) + ) + + return [ + { + "action_type": row[0], + "symbol": row[1], + "quantity": row[2], + "price": row[3], + "created_at": row[4] + } + for row in cursor.fetchall() + ] diff --git a/tests/unit/test_database_helpers.py b/tests/unit/test_database_helpers.py new file mode 100644 index 0000000..17a5b11 --- /dev/null +++ b/tests/unit/test_database_helpers.py @@ -0,0 +1,288 @@ +import pytest +from datetime import datetime +from api.database import Database + + +class TestDatabaseHelpers: + + @pytest.fixture + def db(self, tmp_path): + """Create test database with schema.""" + import importlib + migration_module = importlib.import_module('api.migrations.001_trading_days_schema') + create_trading_days_schema = migration_module.create_trading_days_schema + + db_path = tmp_path / "test.db" + db = Database(str(db_path)) + + # Create jobs table (prerequisite) + db.connection.execute(""" + CREATE TABLE IF NOT EXISTS jobs ( + job_id TEXT PRIMARY KEY, + status TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + create_trading_days_schema(db) + return db + + def test_create_trading_day(self, db): + """Test creating a new trading day record.""" + # Insert job first + db.connection.execute( + "INSERT INTO jobs (job_id, status) VALUES (?, ?)", + ("test-job", "running") + ) + + trading_day_id = db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-15", + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=9500.0, + ending_portfolio_value=9500.0 + ) + + assert trading_day_id is not None + + # Verify record created + cursor = db.connection.execute( + "SELECT * FROM trading_days WHERE id = ?", + (trading_day_id,) + ) + row = cursor.fetchone() + assert row is not None + + def test_get_previous_trading_day(self, db): + """Test retrieving previous trading day.""" + # Setup: Create job and two trading days + db.connection.execute( + "INSERT INTO jobs (job_id, status) VALUES (?, ?)", + ("test-job", "running") + ) + + day1_id = db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-15", + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=9500.0, + ending_portfolio_value=9500.0 + ) + + day2_id = db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-16", + starting_cash=9500.0, + starting_portfolio_value=9500.0, + daily_profit=-500.0, + daily_return_pct=-5.0, + ending_cash=9700.0, + ending_portfolio_value=9700.0 + ) + + # Test: Get previous day from day2 + previous = db.get_previous_trading_day( + job_id="test-job", + model="gpt-4", + current_date="2025-01-16" + ) + + assert previous is not None + assert previous["date"] == "2025-01-15" + assert previous["ending_cash"] == 9500.0 + + def test_get_previous_trading_day_with_weekend_gap(self, db): + """Test retrieving previous trading day across weekend.""" + db.connection.execute( + "INSERT INTO jobs (job_id, status) VALUES (?, ?)", + ("test-job", "running") + ) + + # Friday + db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-17", # Friday + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=9500.0, + ending_portfolio_value=9500.0 + ) + + # Test: Get previous from Monday (should find Friday) + previous = db.get_previous_trading_day( + job_id="test-job", + model="gpt-4", + current_date="2025-01-20" # Monday + ) + + assert previous is not None + assert previous["date"] == "2025-01-17" + + def test_get_ending_holdings(self, db): + """Test retrieving ending holdings for a trading day.""" + db.connection.execute( + "INSERT INTO jobs (job_id, status) VALUES (?, ?)", + ("test-job", "running") + ) + + trading_day_id = db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-15", + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=9000.0, + ending_portfolio_value=10000.0 + ) + + # Add holdings + db.create_holding(trading_day_id, "AAPL", 10) + db.create_holding(trading_day_id, "MSFT", 5) + + # Test + holdings = db.get_ending_holdings(trading_day_id) + + assert len(holdings) == 2 + assert {"symbol": "AAPL", "quantity": 10} in holdings + assert {"symbol": "MSFT", "quantity": 5} in holdings + + def test_get_starting_holdings_first_day(self, db): + """Test starting holdings for first trading day (should be empty).""" + db.connection.execute( + "INSERT INTO jobs (job_id, status) VALUES (?, ?)", + ("test-job", "running") + ) + + trading_day_id = db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-15", + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=9500.0, + ending_portfolio_value=9500.0 + ) + + holdings = db.get_starting_holdings(trading_day_id) + + assert holdings == [] + + def test_get_starting_holdings_from_previous_day(self, db): + """Test starting holdings derived from previous day's ending.""" + db.connection.execute( + "INSERT INTO jobs (job_id, status) VALUES (?, ?)", + ("test-job", "running") + ) + + # Day 1 + day1_id = db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-15", + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=9000.0, + ending_portfolio_value=10000.0 + ) + db.create_holding(day1_id, "AAPL", 10) + + # Day 2 + day2_id = db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-16", + starting_cash=9000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=8500.0, + ending_portfolio_value=9500.0 + ) + + # Test: Day 2 starting = Day 1 ending + holdings = db.get_starting_holdings(day2_id) + + assert len(holdings) == 1 + assert holdings[0]["symbol"] == "AAPL" + assert holdings[0]["quantity"] == 10 + + def test_create_action(self, db): + """Test creating an action record.""" + db.connection.execute( + "INSERT INTO jobs (job_id, status) VALUES (?, ?)", + ("test-job", "running") + ) + + trading_day_id = db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-15", + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=9500.0, + ending_portfolio_value=9500.0 + ) + + action_id = db.create_action( + trading_day_id=trading_day_id, + action_type="buy", + symbol="AAPL", + quantity=10, + price=100.0 + ) + + assert action_id is not None + + # Verify + cursor = db.connection.execute( + "SELECT * FROM actions WHERE id = ?", + (action_id,) + ) + row = cursor.fetchone() + assert row is not None + + def test_get_actions(self, db): + """Test retrieving all actions for a trading day.""" + db.connection.execute( + "INSERT INTO jobs (job_id, status) VALUES (?, ?)", + ("test-job", "running") + ) + + trading_day_id = db.create_trading_day( + job_id="test-job", + model="gpt-4", + date="2025-01-15", + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=9500.0, + ending_portfolio_value=9500.0 + ) + + db.create_action(trading_day_id, "buy", "AAPL", 10, 100.0) + db.create_action(trading_day_id, "sell", "MSFT", 5, 50.0) + + actions = db.get_actions(trading_day_id) + + assert len(actions) == 2