mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-12 21:47:23 -04:00
refactor: update get_current_position_from_db to query new schema
This commit is contained in:
@@ -20,71 +20,69 @@ from datetime import datetime, timezone
|
|||||||
mcp = FastMCP("TradeTools")
|
mcp = FastMCP("TradeTools")
|
||||||
|
|
||||||
|
|
||||||
def get_current_position_from_db(job_id: str, model: str, date: str) -> Tuple[Dict[str, float], int]:
|
def get_current_position_from_db(
|
||||||
|
job_id: str,
|
||||||
|
model: str,
|
||||||
|
date: str,
|
||||||
|
initial_cash: float = 10000.0
|
||||||
|
) -> Tuple[Dict[str, float], int]:
|
||||||
"""
|
"""
|
||||||
Query current position from SQLite database.
|
Get current position from database (new schema).
|
||||||
|
|
||||||
|
Queries most recent trading_day record for this job+model up to date.
|
||||||
|
Returns ending holdings and cash from that day.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
job_id: Job UUID
|
job_id: Job UUID
|
||||||
model: Model signature
|
model: Model signature
|
||||||
date: Trading date (YYYY-MM-DD)
|
date: Current trading date
|
||||||
|
initial_cash: Initial cash if no prior data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (position_dict, next_action_id)
|
(position_dict, action_count) where:
|
||||||
- position_dict: {symbol: quantity, "CASH": amount}
|
- position_dict: {"AAPL": 10, "MSFT": 5, "CASH": 8500.0}
|
||||||
- next_action_id: Next available action_id for this job+model
|
- action_count: Number of holdings (for action_id tracking)
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If database query fails
|
|
||||||
"""
|
"""
|
||||||
db_path = "data/jobs.db"
|
db_path = "data/trading.db"
|
||||||
conn = get_db_connection(db_path)
|
conn = get_db_connection(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get most recent position on or before this date
|
# Query most recent trading_day up to date
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT p.id, p.cash
|
SELECT id, ending_cash
|
||||||
FROM positions p
|
FROM trading_days
|
||||||
WHERE p.job_id = ? AND p.model = ? AND p.date <= ?
|
WHERE job_id = ? AND model = ? AND date <= ?
|
||||||
ORDER BY p.date DESC, p.action_id DESC
|
ORDER BY date DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
""", (job_id, model, date))
|
""", (job_id, model, date))
|
||||||
|
|
||||||
position_row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
|
|
||||||
if not position_row:
|
if row is None:
|
||||||
# No position found - this shouldn't happen if ModelDayExecutor initializes properly
|
# First day - return initial position
|
||||||
raise Exception(f"No position found for job_id={job_id}, model={model}, date={date}")
|
return {"CASH": initial_cash}, 0
|
||||||
|
|
||||||
position_id = position_row[0]
|
trading_day_id, ending_cash = row
|
||||||
cash = position_row[1]
|
|
||||||
|
|
||||||
# Build position dict starting with CASH
|
# Query holdings for that day
|
||||||
position_dict = {"CASH": cash}
|
|
||||||
|
|
||||||
# Get holdings for this position
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT symbol, quantity
|
SELECT symbol, quantity
|
||||||
FROM holdings
|
FROM holdings
|
||||||
WHERE position_id = ?
|
WHERE trading_day_id = ?
|
||||||
""", (position_id,))
|
""", (trading_day_id,))
|
||||||
|
|
||||||
for row in cursor.fetchall():
|
holdings_rows = cursor.fetchall()
|
||||||
symbol = row[0]
|
|
||||||
quantity = row[1]
|
|
||||||
position_dict[symbol] = quantity
|
|
||||||
|
|
||||||
# Get next action_id
|
# Build position dict
|
||||||
cursor.execute("""
|
position = {"CASH": ending_cash}
|
||||||
SELECT COALESCE(MAX(action_id), -1) + 1 as next_action_id
|
for symbol, quantity in holdings_rows:
|
||||||
FROM positions
|
position[symbol] = quantity
|
||||||
WHERE job_id = ? AND model = ?
|
|
||||||
""", (job_id, model))
|
|
||||||
|
|
||||||
next_action_id = cursor.fetchone()[0]
|
# Action count is number of holdings (used for action_id)
|
||||||
|
action_count = len(holdings_rows)
|
||||||
|
|
||||||
return position_dict, next_action_id
|
return position, action_count
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|||||||
97
tests/unit/test_get_position_new_schema.py
Normal file
97
tests/unit/test_get_position_new_schema.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Test get_current_position_from_db queries new schema."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from agent_tools.tool_trade import get_current_position_from_db
|
||||||
|
from api.database import Database
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_position_from_new_schema():
|
||||||
|
"""Test position retrieval from trading_days + holdings."""
|
||||||
|
|
||||||
|
# Create test database
|
||||||
|
db = Database(":memory:")
|
||||||
|
|
||||||
|
# Create prerequisite: jobs record
|
||||||
|
db.connection.execute("""
|
||||||
|
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
||||||
|
VALUES ('test-job-123', 'test_config.json', 'running', '2025-01-15 to 2025-01-15', 'test-model', '2025-01-15T10:00:00Z')
|
||||||
|
""")
|
||||||
|
db.connection.commit()
|
||||||
|
|
||||||
|
# Create trading_day with holdings
|
||||||
|
trading_day_id = db.create_trading_day(
|
||||||
|
job_id='test-job-123',
|
||||||
|
model='test-model',
|
||||||
|
date='2025-01-15',
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=8000.0,
|
||||||
|
ending_portfolio_value=9500.0,
|
||||||
|
days_since_last_trading=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add ending holdings
|
||||||
|
db.create_holding(trading_day_id, 'AAPL', 10)
|
||||||
|
db.create_holding(trading_day_id, 'MSFT', 5)
|
||||||
|
|
||||||
|
db.connection.commit()
|
||||||
|
|
||||||
|
# Mock get_db_connection to return our test db
|
||||||
|
import agent_tools.tool_trade as trade_module
|
||||||
|
original_get_db_connection = trade_module.get_db_connection
|
||||||
|
|
||||||
|
def mock_get_db_connection(path):
|
||||||
|
return db.connection
|
||||||
|
|
||||||
|
trade_module.get_db_connection = mock_get_db_connection
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Query position
|
||||||
|
position, action_id = get_current_position_from_db(
|
||||||
|
job_id='test-job-123',
|
||||||
|
model='test-model',
|
||||||
|
date='2025-01-15'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
assert position['AAPL'] == 10
|
||||||
|
assert position['MSFT'] == 5
|
||||||
|
assert position['CASH'] == 8000.0
|
||||||
|
assert action_id == 2 # 2 holdings = 2 actions
|
||||||
|
finally:
|
||||||
|
# Restore original function
|
||||||
|
trade_module.get_db_connection = original_get_db_connection
|
||||||
|
db.connection.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_position_first_day():
|
||||||
|
"""Test position retrieval on first day (no prior data)."""
|
||||||
|
|
||||||
|
db = Database(":memory:")
|
||||||
|
|
||||||
|
# Mock get_db_connection to return our test db
|
||||||
|
import agent_tools.tool_trade as trade_module
|
||||||
|
original_get_db_connection = trade_module.get_db_connection
|
||||||
|
|
||||||
|
def mock_get_db_connection(path):
|
||||||
|
return db.connection
|
||||||
|
|
||||||
|
trade_module.get_db_connection = mock_get_db_connection
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Query position (no data exists)
|
||||||
|
position, action_id = get_current_position_from_db(
|
||||||
|
job_id='test-job-123',
|
||||||
|
model='test-model',
|
||||||
|
date='2025-01-15'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return initial position
|
||||||
|
assert position['CASH'] == 10000.0 # Default initial cash
|
||||||
|
assert action_id == 0
|
||||||
|
finally:
|
||||||
|
# Restore original function
|
||||||
|
trade_module.get_db_connection = original_get_db_connection
|
||||||
|
db.connection.close()
|
||||||
Reference in New Issue
Block a user