mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 01:27:24 -04:00
Compare commits
3 Commits
v0.4.0-alp
...
v0.4.0-alp
| Author | SHA1 | Date | |
|---|---|---|---|
| aa16480158 | |||
| 05620facc2 | |||
| 7c71a047bc |
@@ -538,6 +538,10 @@ Summary:"""
|
||||
from tools.general_tools import write_config_value
|
||||
write_config_value('TRADING_DAY_ID', trading_day_id)
|
||||
|
||||
# Update context_injector with trading_day_id for MCP tools
|
||||
if self.context_injector:
|
||||
self.context_injector.trading_day_id = trading_day_id
|
||||
|
||||
# 6. Run AI trading session
|
||||
action_count = 0
|
||||
|
||||
@@ -660,8 +664,6 @@ Summary:"""
|
||||
|
||||
async def _handle_trading_result(self, today_date: str) -> None:
|
||||
"""Handle trading results with database writes."""
|
||||
from tools.price_tools import add_no_trade_record_to_db
|
||||
|
||||
if_trade = get_config_value("IF_TRADE")
|
||||
|
||||
if if_trade:
|
||||
@@ -669,23 +671,10 @@ Summary:"""
|
||||
print("✅ Trading completed")
|
||||
else:
|
||||
print("📊 No trading, maintaining positions")
|
||||
|
||||
# Get context from runtime config
|
||||
job_id = get_config_value("JOB_ID")
|
||||
session_id = self.context_injector.session_id if self.context_injector else None
|
||||
|
||||
if not job_id or not session_id:
|
||||
raise ValueError("Missing JOB_ID or session_id for no-trade record")
|
||||
|
||||
# Write no-trade record to database
|
||||
add_no_trade_record_to_db(
|
||||
today_date,
|
||||
self.signature,
|
||||
job_id,
|
||||
session_id
|
||||
)
|
||||
|
||||
write_config_value("IF_TRADE", False)
|
||||
|
||||
# Note: In new schema, trading_day record is created at session start
|
||||
# and updated at session end, so no separate no-trade record needed
|
||||
|
||||
def register_agent(self) -> None:
|
||||
"""Register new agent, create initial positions"""
|
||||
|
||||
@@ -28,16 +28,17 @@ def get_current_position_from_db(
|
||||
initial_cash: float = 10000.0
|
||||
) -> Tuple[Dict[str, float], int]:
|
||||
"""
|
||||
Get current position from database (new schema).
|
||||
Get starting position for current trading day from database (new schema).
|
||||
|
||||
Queries most recent trading_day record for this job+model up to date.
|
||||
Returns ending holdings and cash from that day.
|
||||
Queries most recent trading_day record BEFORE the given date (previous day's ending).
|
||||
Returns ending holdings and cash from that previous day, which becomes the
|
||||
starting position for the current day.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
model: Model signature
|
||||
date: Current trading date
|
||||
initial_cash: Initial cash if no prior data
|
||||
date: Current trading date (will query for date < this)
|
||||
initial_cash: Initial cash if no prior data (first trading day)
|
||||
|
||||
Returns:
|
||||
(position_dict, action_count) where:
|
||||
@@ -49,11 +50,11 @@ def get_current_position_from_db(
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Query most recent trading_day up to date
|
||||
# Query most recent trading_day BEFORE current date (previous day's ending position)
|
||||
cursor.execute("""
|
||||
SELECT id, ending_cash
|
||||
FROM trading_days
|
||||
WHERE job_id = ? AND model = ? AND date <= ?
|
||||
WHERE job_id = ? AND model = ? AND date < ?
|
||||
ORDER BY date DESC
|
||||
LIMIT 1
|
||||
""", (job_id, model, date))
|
||||
|
||||
@@ -6,7 +6,7 @@ from api.database import Database
|
||||
|
||||
|
||||
def test_get_position_from_new_schema():
|
||||
"""Test position retrieval from trading_days + holdings."""
|
||||
"""Test position retrieval from trading_days + holdings (previous day)."""
|
||||
|
||||
# Create test database
|
||||
db = Database(":memory:")
|
||||
@@ -14,11 +14,11 @@ def test_get_position_from_new_schema():
|
||||
# Create prerequisite: jobs record
|
||||
db.connection.execute("""
|
||||
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
||||
VALUES ('test-job-123', 'test_config.json', 'running', '2025-01-15 to 2025-01-15', 'test-model', '2025-01-15T10:00:00Z')
|
||||
VALUES ('test-job-123', 'test_config.json', 'running', '2025-01-14 to 2025-01-16', 'test-model', '2025-01-14T10:00:00Z')
|
||||
""")
|
||||
db.connection.commit()
|
||||
|
||||
# Create trading_day with holdings
|
||||
# Create trading_day with holdings for 2025-01-15
|
||||
trading_day_id = db.create_trading_day(
|
||||
job_id='test-job-123',
|
||||
model='test-model',
|
||||
@@ -32,7 +32,7 @@ def test_get_position_from_new_schema():
|
||||
days_since_last_trading=0
|
||||
)
|
||||
|
||||
# Add ending holdings
|
||||
# Add ending holdings for 2025-01-15
|
||||
db.create_holding(trading_day_id, 'AAPL', 10)
|
||||
db.create_holding(trading_day_id, 'MSFT', 5)
|
||||
|
||||
@@ -48,18 +48,19 @@ def test_get_position_from_new_schema():
|
||||
trade_module.get_db_connection = mock_get_db_connection
|
||||
|
||||
try:
|
||||
# Query position
|
||||
# Query position for NEXT day (2025-01-16)
|
||||
# Should retrieve previous day's (2025-01-15) ending position
|
||||
position, action_id = get_current_position_from_db(
|
||||
job_id='test-job-123',
|
||||
model='test-model',
|
||||
date='2025-01-15'
|
||||
date='2025-01-16' # Query for day AFTER the trading_day record
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert position['AAPL'] == 10
|
||||
assert position['MSFT'] == 5
|
||||
assert position['CASH'] == 8000.0
|
||||
assert action_id == 2 # 2 holdings = 2 actions
|
||||
# Verify we got the previous day's ending position
|
||||
assert position['AAPL'] == 10, f"Expected 10 AAPL but got {position.get('AAPL', 0)}"
|
||||
assert position['MSFT'] == 5, f"Expected 5 MSFT but got {position.get('MSFT', 0)}"
|
||||
assert position['CASH'] == 8000.0, f"Expected cash $8000 but got ${position['CASH']}"
|
||||
assert action_id == 2, f"Expected 2 holdings but got {action_id}"
|
||||
finally:
|
||||
# Restore original function
|
||||
trade_module.get_db_connection = original_get_db_connection
|
||||
@@ -95,3 +96,99 @@ def test_get_position_first_day():
|
||||
# Restore original function
|
||||
trade_module.get_db_connection = original_get_db_connection
|
||||
db.connection.close()
|
||||
|
||||
|
||||
def test_get_position_retrieves_previous_day_not_current():
|
||||
"""Test that get_current_position_from_db queries PREVIOUS day's ending, not current day.
|
||||
|
||||
This is the critical fix: when querying for day 2's starting position,
|
||||
it should return day 1's ending position, NOT day 2's (incomplete) position.
|
||||
"""
|
||||
|
||||
db = Database(":memory:")
|
||||
|
||||
# Create prerequisite: jobs record
|
||||
db.connection.execute("""
|
||||
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
||||
VALUES ('test-job-123', 'test_config.json', 'running', '2025-10-01 to 2025-10-03', 'gpt-5', '2025-10-01T10:00:00Z')
|
||||
""")
|
||||
db.connection.commit()
|
||||
|
||||
# Day 1: Create complete trading day with holdings
|
||||
day1_id = db.create_trading_day(
|
||||
job_id='test-job-123',
|
||||
model='gpt-5',
|
||||
date='2025-10-02',
|
||||
starting_cash=10000.0,
|
||||
starting_portfolio_value=10000.0,
|
||||
daily_profit=0.0,
|
||||
daily_return_pct=0.0,
|
||||
ending_cash=2500.0, # After buying stocks
|
||||
ending_portfolio_value=10000.0,
|
||||
days_since_last_trading=1
|
||||
)
|
||||
|
||||
# Day 1 ending holdings (7 AMZN, 5 GOOGL, 6 MU, 3 QCOM, 4 MSFT, 1 CRWD, 10 NVDA, 3 AVGO)
|
||||
db.create_holding(day1_id, 'AMZN', 7)
|
||||
db.create_holding(day1_id, 'GOOGL', 5)
|
||||
db.create_holding(day1_id, 'MU', 6)
|
||||
db.create_holding(day1_id, 'QCOM', 3)
|
||||
db.create_holding(day1_id, 'MSFT', 4)
|
||||
db.create_holding(day1_id, 'CRWD', 1)
|
||||
db.create_holding(day1_id, 'NVDA', 10)
|
||||
db.create_holding(day1_id, 'AVGO', 3)
|
||||
|
||||
# Day 2: Create incomplete trading day (just started, no holdings yet)
|
||||
day2_id = db.create_trading_day(
|
||||
job_id='test-job-123',
|
||||
model='gpt-5',
|
||||
date='2025-10-03',
|
||||
starting_cash=2500.0, # From day 1 ending
|
||||
starting_portfolio_value=10000.0,
|
||||
daily_profit=0.0,
|
||||
daily_return_pct=0.0,
|
||||
ending_cash=2500.0, # Not finalized yet
|
||||
ending_portfolio_value=10000.0, # Not finalized yet
|
||||
days_since_last_trading=1
|
||||
)
|
||||
# NOTE: No holdings created for day 2 yet (trading in progress)
|
||||
|
||||
db.connection.commit()
|
||||
|
||||
# Mock get_db_connection to return our test db
|
||||
import agent_tools.tool_trade as trade_module
|
||||
original_get_db_connection = trade_module.get_db_connection
|
||||
|
||||
def mock_get_db_connection(path):
|
||||
return db.connection
|
||||
|
||||
trade_module.get_db_connection = mock_get_db_connection
|
||||
|
||||
try:
|
||||
# Query starting position for day 2 (2025-10-03)
|
||||
# This should return day 1's ending position, NOT day 2's incomplete position
|
||||
position, action_id = get_current_position_from_db(
|
||||
job_id='test-job-123',
|
||||
model='gpt-5',
|
||||
date='2025-10-03'
|
||||
)
|
||||
|
||||
# Verify we got day 1's ending position (8 holdings)
|
||||
assert position['CASH'] == 2500.0, f"Expected cash $2500 but got ${position['CASH']}"
|
||||
assert position['AMZN'] == 7, f"Expected 7 AMZN but got {position.get('AMZN', 0)}"
|
||||
assert position['GOOGL'] == 5, f"Expected 5 GOOGL but got {position.get('GOOGL', 0)}"
|
||||
assert position['MU'] == 6, f"Expected 6 MU but got {position.get('MU', 0)}"
|
||||
assert position['QCOM'] == 3, f"Expected 3 QCOM but got {position.get('QCOM', 0)}"
|
||||
assert position['MSFT'] == 4, f"Expected 4 MSFT but got {position.get('MSFT', 0)}"
|
||||
assert position['CRWD'] == 1, f"Expected 1 CRWD but got {position.get('CRWD', 0)}"
|
||||
assert position['NVDA'] == 10, f"Expected 10 NVDA but got {position.get('NVDA', 0)}"
|
||||
assert position['AVGO'] == 3, f"Expected 3 AVGO but got {position.get('AVGO', 0)}"
|
||||
assert action_id == 8, f"Expected 8 holdings but got {action_id}"
|
||||
|
||||
# Verify total holdings count (should NOT include day 2's empty holdings)
|
||||
assert len(position) == 9, f"Expected 9 items (8 stocks + CASH) but got {len(position)}"
|
||||
|
||||
finally:
|
||||
# Restore original function
|
||||
trade_module.get_db_connection = original_get_db_connection
|
||||
db.connection.close()
|
||||
|
||||
@@ -337,12 +337,12 @@ def get_today_init_position_from_db(
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Get most recent position before today
|
||||
# Get most recent trading day before today
|
||||
cursor.execute("""
|
||||
SELECT p.id, p.cash
|
||||
FROM positions p
|
||||
WHERE p.job_id = ? AND p.model = ? AND p.date < ?
|
||||
ORDER BY p.date DESC, p.action_id DESC
|
||||
SELECT id, ending_cash
|
||||
FROM trading_days
|
||||
WHERE job_id = ? AND model = ? AND date < ?
|
||||
ORDER BY date DESC
|
||||
LIMIT 1
|
||||
""", (job_id, modelname, today_date))
|
||||
|
||||
@@ -353,15 +353,15 @@ def get_today_init_position_from_db(
|
||||
logger.info(f"No previous position found for {modelname}, returning initial cash")
|
||||
return {"CASH": 10000.0}
|
||||
|
||||
position_id, cash = row
|
||||
trading_day_id, cash = row
|
||||
position_dict = {"CASH": cash}
|
||||
|
||||
# Get holdings for this position
|
||||
# Get holdings for this trading day
|
||||
cursor.execute("""
|
||||
SELECT symbol, quantity
|
||||
FROM holdings
|
||||
WHERE position_id = ?
|
||||
""", (position_id,))
|
||||
WHERE trading_day_id = ?
|
||||
""", (trading_day_id,))
|
||||
|
||||
for symbol, quantity in cursor.fetchall():
|
||||
position_dict[symbol] = quantity
|
||||
|
||||
Reference in New Issue
Block a user