feat: migrate trade tools to write to actions table (new schema)

This commit implements Task 1 from the schema migration plan:
- Trade tools (buy/sell) now write to actions table instead of old positions table
- Added trading_day_id parameter to buy/sell functions
- Updated ContextInjector to inject trading_day_id
- Updated RuntimeConfigManager to include TRADING_DAY_ID in config
- Removed P&L calculation from trade functions (now done at trading_days level)
- Added tests verifying correct behavior with new schema

Changes:
- agent_tools/tool_trade.py: Modified _buy_impl and _sell_impl to write to actions table
- agent/context_injector.py: Added trading_day_id parameter and injection logic
- api/model_day_executor.py: Updated to read trading_day_id from runtime config
- api/runtime_manager.py: Added trading_day_id to config initialization
- tests/unit/test_trade_tools_new_schema.py: New tests for new schema compliance

All tests passing.
This commit is contained in:
2025-11-04 09:18:35 -05:00
parent faa2135668
commit 7d9d093d6c
5 changed files with 282 additions and 114 deletions

View File

@@ -17,7 +17,8 @@ class ContextInjector:
client = MultiServerMCPClient(config, tool_interceptors=[interceptor]) client = MultiServerMCPClient(config, tool_interceptors=[interceptor])
""" """
def __init__(self, signature: str, today_date: str, job_id: str = None, session_id: int = None): def __init__(self, signature: str, today_date: str, job_id: str = None,
session_id: int = None, trading_day_id: int = None):
""" """
Initialize context injector. Initialize context injector.
@@ -25,12 +26,14 @@ class ContextInjector:
signature: Model signature to inject signature: Model signature to inject
today_date: Trading date to inject today_date: Trading date to inject
job_id: Job UUID to inject (optional) job_id: Job UUID to inject (optional)
session_id: Trading session ID to inject (optional, updated during execution) session_id: Trading session ID to inject (optional, DEPRECATED)
trading_day_id: Trading day ID to inject (optional)
""" """
self.signature = signature self.signature = signature
self.today_date = today_date self.today_date = today_date
self.job_id = job_id self.job_id = job_id
self.session_id = session_id self.session_id = session_id # Deprecated but kept for compatibility
self.trading_day_id = trading_day_id
async def __call__( async def __call__(
self, self,
@@ -50,7 +53,7 @@ class ContextInjector:
# Inject context parameters for trade tools # Inject context parameters for trade tools
if request.name in ["buy", "sell"]: if request.name in ["buy", "sell"]:
# Debug: Log self attributes BEFORE injection # Debug: Log self attributes BEFORE injection
print(f"[ContextInjector.__call__] ENTRY: id={id(self)}, self.signature={self.signature}, self.today_date={self.today_date}, self.job_id={self.job_id}, self.session_id={self.session_id}") print(f"[ContextInjector.__call__] ENTRY: id={id(self)}, self.signature={self.signature}, self.today_date={self.today_date}, self.job_id={self.job_id}, self.session_id={self.session_id}, self.trading_day_id={self.trading_day_id}")
print(f"[ContextInjector.__call__] Args BEFORE injection: {request.args}") print(f"[ContextInjector.__call__] Args BEFORE injection: {request.args}")
# ALWAYS inject/override context parameters (don't trust AI-provided values) # ALWAYS inject/override context parameters (don't trust AI-provided values)
@@ -60,6 +63,8 @@ class ContextInjector:
request.args["job_id"] = self.job_id request.args["job_id"] = self.job_id
if self.session_id: if self.session_id:
request.args["session_id"] = self.session_id request.args["session_id"] = self.session_id
if self.trading_day_id:
request.args["trading_day_id"] = self.trading_day_id
# Debug logging # Debug logging
print(f"[ContextInjector] Tool: {request.name}, Args after injection: {request.args}") print(f"[ContextInjector] Tool: {request.name}, Args after injection: {request.args}")

View File

@@ -91,12 +91,21 @@ def get_current_position_from_db(job_id: str, model: str, date: str) -> Tuple[Di
def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = None, def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
job_id: str = None, session_id: int = None) -> Dict[str, Any]: job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
""" """
Internal buy implementation - accepts injected context parameters. Internal buy implementation - accepts injected context parameters.
Args:
symbol: Stock symbol
amount: Number of shares
signature: Model signature (injected)
today_date: Trading date (injected)
job_id: Job ID (injected)
session_id: Session ID (injected, DEPRECATED)
trading_day_id: Trading day ID (injected)
This function is not exposed to the AI model. It receives runtime context This function is not exposed to the AI model. It receives runtime context
(signature, today_date, job_id, session_id) from the ContextInjector. (signature, today_date, job_id, session_id, trading_day_id) from the ContextInjector.
""" """
# Validate required parameters # Validate required parameters
if not job_id: if not job_id:
@@ -139,61 +148,29 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str =
new_position["CASH"] = cash_left new_position["CASH"] = cash_left
new_position[symbol] = new_position.get(symbol, 0) + amount new_position[symbol] = new_position.get(symbol, 0) + amount
# Step 5: Calculate portfolio value and P&L # Step 5: Write to actions table (NEW SCHEMA)
portfolio_value = cash_left # NOTE: P&L is now calculated at the trading_days level, not per-trade
for sym, qty in new_position.items(): if trading_day_id is None:
if sym != "CASH": # Get trading_day_id from runtime config if not provided
try: from tools.general_tools import get_config_value
price = get_open_prices(today_date, [sym])[f'{sym}_price'] trading_day_id = get_config_value('TRADING_DAY_ID')
portfolio_value += qty * price
except KeyError:
pass # Symbol price not available, skip
# Get start-of-day portfolio value (action_id=0 for today) for P&L calculation if trading_day_id is None:
cursor.execute(""" raise ValueError("trading_day_id not found in runtime config")
SELECT portfolio_value
FROM positions
WHERE job_id = ? AND model = ? AND date = ? AND action_id = 0
LIMIT 1
""", (job_id, signature, today_date))
row = cursor.fetchone()
if row:
# Compare to start of day (action_id=0)
start_of_day_value = row[0]
daily_profit = portfolio_value - start_of_day_value
daily_return_pct = (daily_profit / start_of_day_value * 100) if start_of_day_value > 0 else 0
else:
# First action of first day - no baseline yet
daily_profit = 0.0
daily_return_pct = 0.0
# Step 6: Write to positions table
created_at = datetime.utcnow().isoformat() + "Z" created_at = datetime.utcnow().isoformat() + "Z"
cursor.execute(""" cursor.execute("""
INSERT INTO positions ( INSERT INTO actions (
job_id, date, model, action_id, action_type, symbol, trading_day_id, action_type, symbol, quantity, price, created_at
amount, price, cash, portfolio_value, daily_profit,
daily_return_pct, session_id, created_at
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""", ( """, (
job_id, today_date, signature, next_action_id, "buy", symbol, trading_day_id, "buy", symbol, amount, this_symbol_price, created_at
amount, this_symbol_price, cash_left, portfolio_value, daily_profit,
daily_return_pct, session_id, created_at
)) ))
position_id = cursor.lastrowid # NOTE: Holdings are written by BaseAgent at end of day, not per-trade
# This keeps the data model clean (one holdings snapshot per day)
# Step 7: Write to holdings table
for sym, qty in new_position.items():
if sym != "CASH":
cursor.execute("""
INSERT INTO holdings (position_id, symbol, quantity)
VALUES (?, ?, ?)
""", (position_id, sym, qty))
conn.commit() conn.commit()
print(f"[buy] {signature} bought {amount} shares of {symbol} at ${this_symbol_price}") print(f"[buy] {signature} bought {amount} shares of {symbol} at ${this_symbol_price}")
@@ -209,7 +186,7 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str =
@mcp.tool() @mcp.tool()
def buy(symbol: str, amount: int, signature: str = None, today_date: str = None, def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
job_id: str = None, session_id: int = None) -> Dict[str, Any]: job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
""" """
Buy stock shares. Buy stock shares.
@@ -222,15 +199,14 @@ def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...} - Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
- Failure: {"error": error_message, ...} - Failure: {"error": error_message, ...}
Note: signature, today_date, job_id, session_id are automatically injected by the system. Note: signature, today_date, job_id, session_id, trading_day_id are
Do not provide these parameters - they will be added automatically. automatically injected by the system. Do not provide these parameters.
""" """
# Delegate to internal implementation return _buy_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id)
return _buy_impl(symbol, amount, signature, today_date, job_id, session_id)
def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str = None, def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
job_id: str = None, session_id: int = None) -> Dict[str, Any]: job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
""" """
Sell stock function - writes to SQLite database. Sell stock function - writes to SQLite database.
@@ -240,7 +216,8 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
signature: Model signature (injected by ContextInjector) signature: Model signature (injected by ContextInjector)
today_date: Trading date YYYY-MM-DD (injected by ContextInjector) today_date: Trading date YYYY-MM-DD (injected by ContextInjector)
job_id: Job UUID (injected by ContextInjector) job_id: Job UUID (injected by ContextInjector)
session_id: Trading session ID (injected by ContextInjector) session_id: Trading session ID (injected by ContextInjector, DEPRECATED)
trading_day_id: Trading day ID (injected by ContextInjector)
Returns: Returns:
Dict[str, Any]: Dict[str, Any]:
@@ -287,62 +264,26 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
new_position[symbol] -= amount new_position[symbol] -= amount
new_position["CASH"] = new_position.get("CASH", 0) + (this_symbol_price * amount) new_position["CASH"] = new_position.get("CASH", 0) + (this_symbol_price * amount)
# Step 5: Calculate portfolio value and P&L # Step 5: Write to actions table (NEW SCHEMA)
portfolio_value = new_position["CASH"] # NOTE: P&L is now calculated at the trading_days level, not per-trade
for sym, qty in new_position.items(): if trading_day_id is None:
if sym != "CASH": from tools.general_tools import get_config_value
try: trading_day_id = get_config_value('TRADING_DAY_ID')
price = get_open_prices(today_date, [sym])[f'{sym}_price']
portfolio_value += qty * price
except KeyError:
pass
# Get start-of-day portfolio value (action_id=0 for today) for P&L calculation if trading_day_id is None:
cursor.execute(""" raise ValueError("trading_day_id not found in runtime config")
SELECT portfolio_value
FROM positions
WHERE job_id = ? AND model = ? AND date = ? AND action_id = 0
LIMIT 1
""", (job_id, signature, today_date))
row = cursor.fetchone()
if row:
# Compare to start of day (action_id=0)
start_of_day_value = row[0]
daily_profit = portfolio_value - start_of_day_value
daily_return_pct = (daily_profit / start_of_day_value * 100) if start_of_day_value > 0 else 0
else:
# First action of first day - no baseline yet
daily_profit = 0.0
daily_return_pct = 0.0
# Step 6: Write to positions table
created_at = datetime.utcnow().isoformat() + "Z" created_at = datetime.utcnow().isoformat() + "Z"
cursor.execute(""" cursor.execute("""
INSERT INTO positions ( INSERT INTO actions (
job_id, date, model, action_id, action_type, symbol, trading_day_id, action_type, symbol, quantity, price, created_at
amount, price, cash, portfolio_value, daily_profit,
daily_return_pct, session_id, created_at
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""", ( """, (
job_id, today_date, signature, next_action_id, "sell", symbol, trading_day_id, "sell", symbol, amount, this_symbol_price, created_at
amount, this_symbol_price, new_position["CASH"], portfolio_value, daily_profit,
daily_return_pct, session_id, created_at
)) ))
position_id = cursor.lastrowid
# Step 7: Write to holdings table
for sym, qty in new_position.items():
if sym != "CASH":
cursor.execute("""
INSERT INTO holdings (position_id, symbol, quantity)
VALUES (?, ?, ?)
""", (position_id, sym, qty))
conn.commit() conn.commit()
print(f"[sell] {signature} sold {amount} shares of {symbol} at ${this_symbol_price}") print(f"[sell] {signature} sold {amount} shares of {symbol} at ${this_symbol_price}")
return new_position return new_position
@@ -357,7 +298,7 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
@mcp.tool() @mcp.tool()
def sell(symbol: str, amount: int, signature: str = None, today_date: str = None, def sell(symbol: str, amount: int, signature: str = None, today_date: str = None,
job_id: str = None, session_id: int = None) -> Dict[str, Any]: job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
""" """
Sell stock shares. Sell stock shares.
@@ -370,11 +311,10 @@ def sell(symbol: str, amount: int, signature: str = None, today_date: str = None
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...} - Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
- Failure: {"error": error_message, ...} - Failure: {"error": error_message, ...}
Note: signature, today_date, job_id, session_id are automatically injected by the system. Note: signature, today_date, job_id, session_id, trading_day_id are
Do not provide these parameters - they will be added automatically. automatically injected by the system. Do not provide these parameters.
""" """
# Delegate to internal implementation return _sell_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id)
return _sell_impl(symbol, amount, signature, today_date, job_id, session_id)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -138,11 +138,15 @@ class ModelDayExecutor:
# Create and inject context with correct values # Create and inject context with correct values
from agent.context_injector import ContextInjector from agent.context_injector import ContextInjector
from tools.general_tools import get_config_value
trading_day_id = get_config_value('TRADING_DAY_ID') # Get from runtime config
context_injector = ContextInjector( context_injector = ContextInjector(
signature=self.model_sig, signature=self.model_sig,
today_date=self.date, # Current trading day today_date=self.date, # Current trading day
job_id=self.job_id, job_id=self.job_id,
session_id=session_id session_id=session_id,
trading_day_id=trading_day_id
) )
logger.info(f"[DEBUG] ModelDayExecutor: Created ContextInjector with signature={self.model_sig}, date={self.date}, job_id={self.job_id}, session_id={session_id}") logger.info(f"[DEBUG] ModelDayExecutor: Created ContextInjector with signature={self.model_sig}, date={self.date}, job_id={self.job_id}, session_id={session_id}")
logger.info(f"[DEBUG] ModelDayExecutor: Calling await agent.set_context()") logger.info(f"[DEBUG] ModelDayExecutor: Calling await agent.set_context()")

View File

@@ -48,7 +48,8 @@ class RuntimeConfigManager:
self, self,
job_id: str, job_id: str,
model_sig: str, model_sig: str,
date: str date: str,
trading_day_id: int = None
) -> str: ) -> str:
""" """
Create isolated runtime config file for this execution. Create isolated runtime config file for this execution.
@@ -57,6 +58,7 @@ class RuntimeConfigManager:
job_id: Job UUID job_id: Job UUID
model_sig: Model signature model_sig: Model signature
date: Trading date (YYYY-MM-DD) date: Trading date (YYYY-MM-DD)
trading_day_id: Trading day record ID (optional, can be set later)
Returns: Returns:
Path to created runtime config file Path to created runtime config file
@@ -79,7 +81,8 @@ class RuntimeConfigManager:
"TODAY_DATE": date, "TODAY_DATE": date,
"SIGNATURE": model_sig, "SIGNATURE": model_sig,
"IF_TRADE": False, "IF_TRADE": False,
"JOB_ID": job_id "JOB_ID": job_id,
"TRADING_DAY_ID": trading_day_id
} }
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:

