From 7d9d093d6cc903fd93901e1152d4ff03cd53c700 Mon Sep 17 00:00:00 2001 From: Bill Date: Tue, 4 Nov 2025 09:18:35 -0500 Subject: [PATCH] feat: migrate trade tools to write to actions table (new schema) This commit implements Task 1 from the schema migration plan: - Trade tools (buy/sell) now write to actions table instead of old positions table - Added trading_day_id parameter to buy/sell functions - Updated ContextInjector to inject trading_day_id - Updated RuntimeConfigManager to include TRADING_DAY_ID in config - Removed P&L calculation from trade functions (now done at trading_days level) - Added tests verifying correct behavior with new schema Changes: - agent_tools/tool_trade.py: Modified _buy_impl and _sell_impl to write to actions table - agent/context_injector.py: Added trading_day_id parameter and injection logic - api/model_day_executor.py: Updated to read trading_day_id from runtime config - api/runtime_manager.py: Added trading_day_id to config initialization - tests/unit/test_trade_tools_new_schema.py: New tests for new schema compliance All tests passing. --- agent/context_injector.py | 13 +- agent_tools/tool_trade.py | 154 +++++---------- api/model_day_executor.py | 6 +- api/runtime_manager.py | 7 +- tests/unit/test_trade_tools_new_schema.py | 216 ++++++++++++++++++++++ 5 files changed, 282 insertions(+), 114 deletions(-) create mode 100644 tests/unit/test_trade_tools_new_schema.py diff --git a/agent/context_injector.py b/agent/context_injector.py index 0873d2d..d9a9972 100644 --- a/agent/context_injector.py +++ b/agent/context_injector.py @@ -17,7 +17,8 @@ class ContextInjector: client = MultiServerMCPClient(config, tool_interceptors=[interceptor]) """ - def __init__(self, signature: str, today_date: str, job_id: str = None, session_id: int = None): + def __init__(self, signature: str, today_date: str, job_id: str = None, + session_id: int = None, trading_day_id: int = None): """ Initialize context injector. @@ -25,12 +26,14 @@ class ContextInjector: signature: Model signature to inject today_date: Trading date to inject job_id: Job UUID to inject (optional) - session_id: Trading session ID to inject (optional, updated during execution) + session_id: Trading session ID to inject (optional, DEPRECATED) + trading_day_id: Trading day ID to inject (optional) """ self.signature = signature self.today_date = today_date self.job_id = job_id - self.session_id = session_id + self.session_id = session_id # Deprecated but kept for compatibility + self.trading_day_id = trading_day_id async def __call__( self, @@ -50,7 +53,7 @@ class ContextInjector: # Inject context parameters for trade tools if request.name in ["buy", "sell"]: # Debug: Log self attributes BEFORE injection - print(f"[ContextInjector.__call__] ENTRY: id={id(self)}, self.signature={self.signature}, self.today_date={self.today_date}, self.job_id={self.job_id}, self.session_id={self.session_id}") + print(f"[ContextInjector.__call__] ENTRY: id={id(self)}, self.signature={self.signature}, self.today_date={self.today_date}, self.job_id={self.job_id}, self.session_id={self.session_id}, self.trading_day_id={self.trading_day_id}") print(f"[ContextInjector.__call__] Args BEFORE injection: {request.args}") # ALWAYS inject/override context parameters (don't trust AI-provided values) @@ -60,6 +63,8 @@ class ContextInjector: request.args["job_id"] = self.job_id if self.session_id: request.args["session_id"] = self.session_id + if self.trading_day_id: + request.args["trading_day_id"] = self.trading_day_id # Debug logging print(f"[ContextInjector] Tool: {request.name}, Args after injection: {request.args}") diff --git a/agent_tools/tool_trade.py b/agent_tools/tool_trade.py index 5404386..a634d0f 100644 --- a/agent_tools/tool_trade.py +++ b/agent_tools/tool_trade.py @@ -91,12 +91,21 @@ def get_current_position_from_db(job_id: str, model: str, date: str) -> Tuple[Di def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = None, - job_id: str = None, session_id: int = None) -> Dict[str, Any]: + job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]: """ Internal buy implementation - accepts injected context parameters. + Args: + symbol: Stock symbol + amount: Number of shares + signature: Model signature (injected) + today_date: Trading date (injected) + job_id: Job ID (injected) + session_id: Session ID (injected, DEPRECATED) + trading_day_id: Trading day ID (injected) + This function is not exposed to the AI model. It receives runtime context - (signature, today_date, job_id, session_id) from the ContextInjector. + (signature, today_date, job_id, session_id, trading_day_id) from the ContextInjector. """ # Validate required parameters if not job_id: @@ -139,61 +148,29 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = new_position["CASH"] = cash_left new_position[symbol] = new_position.get(symbol, 0) + amount - # Step 5: Calculate portfolio value and P&L - portfolio_value = cash_left - for sym, qty in new_position.items(): - if sym != "CASH": - try: - price = get_open_prices(today_date, [sym])[f'{sym}_price'] - portfolio_value += qty * price - except KeyError: - pass # Symbol price not available, skip + # Step 5: Write to actions table (NEW SCHEMA) + # NOTE: P&L is now calculated at the trading_days level, not per-trade + if trading_day_id is None: + # Get trading_day_id from runtime config if not provided + from tools.general_tools import get_config_value + trading_day_id = get_config_value('TRADING_DAY_ID') - # Get start-of-day portfolio value (action_id=0 for today) for P&L calculation - cursor.execute(""" - SELECT portfolio_value - FROM positions - WHERE job_id = ? AND model = ? AND date = ? AND action_id = 0 - LIMIT 1 - """, (job_id, signature, today_date)) + if trading_day_id is None: + raise ValueError("trading_day_id not found in runtime config") - row = cursor.fetchone() - - if row: - # Compare to start of day (action_id=0) - start_of_day_value = row[0] - daily_profit = portfolio_value - start_of_day_value - daily_return_pct = (daily_profit / start_of_day_value * 100) if start_of_day_value > 0 else 0 - else: - # First action of first day - no baseline yet - daily_profit = 0.0 - daily_return_pct = 0.0 - - # Step 6: Write to positions table created_at = datetime.utcnow().isoformat() + "Z" cursor.execute(""" - INSERT INTO positions ( - job_id, date, model, action_id, action_type, symbol, - amount, price, cash, portfolio_value, daily_profit, - daily_return_pct, session_id, created_at + INSERT INTO actions ( + trading_day_id, action_type, symbol, quantity, price, created_at ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?) """, ( - job_id, today_date, signature, next_action_id, "buy", symbol, - amount, this_symbol_price, cash_left, portfolio_value, daily_profit, - daily_return_pct, session_id, created_at + trading_day_id, "buy", symbol, amount, this_symbol_price, created_at )) - position_id = cursor.lastrowid - - # Step 7: Write to holdings table - for sym, qty in new_position.items(): - if sym != "CASH": - cursor.execute(""" - INSERT INTO holdings (position_id, symbol, quantity) - VALUES (?, ?, ?) - """, (position_id, sym, qty)) + # NOTE: Holdings are written by BaseAgent at end of day, not per-trade + # This keeps the data model clean (one holdings snapshot per day) conn.commit() print(f"[buy] {signature} bought {amount} shares of {symbol} at ${this_symbol_price}") @@ -209,7 +186,7 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = @mcp.tool() def buy(symbol: str, amount: int, signature: str = None, today_date: str = None, - job_id: str = None, session_id: int = None) -> Dict[str, Any]: + job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]: """ Buy stock shares. @@ -222,15 +199,14 @@ def buy(symbol: str, amount: int, signature: str = None, today_date: str = None, - Success: {"CASH": remaining_cash, "SYMBOL": shares, ...} - Failure: {"error": error_message, ...} - Note: signature, today_date, job_id, session_id are automatically injected by the system. - Do not provide these parameters - they will be added automatically. + Note: signature, today_date, job_id, session_id, trading_day_id are + automatically injected by the system. Do not provide these parameters. """ - # Delegate to internal implementation - return _buy_impl(symbol, amount, signature, today_date, job_id, session_id) + return _buy_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id) def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str = None, - job_id: str = None, session_id: int = None) -> Dict[str, Any]: + job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]: """ Sell stock function - writes to SQLite database. @@ -240,7 +216,8 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str signature: Model signature (injected by ContextInjector) today_date: Trading date YYYY-MM-DD (injected by ContextInjector) job_id: Job UUID (injected by ContextInjector) - session_id: Trading session ID (injected by ContextInjector) + session_id: Trading session ID (injected by ContextInjector, DEPRECATED) + trading_day_id: Trading day ID (injected by ContextInjector) Returns: Dict[str, Any]: @@ -287,62 +264,26 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str new_position[symbol] -= amount new_position["CASH"] = new_position.get("CASH", 0) + (this_symbol_price * amount) - # Step 5: Calculate portfolio value and P&L - portfolio_value = new_position["CASH"] - for sym, qty in new_position.items(): - if sym != "CASH": - try: - price = get_open_prices(today_date, [sym])[f'{sym}_price'] - portfolio_value += qty * price - except KeyError: - pass + # Step 5: Write to actions table (NEW SCHEMA) + # NOTE: P&L is now calculated at the trading_days level, not per-trade + if trading_day_id is None: + from tools.general_tools import get_config_value + trading_day_id = get_config_value('TRADING_DAY_ID') - # Get start-of-day portfolio value (action_id=0 for today) for P&L calculation - cursor.execute(""" - SELECT portfolio_value - FROM positions - WHERE job_id = ? AND model = ? AND date = ? AND action_id = 0 - LIMIT 1 - """, (job_id, signature, today_date)) + if trading_day_id is None: + raise ValueError("trading_day_id not found in runtime config") - row = cursor.fetchone() - - if row: - # Compare to start of day (action_id=0) - start_of_day_value = row[0] - daily_profit = portfolio_value - start_of_day_value - daily_return_pct = (daily_profit / start_of_day_value * 100) if start_of_day_value > 0 else 0 - else: - # First action of first day - no baseline yet - daily_profit = 0.0 - daily_return_pct = 0.0 - - # Step 6: Write to positions table created_at = datetime.utcnow().isoformat() + "Z" cursor.execute(""" - INSERT INTO positions ( - job_id, date, model, action_id, action_type, symbol, - amount, price, cash, portfolio_value, daily_profit, - daily_return_pct, session_id, created_at + INSERT INTO actions ( + trading_day_id, action_type, symbol, quantity, price, created_at ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?) """, ( - job_id, today_date, signature, next_action_id, "sell", symbol, - amount, this_symbol_price, new_position["CASH"], portfolio_value, daily_profit, - daily_return_pct, session_id, created_at + trading_day_id, "sell", symbol, amount, this_symbol_price, created_at )) - position_id = cursor.lastrowid - - # Step 7: Write to holdings table - for sym, qty in new_position.items(): - if sym != "CASH": - cursor.execute(""" - INSERT INTO holdings (position_id, symbol, quantity) - VALUES (?, ?, ?) - """, (position_id, sym, qty)) - conn.commit() print(f"[sell] {signature} sold {amount} shares of {symbol} at ${this_symbol_price}") return new_position @@ -357,7 +298,7 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str @mcp.tool() def sell(symbol: str, amount: int, signature: str = None, today_date: str = None, - job_id: str = None, session_id: int = None) -> Dict[str, Any]: + job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]: """ Sell stock shares. @@ -370,11 +311,10 @@ def sell(symbol: str, amount: int, signature: str = None, today_date: str = None - Success: {"CASH": remaining_cash, "SYMBOL": shares, ...} - Failure: {"error": error_message, ...} - Note: signature, today_date, job_id, session_id are automatically injected by the system. - Do not provide these parameters - they will be added automatically. + Note: signature, today_date, job_id, session_id, trading_day_id are + automatically injected by the system. Do not provide these parameters. """ - # Delegate to internal implementation - return _sell_impl(symbol, amount, signature, today_date, job_id, session_id) + return _sell_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id) if __name__ == "__main__": diff --git a/api/model_day_executor.py b/api/model_day_executor.py index 8381f57..07bf888 100644 --- a/api/model_day_executor.py +++ b/api/model_day_executor.py @@ -138,11 +138,15 @@ class ModelDayExecutor: # Create and inject context with correct values from agent.context_injector import ContextInjector + from tools.general_tools import get_config_value + trading_day_id = get_config_value('TRADING_DAY_ID') # Get from runtime config + context_injector = ContextInjector( signature=self.model_sig, today_date=self.date, # Current trading day job_id=self.job_id, - session_id=session_id + session_id=session_id, + trading_day_id=trading_day_id ) logger.info(f"[DEBUG] ModelDayExecutor: Created ContextInjector with signature={self.model_sig}, date={self.date}, job_id={self.job_id}, session_id={session_id}") logger.info(f"[DEBUG] ModelDayExecutor: Calling await agent.set_context()") diff --git a/api/runtime_manager.py b/api/runtime_manager.py index d7880ac..eaa0608 100644 --- a/api/runtime_manager.py +++ b/api/runtime_manager.py @@ -48,7 +48,8 @@ class RuntimeConfigManager: self, job_id: str, model_sig: str, - date: str + date: str, + trading_day_id: int = None ) -> str: """ Create isolated runtime config file for this execution. @@ -57,6 +58,7 @@ class RuntimeConfigManager: job_id: Job UUID model_sig: Model signature date: Trading date (YYYY-MM-DD) + trading_day_id: Trading day record ID (optional, can be set later) Returns: Path to created runtime config file @@ -79,7 +81,8 @@ class RuntimeConfigManager: "TODAY_DATE": date, "SIGNATURE": model_sig, "IF_TRADE": False, - "JOB_ID": job_id + "JOB_ID": job_id, + "TRADING_DAY_ID": trading_day_id } with open(config_path, "w", encoding="utf-8") as f: diff --git a/tests/unit/test_trade_tools_new_schema.py b/tests/unit/test_trade_tools_new_schema.py new file mode 100644 index 0000000..fa2f508 --- /dev/null +++ b/tests/unit/test_trade_tools_new_schema.py @@ -0,0 +1,216 @@ +"""Test trade tools write to new schema (actions table).""" + +import pytest +import sqlite3 +from agent_tools.tool_trade import _buy_impl, _sell_impl +from api.database import Database +from tools.deployment_config import get_db_path + + +@pytest.fixture +def test_db(): + """Create test database with new schema.""" + db_path = ":memory:" + db = Database(db_path) + + # Create jobs table (prerequisite) + db.connection.execute(""" + CREATE TABLE IF NOT EXISTS jobs ( + job_id TEXT PRIMARY KEY, + config_path TEXT NOT NULL, + status TEXT NOT NULL, + date_range TEXT NOT NULL, + models TEXT NOT NULL, + created_at TEXT NOT NULL + ) + """) + + db.connection.execute(""" + INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at) + VALUES ('test-job-123', 'test_config.json', 'running', '2025-01-15', '["test-model"]', '2025-01-15T10:00:00Z') + """) + + # Create trading_days record + trading_day_id = db.create_trading_day( + job_id='test-job-123', + model='test-model', + date='2025-01-15', + starting_cash=10000.0, + starting_portfolio_value=10000.0, + daily_profit=0.0, + daily_return_pct=0.0, + ending_cash=10000.0, + ending_portfolio_value=10000.0, + days_since_last_trading=0 + ) + + db.connection.commit() + + yield db, trading_day_id + + db.connection.close() + + +def test_buy_writes_to_actions_table(test_db, monkeypatch): + """Test buy() writes action record to actions table.""" + db, trading_day_id = test_db + + # Create a mock connection wrapper that doesn't actually close + class MockConnection: + def __init__(self, real_conn): + self.real_conn = real_conn + + def cursor(self): + return self.real_conn.cursor() + + def execute(self, *args, **kwargs): + return self.real_conn.execute(*args, **kwargs) + + def commit(self): + return self.real_conn.commit() + + def rollback(self): + return self.real_conn.rollback() + + def close(self): + pass # Don't actually close the connection + + mock_conn = MockConnection(db.connection) + + # Mock get_db_connection to return our mock connection + monkeypatch.setattr('agent_tools.tool_trade.get_db_connection', + lambda x: mock_conn) + + # Mock get_current_position_from_db to return starting position + monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db', + lambda job_id, sig, date: ({'CASH': 10000.0}, 0)) + + # Mock runtime config + monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime.json') + + # Create mock runtime config file + import json + with open('/tmp/test_runtime.json', 'w') as f: + json.dump({ + 'TODAY_DATE': '2025-01-15', + 'SIGNATURE': 'test-model', + 'JOB_ID': 'test-job-123', + 'TRADING_DAY_ID': trading_day_id + }, f) + + # Mock price data + monkeypatch.setattr('agent_tools.tool_trade.get_open_prices', + lambda date, symbols: {'AAPL_price': 150.0}) + + # Execute buy + result = _buy_impl( + symbol='AAPL', + amount=10, + signature='test-model', + today_date='2025-01-15', + job_id='test-job-123', + trading_day_id=trading_day_id + ) + + # Check if there was an error + if 'error' in result: + print(f"Buy failed with error: {result}") + + # Verify action record created + cursor = db.connection.execute(""" + SELECT action_type, symbol, quantity, price, trading_day_id + FROM actions + WHERE trading_day_id = ? + """, (trading_day_id,)) + + row = cursor.fetchone() + assert row is not None, "Action record should exist" + assert row[0] == 'buy' + assert row[1] == 'AAPL' + assert row[2] == 10 + assert row[3] == 150.0 + assert row[4] == trading_day_id + + # Verify NO write to old positions table + cursor = db.connection.execute(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='positions' + """) + assert cursor.fetchone() is None, "Old positions table should not exist" + + +def test_sell_writes_to_actions_table(test_db, monkeypatch): + """Test sell() writes action record to actions table.""" + db, trading_day_id = test_db + + # Setup: Create starting holdings + db.create_holding(trading_day_id, 'AAPL', 10) + db.connection.commit() + + # Create a mock connection wrapper that doesn't actually close + class MockConnection: + def __init__(self, real_conn): + self.real_conn = real_conn + + def cursor(self): + return self.real_conn.cursor() + + def execute(self, *args, **kwargs): + return self.real_conn.execute(*args, **kwargs) + + def commit(self): + return self.real_conn.commit() + + def rollback(self): + return self.real_conn.rollback() + + def close(self): + pass # Don't actually close the connection + + mock_conn = MockConnection(db.connection) + + # Mock dependencies + monkeypatch.setattr('agent_tools.tool_trade.get_db_connection', + lambda x: mock_conn) + + # Mock get_current_position_from_db to return position with AAPL shares + monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db', + lambda job_id, sig, date: ({'CASH': 10000.0, 'AAPL': 10}, 0)) + + monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime.json') + + import json + with open('/tmp/test_runtime.json', 'w') as f: + json.dump({ + 'TODAY_DATE': '2025-01-15', + 'SIGNATURE': 'test-model', + 'JOB_ID': 'test-job-123', + 'TRADING_DAY_ID': trading_day_id + }, f) + + monkeypatch.setattr('agent_tools.tool_trade.get_open_prices', + lambda date, symbols: {'AAPL_price': 160.0}) + + # Execute sell + result = _sell_impl( + symbol='AAPL', + amount=5, + signature='test-model', + today_date='2025-01-15', + job_id='test-job-123', + trading_day_id=trading_day_id + ) + + # Verify action record created + cursor = db.connection.execute(""" + SELECT action_type, symbol, quantity, price + FROM actions + WHERE trading_day_id = ? AND action_type = 'sell' + """, (trading_day_id,)) + + row = cursor.fetchone() + assert row is not None + assert row[0] == 'sell' + assert row[1] == 'AAPL' + assert row[2] == 5 + assert row[3] == 160.0