From 8aedb058e24ff86bf599b4096ec3ab3123f6ba98 Mon Sep 17 00:00:00 2001 From: Bill Date: Tue, 4 Nov 2025 09:47:02 -0500 Subject: [PATCH] refactor: update get_current_position_from_db to query new schema --- agent_tools/tool_trade.py | 76 +++++++++-------- tests/unit/test_get_position_new_schema.py | 97 ++++++++++++++++++++++ 2 files changed, 134 insertions(+), 39 deletions(-) create mode 100644 tests/unit/test_get_position_new_schema.py diff --git a/agent_tools/tool_trade.py b/agent_tools/tool_trade.py index 7266a9f..6fa9429 100644 --- a/agent_tools/tool_trade.py +++ b/agent_tools/tool_trade.py @@ -20,71 +20,69 @@ from datetime import datetime, timezone mcp = FastMCP("TradeTools") -def get_current_position_from_db(job_id: str, model: str, date: str) -> Tuple[Dict[str, float], int]: +def get_current_position_from_db( + job_id: str, + model: str, + date: str, + initial_cash: float = 10000.0 +) -> Tuple[Dict[str, float], int]: """ - Query current position from SQLite database. + Get current position from database (new schema). + + Queries most recent trading_day record for this job+model up to date. + Returns ending holdings and cash from that day. Args: job_id: Job UUID model: Model signature - date: Trading date (YYYY-MM-DD) + date: Current trading date + initial_cash: Initial cash if no prior data Returns: - Tuple of (position_dict, next_action_id) - - position_dict: {symbol: quantity, "CASH": amount} - - next_action_id: Next available action_id for this job+model - - Raises: - Exception: If database query fails + (position_dict, action_count) where: + - position_dict: {"AAPL": 10, "MSFT": 5, "CASH": 8500.0} + - action_count: Number of holdings (for action_id tracking) """ - db_path = "data/jobs.db" + db_path = "data/trading.db" conn = get_db_connection(db_path) cursor = conn.cursor() try: - # Get most recent position on or before this date + # Query most recent trading_day up to date cursor.execute(""" - SELECT p.id, p.cash - FROM positions p - WHERE p.job_id = ? AND p.model = ? AND p.date <= ? - ORDER BY p.date DESC, p.action_id DESC + SELECT id, ending_cash + FROM trading_days + WHERE job_id = ? AND model = ? AND date <= ? + ORDER BY date DESC LIMIT 1 """, (job_id, model, date)) - position_row = cursor.fetchone() + row = cursor.fetchone() - if not position_row: - # No position found - this shouldn't happen if ModelDayExecutor initializes properly - raise Exception(f"No position found for job_id={job_id}, model={model}, date={date}") + if row is None: + # First day - return initial position + return {"CASH": initial_cash}, 0 - position_id = position_row[0] - cash = position_row[1] + trading_day_id, ending_cash = row - # Build position dict starting with CASH - position_dict = {"CASH": cash} - - # Get holdings for this position + # Query holdings for that day cursor.execute(""" SELECT symbol, quantity FROM holdings - WHERE position_id = ? - """, (position_id,)) + WHERE trading_day_id = ? + """, (trading_day_id,)) - for row in cursor.fetchall(): - symbol = row[0] - quantity = row[1] - position_dict[symbol] = quantity + holdings_rows = cursor.fetchall() - # Get next action_id - cursor.execute(""" - SELECT COALESCE(MAX(action_id), -1) + 1 as next_action_id - FROM positions - WHERE job_id = ? AND model = ? - """, (job_id, model)) + # Build position dict + position = {"CASH": ending_cash} + for symbol, quantity in holdings_rows: + position[symbol] = quantity - next_action_id = cursor.fetchone()[0] + # Action count is number of holdings (used for action_id) + action_count = len(holdings_rows) - return position_dict, next_action_id + return position, action_count finally: conn.close() diff --git a/tests/unit/test_get_position_new_schema.py b/tests/unit/test_get_position_new_schema.py new file mode 100644 index 0000000..6bfd3ab --- /dev/null +++ b/tests/unit/test_get_position_new_schema.py @@ -0,0 +1,97 @@ +"""Test get_current_position_from_db queries new schema.""" + +import pytest +from agent_tools.tool_trade import get_current_position_from_db +from api.database import Database + + +def test_get_position_from_new_schema(): + """Test position retrieval from trading_days + holdings.""" + + # Create test database + db = Database(":memory:") + + # Create prerequisite: jobs record + db.connection.execute(""" + INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at) + VALUES ('test-job-123', 'test_config.json', 'running', '2025-01-15 to 2025-01-15', 'test-model', '2025-01-15T10:00:00Z') + """) + db.connection.commit() + + # Create trading_day with holdings + trading_day_id = db.create_trading_day( + job_id='test-job-123', + model='test-model', + date='2025-01-15', + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=8000.0, + ending_portfolio_value=9500.0, + days_since_last_trading=0 + ) + + # Add ending holdings + db.create_holding(trading_day_id, 'AAPL', 10) + db.create_holding(trading_day_id, 'MSFT', 5) + + db.connection.commit() + + # Mock get_db_connection to return our test db + import agent_tools.tool_trade as trade_module + original_get_db_connection = trade_module.get_db_connection + + def mock_get_db_connection(path): + return db.connection + + trade_module.get_db_connection = mock_get_db_connection + + try: + # Query position + position, action_id = get_current_position_from_db( + job_id='test-job-123', + model='test-model', + date='2025-01-15' + ) + + # Verify + assert position['AAPL'] == 10 + assert position['MSFT'] == 5 + assert position['CASH'] == 8000.0 + assert action_id == 2 # 2 holdings = 2 actions + finally: + # Restore original function + trade_module.get_db_connection = original_get_db_connection + db.connection.close() + + +def test_get_position_first_day(): + """Test position retrieval on first day (no prior data).""" + + db = Database(":memory:") + + # Mock get_db_connection to return our test db + import agent_tools.tool_trade as trade_module + original_get_db_connection = trade_module.get_db_connection + + def mock_get_db_connection(path): + return db.connection + + trade_module.get_db_connection = mock_get_db_connection + + try: + # Query position (no data exists) + position, action_id = get_current_position_from_db( + job_id='test-job-123', + model='test-model', + date='2025-01-15' + ) + + # Should return initial position + assert position['CASH'] == 10000.0 # Default initial cash + assert action_id == 0 + finally: + # Restore original function + trade_module.get_db_connection = original_get_db_connection + db.connection.close()