View File

@@ -0,0 +1,216 @@
"""Test trade tools write to new schema (actions table)."""
import pytest
import sqlite3
from agent_tools.tool_trade import _buy_impl, _sell_impl
from api.database import Database
from tools.deployment_config import get_db_path
@pytest.fixture
def test_db():
"""Create test database with new schema."""
db_path = ":memory:"
db = Database(db_path)
# Create jobs table (prerequisite)
db.connection.execute("""
CREATE TABLE IF NOT EXISTS jobs (
job_id TEXT PRIMARY KEY,
config_path TEXT NOT NULL,
status TEXT NOT NULL,
date_range TEXT NOT NULL,
models TEXT NOT NULL,
created_at TEXT NOT NULL
)
""")
db.connection.execute("""
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
VALUES ('test-job-123', 'test_config.json', 'running', '2025-01-15', '["test-model"]', '2025-01-15T10:00:00Z')
""")
# Create trading_days record
trading_day_id = db.create_trading_day(
job_id='test-job-123',
model='test-model',
date='2025-01-15',
starting_cash=10000.0,
starting_portfolio_value=10000.0,
daily_profit=0.0,
daily_return_pct=0.0,
ending_cash=10000.0,
ending_portfolio_value=10000.0,
days_since_last_trading=0
)
db.connection.commit()
yield db, trading_day_id
db.connection.close()
def test_buy_writes_to_actions_table(test_db, monkeypatch):
"""Test buy() writes action record to actions table."""
db, trading_day_id = test_db
# Create a mock connection wrapper that doesn't actually close
class MockConnection:
def __init__(self, real_conn):
self.real_conn = real_conn
def cursor(self):
return self.real_conn.cursor()
def execute(self, *args, **kwargs):
return self.real_conn.execute(*args, **kwargs)
def commit(self):
return self.real_conn.commit()
def rollback(self):
return self.real_conn.rollback()
def close(self):
pass # Don't actually close the connection
mock_conn = MockConnection(db.connection)
# Mock get_db_connection to return our mock connection
monkeypatch.setattr('agent_tools.tool_trade.get_db_connection',
lambda x: mock_conn)
# Mock get_current_position_from_db to return starting position
monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db',
lambda job_id, sig, date: ({'CASH': 10000.0}, 0))
# Mock runtime config
monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime.json')
# Create mock runtime config file
import json
with open('/tmp/test_runtime.json', 'w') as f:
json.dump({
'TODAY_DATE': '2025-01-15',
'SIGNATURE': 'test-model',
'JOB_ID': 'test-job-123',
'TRADING_DAY_ID': trading_day_id
}, f)
# Mock price data
monkeypatch.setattr('agent_tools.tool_trade.get_open_prices',
lambda date, symbols: {'AAPL_price': 150.0})
# Execute buy
result = _buy_impl(
symbol='AAPL',
amount=10,
signature='test-model',
today_date='2025-01-15',
job_id='test-job-123',
trading_day_id=trading_day_id
)
# Check if there was an error
if 'error' in result:
print(f"Buy failed with error: {result}")
# Verify action record created
cursor = db.connection.execute("""
SELECT action_type, symbol, quantity, price, trading_day_id
FROM actions
WHERE trading_day_id = ?
""", (trading_day_id,))
row = cursor.fetchone()
assert row is not None, "Action record should exist"
assert row[0] == 'buy'
assert row[1] == 'AAPL'
assert row[2] == 10
assert row[3] == 150.0
assert row[4] == trading_day_id
# Verify NO write to old positions table
cursor = db.connection.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND name='positions'
""")
assert cursor.fetchone() is None, "Old positions table should not exist"
def test_sell_writes_to_actions_table(test_db, monkeypatch):
"""Test sell() writes action record to actions table."""
db, trading_day_id = test_db
# Setup: Create starting holdings
db.create_holding(trading_day_id, 'AAPL', 10)
db.connection.commit()
# Create a mock connection wrapper that doesn't actually close
class MockConnection:
def __init__(self, real_conn):
self.real_conn = real_conn
def cursor(self):
return self.real_conn.cursor()
def execute(self, *args, **kwargs):
return self.real_conn.execute(*args, **kwargs)
def commit(self):
return self.real_conn.commit()
def rollback(self):
return self.real_conn.rollback()
def close(self):
pass # Don't actually close the connection
mock_conn = MockConnection(db.connection)
# Mock dependencies
monkeypatch.setattr('agent_tools.tool_trade.get_db_connection',
lambda x: mock_conn)
# Mock get_current_position_from_db to return position with AAPL shares
monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db',
lambda job_id, sig, date: ({'CASH': 10000.0, 'AAPL': 10}, 0))
monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime.json')
import json
with open('/tmp/test_runtime.json', 'w') as f:
json.dump({
'TODAY_DATE': '2025-01-15',
'SIGNATURE': 'test-model',
'JOB_ID': 'test-job-123',
'TRADING_DAY_ID': trading_day_id
}, f)
monkeypatch.setattr('agent_tools.tool_trade.get_open_prices',
lambda date, symbols: {'AAPL_price': 160.0})
# Execute sell
result = _sell_impl(
symbol='AAPL',
amount=5,
signature='test-model',
today_date='2025-01-15',
job_id='test-job-123',
trading_day_id=trading_day_id
)
# Verify action record created
cursor = db.connection.execute("""
SELECT action_type, symbol, quantity, price
FROM actions
WHERE trading_day_id = ? AND action_type = 'sell'
""", (trading_day_id,))
row = cursor.fetchone()
assert row is not None
assert row[0] == 'sell'
assert row[1] == 'AAPL'
assert row[2] == 5
assert row[3] == 160.0