mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-11 21:17:25 -04:00
feat: migrate trade tools to write to actions table (new schema)
This commit implements Task 1 from the schema migration plan: - Trade tools (buy/sell) now write to actions table instead of old positions table - Added trading_day_id parameter to buy/sell functions - Updated ContextInjector to inject trading_day_id - Updated RuntimeConfigManager to include TRADING_DAY_ID in config - Removed P&L calculation from trade functions (now done at trading_days level) - Added tests verifying correct behavior with new schema Changes: - agent_tools/tool_trade.py: Modified _buy_impl and _sell_impl to write to actions table - agent/context_injector.py: Added trading_day_id parameter and injection logic - api/model_day_executor.py: Updated to read trading_day_id from runtime config - api/runtime_manager.py: Added trading_day_id to config initialization - tests/unit/test_trade_tools_new_schema.py: New tests for new schema compliance All tests passing.
This commit is contained in:
@@ -17,7 +17,8 @@ class ContextInjector:
|
|||||||
client = MultiServerMCPClient(config, tool_interceptors=[interceptor])
|
client = MultiServerMCPClient(config, tool_interceptors=[interceptor])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, signature: str, today_date: str, job_id: str = None, session_id: int = None):
|
def __init__(self, signature: str, today_date: str, job_id: str = None,
|
||||||
|
session_id: int = None, trading_day_id: int = None):
|
||||||
"""
|
"""
|
||||||
Initialize context injector.
|
Initialize context injector.
|
||||||
|
|
||||||
@@ -25,12 +26,14 @@ class ContextInjector:
|
|||||||
signature: Model signature to inject
|
signature: Model signature to inject
|
||||||
today_date: Trading date to inject
|
today_date: Trading date to inject
|
||||||
job_id: Job UUID to inject (optional)
|
job_id: Job UUID to inject (optional)
|
||||||
session_id: Trading session ID to inject (optional, updated during execution)
|
session_id: Trading session ID to inject (optional, DEPRECATED)
|
||||||
|
trading_day_id: Trading day ID to inject (optional)
|
||||||
"""
|
"""
|
||||||
self.signature = signature
|
self.signature = signature
|
||||||
self.today_date = today_date
|
self.today_date = today_date
|
||||||
self.job_id = job_id
|
self.job_id = job_id
|
||||||
self.session_id = session_id
|
self.session_id = session_id # Deprecated but kept for compatibility
|
||||||
|
self.trading_day_id = trading_day_id
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -50,7 +53,7 @@ class ContextInjector:
|
|||||||
# Inject context parameters for trade tools
|
# Inject context parameters for trade tools
|
||||||
if request.name in ["buy", "sell"]:
|
if request.name in ["buy", "sell"]:
|
||||||
# Debug: Log self attributes BEFORE injection
|
# Debug: Log self attributes BEFORE injection
|
||||||
print(f"[ContextInjector.__call__] ENTRY: id={id(self)}, self.signature={self.signature}, self.today_date={self.today_date}, self.job_id={self.job_id}, self.session_id={self.session_id}")
|
print(f"[ContextInjector.__call__] ENTRY: id={id(self)}, self.signature={self.signature}, self.today_date={self.today_date}, self.job_id={self.job_id}, self.session_id={self.session_id}, self.trading_day_id={self.trading_day_id}")
|
||||||
print(f"[ContextInjector.__call__] Args BEFORE injection: {request.args}")
|
print(f"[ContextInjector.__call__] Args BEFORE injection: {request.args}")
|
||||||
|
|
||||||
# ALWAYS inject/override context parameters (don't trust AI-provided values)
|
# ALWAYS inject/override context parameters (don't trust AI-provided values)
|
||||||
@@ -60,6 +63,8 @@ class ContextInjector:
|
|||||||
request.args["job_id"] = self.job_id
|
request.args["job_id"] = self.job_id
|
||||||
if self.session_id:
|
if self.session_id:
|
||||||
request.args["session_id"] = self.session_id
|
request.args["session_id"] = self.session_id
|
||||||
|
if self.trading_day_id:
|
||||||
|
request.args["trading_day_id"] = self.trading_day_id
|
||||||
|
|
||||||
# Debug logging
|
# Debug logging
|
||||||
print(f"[ContextInjector] Tool: {request.name}, Args after injection: {request.args}")
|
print(f"[ContextInjector] Tool: {request.name}, Args after injection: {request.args}")
|
||||||
|
|||||||
@@ -91,12 +91,21 @@ def get_current_position_from_db(job_id: str, model: str, date: str) -> Tuple[Di
|
|||||||
|
|
||||||
|
|
||||||
def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Internal buy implementation - accepts injected context parameters.
|
Internal buy implementation - accepts injected context parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Stock symbol
|
||||||
|
amount: Number of shares
|
||||||
|
signature: Model signature (injected)
|
||||||
|
today_date: Trading date (injected)
|
||||||
|
job_id: Job ID (injected)
|
||||||
|
session_id: Session ID (injected, DEPRECATED)
|
||||||
|
trading_day_id: Trading day ID (injected)
|
||||||
|
|
||||||
This function is not exposed to the AI model. It receives runtime context
|
This function is not exposed to the AI model. It receives runtime context
|
||||||
(signature, today_date, job_id, session_id) from the ContextInjector.
|
(signature, today_date, job_id, session_id, trading_day_id) from the ContextInjector.
|
||||||
"""
|
"""
|
||||||
# Validate required parameters
|
# Validate required parameters
|
||||||
if not job_id:
|
if not job_id:
|
||||||
@@ -139,61 +148,29 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str =
|
|||||||
new_position["CASH"] = cash_left
|
new_position["CASH"] = cash_left
|
||||||
new_position[symbol] = new_position.get(symbol, 0) + amount
|
new_position[symbol] = new_position.get(symbol, 0) + amount
|
||||||
|
|
||||||
# Step 5: Calculate portfolio value and P&L
|
# Step 5: Write to actions table (NEW SCHEMA)
|
||||||
portfolio_value = cash_left
|
# NOTE: P&L is now calculated at the trading_days level, not per-trade
|
||||||
for sym, qty in new_position.items():
|
if trading_day_id is None:
|
||||||
if sym != "CASH":
|
# Get trading_day_id from runtime config if not provided
|
||||||
try:
|
from tools.general_tools import get_config_value
|
||||||
price = get_open_prices(today_date, [sym])[f'{sym}_price']
|
trading_day_id = get_config_value('TRADING_DAY_ID')
|
||||||
portfolio_value += qty * price
|
|
||||||
except KeyError:
|
|
||||||
pass # Symbol price not available, skip
|
|
||||||
|
|
||||||
# Get start-of-day portfolio value (action_id=0 for today) for P&L calculation
|
if trading_day_id is None:
|
||||||
cursor.execute("""
|
raise ValueError("trading_day_id not found in runtime config")
|
||||||
SELECT portfolio_value
|
|
||||||
FROM positions
|
|
||||||
WHERE job_id = ? AND model = ? AND date = ? AND action_id = 0
|
|
||||||
LIMIT 1
|
|
||||||
""", (job_id, signature, today_date))
|
|
||||||
|
|
||||||
row = cursor.fetchone()
|
|
||||||
|
|
||||||
if row:
|
|
||||||
# Compare to start of day (action_id=0)
|
|
||||||
start_of_day_value = row[0]
|
|
||||||
daily_profit = portfolio_value - start_of_day_value
|
|
||||||
daily_return_pct = (daily_profit / start_of_day_value * 100) if start_of_day_value > 0 else 0
|
|
||||||
else:
|
|
||||||
# First action of first day - no baseline yet
|
|
||||||
daily_profit = 0.0
|
|
||||||
daily_return_pct = 0.0
|
|
||||||
|
|
||||||
# Step 6: Write to positions table
|
|
||||||
created_at = datetime.utcnow().isoformat() + "Z"
|
created_at = datetime.utcnow().isoformat() + "Z"
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO positions (
|
INSERT INTO actions (
|
||||||
job_id, date, model, action_id, action_type, symbol,
|
trading_day_id, action_type, symbol, quantity, price, created_at
|
||||||
amount, price, cash, portfolio_value, daily_profit,
|
|
||||||
daily_return_pct, session_id, created_at
|
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
job_id, today_date, signature, next_action_id, "buy", symbol,
|
trading_day_id, "buy", symbol, amount, this_symbol_price, created_at
|
||||||
amount, this_symbol_price, cash_left, portfolio_value, daily_profit,
|
|
||||||
daily_return_pct, session_id, created_at
|
|
||||||
))
|
))
|
||||||
|
|
||||||
position_id = cursor.lastrowid
|
# NOTE: Holdings are written by BaseAgent at end of day, not per-trade
|
||||||
|
# This keeps the data model clean (one holdings snapshot per day)
|
||||||
# Step 7: Write to holdings table
|
|
||||||
for sym, qty in new_position.items():
|
|
||||||
if sym != "CASH":
|
|
||||||
cursor.execute("""
|
|
||||||
INSERT INTO holdings (position_id, symbol, quantity)
|
|
||||||
VALUES (?, ?, ?)
|
|
||||||
""", (position_id, sym, qty))
|
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
print(f"[buy] {signature} bought {amount} shares of {symbol} at ${this_symbol_price}")
|
print(f"[buy] {signature} bought {amount} shares of {symbol} at ${this_symbol_price}")
|
||||||
@@ -209,7 +186,7 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str =
|
|||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Buy stock shares.
|
Buy stock shares.
|
||||||
|
|
||||||
@@ -222,15 +199,14 @@ def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
|||||||
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
|
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
|
||||||
- Failure: {"error": error_message, ...}
|
- Failure: {"error": error_message, ...}
|
||||||
|
|
||||||
Note: signature, today_date, job_id, session_id are automatically injected by the system.
|
Note: signature, today_date, job_id, session_id, trading_day_id are
|
||||||
Do not provide these parameters - they will be added automatically.
|
automatically injected by the system. Do not provide these parameters.
|
||||||
"""
|
"""
|
||||||
# Delegate to internal implementation
|
return _buy_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id)
|
||||||
return _buy_impl(symbol, amount, signature, today_date, job_id, session_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Sell stock function - writes to SQLite database.
|
Sell stock function - writes to SQLite database.
|
||||||
|
|
||||||
@@ -240,7 +216,8 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
|
|||||||
signature: Model signature (injected by ContextInjector)
|
signature: Model signature (injected by ContextInjector)
|
||||||
today_date: Trading date YYYY-MM-DD (injected by ContextInjector)
|
today_date: Trading date YYYY-MM-DD (injected by ContextInjector)
|
||||||
job_id: Job UUID (injected by ContextInjector)
|
job_id: Job UUID (injected by ContextInjector)
|
||||||
session_id: Trading session ID (injected by ContextInjector)
|
session_id: Trading session ID (injected by ContextInjector, DEPRECATED)
|
||||||
|
trading_day_id: Trading day ID (injected by ContextInjector)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]:
|
Dict[str, Any]:
|
||||||
@@ -287,62 +264,26 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
|
|||||||
new_position[symbol] -= amount
|
new_position[symbol] -= amount
|
||||||
new_position["CASH"] = new_position.get("CASH", 0) + (this_symbol_price * amount)
|
new_position["CASH"] = new_position.get("CASH", 0) + (this_symbol_price * amount)
|
||||||
|
|
||||||
# Step 5: Calculate portfolio value and P&L
|
# Step 5: Write to actions table (NEW SCHEMA)
|
||||||
portfolio_value = new_position["CASH"]
|
# NOTE: P&L is now calculated at the trading_days level, not per-trade
|
||||||
for sym, qty in new_position.items():
|
if trading_day_id is None:
|
||||||
if sym != "CASH":
|
from tools.general_tools import get_config_value
|
||||||
try:
|
trading_day_id = get_config_value('TRADING_DAY_ID')
|
||||||
price = get_open_prices(today_date, [sym])[f'{sym}_price']
|
|
||||||
portfolio_value += qty * price
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Get start-of-day portfolio value (action_id=0 for today) for P&L calculation
|
if trading_day_id is None:
|
||||||
cursor.execute("""
|
raise ValueError("trading_day_id not found in runtime config")
|
||||||
SELECT portfolio_value
|
|
||||||
FROM positions
|
|
||||||
WHERE job_id = ? AND model = ? AND date = ? AND action_id = 0
|
|
||||||
LIMIT 1
|
|
||||||
""", (job_id, signature, today_date))
|
|
||||||
|
|
||||||
row = cursor.fetchone()
|
|
||||||
|
|
||||||
if row:
|
|
||||||
# Compare to start of day (action_id=0)
|
|
||||||
start_of_day_value = row[0]
|
|
||||||
daily_profit = portfolio_value - start_of_day_value
|
|
||||||
daily_return_pct = (daily_profit / start_of_day_value * 100) if start_of_day_value > 0 else 0
|
|
||||||
else:
|
|
||||||
# First action of first day - no baseline yet
|
|
||||||
daily_profit = 0.0
|
|
||||||
daily_return_pct = 0.0
|
|
||||||
|
|
||||||
# Step 6: Write to positions table
|
|
||||||
created_at = datetime.utcnow().isoformat() + "Z"
|
created_at = datetime.utcnow().isoformat() + "Z"
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO positions (
|
INSERT INTO actions (
|
||||||
job_id, date, model, action_id, action_type, symbol,
|
trading_day_id, action_type, symbol, quantity, price, created_at
|
||||||
amount, price, cash, portfolio_value, daily_profit,
|
|
||||||
daily_return_pct, session_id, created_at
|
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
job_id, today_date, signature, next_action_id, "sell", symbol,
|
trading_day_id, "sell", symbol, amount, this_symbol_price, created_at
|
||||||
amount, this_symbol_price, new_position["CASH"], portfolio_value, daily_profit,
|
|
||||||
daily_return_pct, session_id, created_at
|
|
||||||
))
|
))
|
||||||
|
|
||||||
position_id = cursor.lastrowid
|
|
||||||
|
|
||||||
# Step 7: Write to holdings table
|
|
||||||
for sym, qty in new_position.items():
|
|
||||||
if sym != "CASH":
|
|
||||||
cursor.execute("""
|
|
||||||
INSERT INTO holdings (position_id, symbol, quantity)
|
|
||||||
VALUES (?, ?, ?)
|
|
||||||
""", (position_id, sym, qty))
|
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
print(f"[sell] {signature} sold {amount} shares of {symbol} at ${this_symbol_price}")
|
print(f"[sell] {signature} sold {amount} shares of {symbol} at ${this_symbol_price}")
|
||||||
return new_position
|
return new_position
|
||||||
@@ -357,7 +298,7 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
|
|||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def sell(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
def sell(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Sell stock shares.
|
Sell stock shares.
|
||||||
|
|
||||||
@@ -370,11 +311,10 @@ def sell(symbol: str, amount: int, signature: str = None, today_date: str = None
|
|||||||
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
|
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
|
||||||
- Failure: {"error": error_message, ...}
|
- Failure: {"error": error_message, ...}
|
||||||
|
|
||||||
Note: signature, today_date, job_id, session_id are automatically injected by the system.
|
Note: signature, today_date, job_id, session_id, trading_day_id are
|
||||||
Do not provide these parameters - they will be added automatically.
|
automatically injected by the system. Do not provide these parameters.
|
||||||
"""
|
"""
|
||||||
# Delegate to internal implementation
|
return _sell_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id)
|
||||||
return _sell_impl(symbol, amount, signature, today_date, job_id, session_id)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -138,11 +138,15 @@ class ModelDayExecutor:
|
|||||||
|
|
||||||
# Create and inject context with correct values
|
# Create and inject context with correct values
|
||||||
from agent.context_injector import ContextInjector
|
from agent.context_injector import ContextInjector
|
||||||
|
from tools.general_tools import get_config_value
|
||||||
|
trading_day_id = get_config_value('TRADING_DAY_ID') # Get from runtime config
|
||||||
|
|
||||||
context_injector = ContextInjector(
|
context_injector = ContextInjector(
|
||||||
signature=self.model_sig,
|
signature=self.model_sig,
|
||||||
today_date=self.date, # Current trading day
|
today_date=self.date, # Current trading day
|
||||||
job_id=self.job_id,
|
job_id=self.job_id,
|
||||||
session_id=session_id
|
session_id=session_id,
|
||||||
|
trading_day_id=trading_day_id
|
||||||
)
|
)
|
||||||
logger.info(f"[DEBUG] ModelDayExecutor: Created ContextInjector with signature={self.model_sig}, date={self.date}, job_id={self.job_id}, session_id={session_id}")
|
logger.info(f"[DEBUG] ModelDayExecutor: Created ContextInjector with signature={self.model_sig}, date={self.date}, job_id={self.job_id}, session_id={session_id}")
|
||||||
logger.info(f"[DEBUG] ModelDayExecutor: Calling await agent.set_context()")
|
logger.info(f"[DEBUG] ModelDayExecutor: Calling await agent.set_context()")
|
||||||
|
|||||||
@@ -48,7 +48,8 @@ class RuntimeConfigManager:
|
|||||||
self,
|
self,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
model_sig: str,
|
model_sig: str,
|
||||||
date: str
|
date: str,
|
||||||
|
trading_day_id: int = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create isolated runtime config file for this execution.
|
Create isolated runtime config file for this execution.
|
||||||
@@ -57,6 +58,7 @@ class RuntimeConfigManager:
|
|||||||
job_id: Job UUID
|
job_id: Job UUID
|
||||||
model_sig: Model signature
|
model_sig: Model signature
|
||||||
date: Trading date (YYYY-MM-DD)
|
date: Trading date (YYYY-MM-DD)
|
||||||
|
trading_day_id: Trading day record ID (optional, can be set later)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to created runtime config file
|
Path to created runtime config file
|
||||||
@@ -79,7 +81,8 @@ class RuntimeConfigManager:
|
|||||||
"TODAY_DATE": date,
|
"TODAY_DATE": date,
|
||||||
"SIGNATURE": model_sig,
|
"SIGNATURE": model_sig,
|
||||||
"IF_TRADE": False,
|
"IF_TRADE": False,
|
||||||
"JOB_ID": job_id
|
"JOB_ID": job_id,
|
||||||
|
"TRADING_DAY_ID": trading_day_id
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
|||||||
216
tests/unit/test_trade_tools_new_schema.py
Normal file
216
tests/unit/test_trade_tools_new_schema.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""Test trade tools write to new schema (actions table)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
from agent_tools.tool_trade import _buy_impl, _sell_impl
|
||||||
|
from api.database import Database
|
||||||
|
from tools.deployment_config import get_db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_db():
|
||||||
|
"""Create test database with new schema."""
|
||||||
|
db_path = ":memory:"
|
||||||
|
db = Database(db_path)
|
||||||
|
|
||||||
|
# Create jobs table (prerequisite)
|
||||||
|
db.connection.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS jobs (
|
||||||
|
job_id TEXT PRIMARY KEY,
|
||||||
|
config_path TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
date_range TEXT NOT NULL,
|
||||||
|
models TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
db.connection.execute("""
|
||||||
|
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
||||||
|
VALUES ('test-job-123', 'test_config.json', 'running', '2025-01-15', '["test-model"]', '2025-01-15T10:00:00Z')
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create trading_days record
|
||||||
|
trading_day_id = db.create_trading_day(
|
||||||
|
job_id='test-job-123',
|
||||||
|
model='test-model',
|
||||||
|
date='2025-01-15',
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=10000.0,
|
||||||
|
ending_portfolio_value=10000.0,
|
||||||
|
days_since_last_trading=0
|
||||||
|
)
|
||||||
|
|
||||||
|
db.connection.commit()
|
||||||
|
|
||||||
|
yield db, trading_day_id
|
||||||
|
|
||||||
|
db.connection.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_buy_writes_to_actions_table(test_db, monkeypatch):
|
||||||
|
"""Test buy() writes action record to actions table."""
|
||||||
|
db, trading_day_id = test_db
|
||||||
|
|
||||||
|
# Create a mock connection wrapper that doesn't actually close
|
||||||
|
class MockConnection:
|
||||||
|
def __init__(self, real_conn):
|
||||||
|
self.real_conn = real_conn
|
||||||
|
|
||||||
|
def cursor(self):
|
||||||
|
return self.real_conn.cursor()
|
||||||
|
|
||||||
|
def execute(self, *args, **kwargs):
|
||||||
|
return self.real_conn.execute(*args, **kwargs)
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
return self.real_conn.commit()
|
||||||
|
|
||||||
|
def rollback(self):
|
||||||
|
return self.real_conn.rollback()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass # Don't actually close the connection
|
||||||
|
|
||||||
|
mock_conn = MockConnection(db.connection)
|
||||||
|
|
||||||
|
# Mock get_db_connection to return our mock connection
|
||||||
|
monkeypatch.setattr('agent_tools.tool_trade.get_db_connection',
|
||||||
|
lambda x: mock_conn)
|
||||||
|
|
||||||
|
# Mock get_current_position_from_db to return starting position
|
||||||
|
monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db',
|
||||||
|
lambda job_id, sig, date: ({'CASH': 10000.0}, 0))
|
||||||
|
|
||||||
|
# Mock runtime config
|
||||||
|
monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime.json')
|
||||||
|
|
||||||
|
# Create mock runtime config file
|
||||||
|
import json
|
||||||
|
with open('/tmp/test_runtime.json', 'w') as f:
|
||||||
|
json.dump({
|
||||||
|
'TODAY_DATE': '2025-01-15',
|
||||||
|
'SIGNATURE': 'test-model',
|
||||||
|
'JOB_ID': 'test-job-123',
|
||||||
|
'TRADING_DAY_ID': trading_day_id
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
# Mock price data
|
||||||
|
monkeypatch.setattr('agent_tools.tool_trade.get_open_prices',
|
||||||
|
lambda date, symbols: {'AAPL_price': 150.0})
|
||||||
|
|
||||||
|
# Execute buy
|
||||||
|
result = _buy_impl(
|
||||||
|
symbol='AAPL',
|
||||||
|
amount=10,
|
||||||
|
signature='test-model',
|
||||||
|
today_date='2025-01-15',
|
||||||
|
job_id='test-job-123',
|
||||||
|
trading_day_id=trading_day_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if there was an error
|
||||||
|
if 'error' in result:
|
||||||
|
print(f"Buy failed with error: {result}")
|
||||||
|
|
||||||
|
# Verify action record created
|
||||||
|
cursor = db.connection.execute("""
|
||||||
|
SELECT action_type, symbol, quantity, price, trading_day_id
|
||||||
|
FROM actions
|
||||||
|
WHERE trading_day_id = ?
|
||||||
|
""", (trading_day_id,))
|
||||||
|
|
||||||
|
row = cursor.fetchone()
|
||||||
|
assert row is not None, "Action record should exist"
|
||||||
|
assert row[0] == 'buy'
|
||||||
|
assert row[1] == 'AAPL'
|
||||||
|
assert row[2] == 10
|
||||||
|
assert row[3] == 150.0
|
||||||
|
assert row[4] == trading_day_id
|
||||||
|
|
||||||
|
# Verify NO write to old positions table
|
||||||
|
cursor = db.connection.execute("""
|
||||||
|
SELECT name FROM sqlite_master
|
||||||
|
WHERE type='table' AND name='positions'
|
||||||
|
""")
|
||||||
|
assert cursor.fetchone() is None, "Old positions table should not exist"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sell_writes_to_actions_table(test_db, monkeypatch):
|
||||||
|
"""Test sell() writes action record to actions table."""
|
||||||
|
db, trading_day_id = test_db
|
||||||
|
|
||||||
|
# Setup: Create starting holdings
|
||||||
|
db.create_holding(trading_day_id, 'AAPL', 10)
|
||||||
|
db.connection.commit()
|
||||||
|
|
||||||
|
# Create a mock connection wrapper that doesn't actually close
|
||||||
|
class MockConnection:
|
||||||
|
def __init__(self, real_conn):
|
||||||
|
self.real_conn = real_conn
|
||||||
|
|
||||||
|
def cursor(self):
|
||||||
|
return self.real_conn.cursor()
|
||||||
|
|
||||||
|
def execute(self, *args, **kwargs):
|
||||||
|
return self.real_conn.execute(*args, **kwargs)
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
return self.real_conn.commit()
|
||||||
|
|
||||||
|
def rollback(self):
|
||||||
|
return self.real_conn.rollback()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass # Don't actually close the connection
|
||||||
|
|
||||||
|
mock_conn = MockConnection(db.connection)
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
monkeypatch.setattr('agent_tools.tool_trade.get_db_connection',
|
||||||
|
lambda x: mock_conn)
|
||||||
|
|
||||||
|
# Mock get_current_position_from_db to return position with AAPL shares
|
||||||
|
monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db',
|
||||||
|
lambda job_id, sig, date: ({'CASH': 10000.0, 'AAPL': 10}, 0))
|
||||||
|
|
||||||
|
monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime.json')
|
||||||
|
|
||||||
|
import json
|
||||||
|
with open('/tmp/test_runtime.json', 'w') as f:
|
||||||
|
json.dump({
|
||||||
|
'TODAY_DATE': '2025-01-15',
|
||||||
|
'SIGNATURE': 'test-model',
|
||||||
|
'JOB_ID': 'test-job-123',
|
||||||
|
'TRADING_DAY_ID': trading_day_id
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
monkeypatch.setattr('agent_tools.tool_trade.get_open_prices',
|
||||||
|
lambda date, symbols: {'AAPL_price': 160.0})
|
||||||
|
|
||||||
|
# Execute sell
|
||||||
|
result = _sell_impl(
|
||||||
|
symbol='AAPL',
|
||||||
|
amount=5,
|
||||||
|
signature='test-model',
|
||||||
|
today_date='2025-01-15',
|
||||||
|
job_id='test-job-123',
|
||||||
|
trading_day_id=trading_day_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify action record created
|
||||||
|
cursor = db.connection.execute("""
|
||||||
|
SELECT action_type, symbol, quantity, price
|
||||||
|
FROM actions
|
||||||
|
WHERE trading_day_id = ? AND action_type = 'sell'
|
||||||
|
""", (trading_day_id,))
|
||||||
|
|
||||||
|
row = cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row[0] == 'sell'
|
||||||
|
assert row[1] == 'AAPL'
|
||||||
|
assert row[2] == 5
|
||||||
|
assert row[3] == 160.0
|
||||||
Reference in New Issue
Block a user