mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-03 01:47:23 -04:00
fix: add TRADING_DAY_ID write to runtime config and improve test coverage
Changes: - Write TRADING_DAY_ID to runtime config after creating trading_day record in BaseAgent - Fix datetime deprecation warnings by replacing datetime.utcnow() with datetime.now(timezone.utc) - Add test for trading_day_id=None fallback path to verify runtime config lookup works correctly This ensures trade tools can access trading_day_id from runtime config when not explicitly passed.
This commit is contained in:
@@ -534,6 +534,10 @@ Summary:"""
|
||||
days_since_last_trading=pnl_metrics["days_since_last_trading"]
|
||||
)
|
||||
|
||||
# Write trading_day_id to runtime config for trade tools
|
||||
from tools.general_tools import write_config_value
|
||||
write_config_value('TRADING_DAY_ID', trading_day_id)
|
||||
|
||||
# 6. Run AI trading session
|
||||
action_count = 0
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ sys.path.insert(0, project_root)
|
||||
from tools.price_tools import get_open_prices
|
||||
import json
|
||||
from api.database import get_db_connection
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
mcp = FastMCP("TradeTools")
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str =
|
||||
if trading_day_id is None:
|
||||
raise ValueError("trading_day_id not found in runtime config")
|
||||
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO actions (
|
||||
@@ -273,7 +273,7 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
|
||||
if trading_day_id is None:
|
||||
raise ValueError("trading_day_id not found in runtime config")
|
||||
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO actions (
|
||||
|
||||
@@ -139,6 +139,87 @@ def test_buy_writes_to_actions_table(test_db, monkeypatch):
|
||||
assert cursor.fetchone() is None, "Old positions table should not exist"
|
||||
|
||||
|
||||
def test_buy_with_none_trading_day_id_reads_from_config(test_db, monkeypatch):
|
||||
"""Test buy() with trading_day_id=None fallback reads from runtime config."""
|
||||
db, trading_day_id = test_db
|
||||
|
||||
# Create a mock connection wrapper that doesn't actually close
|
||||
class MockConnection:
|
||||
def __init__(self, real_conn):
|
||||
self.real_conn = real_conn
|
||||
|
||||
def cursor(self):
|
||||
return self.real_conn.cursor()
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
return self.real_conn.execute(*args, **kwargs)
|
||||
|
||||
def commit(self):
|
||||
return self.real_conn.commit()
|
||||
|
||||
def rollback(self):
|
||||
return self.real_conn.rollback()
|
||||
|
||||
def close(self):
|
||||
pass # Don't actually close the connection
|
||||
|
||||
mock_conn = MockConnection(db.connection)
|
||||
|
||||
# Mock get_db_connection to return our mock connection
|
||||
monkeypatch.setattr('agent_tools.tool_trade.get_db_connection',
|
||||
lambda x: mock_conn)
|
||||
|
||||
# Mock get_current_position_from_db to return starting position
|
||||
monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db',
|
||||
lambda job_id, sig, date: ({'CASH': 10000.0}, 0))
|
||||
|
||||
# Mock runtime config
|
||||
monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime_fallback.json')
|
||||
|
||||
# Create mock runtime config file with TRADING_DAY_ID
|
||||
import json
|
||||
with open('/tmp/test_runtime_fallback.json', 'w') as f:
|
||||
json.dump({
|
||||
'TODAY_DATE': '2025-01-15',
|
||||
'SIGNATURE': 'test-model',
|
||||
'JOB_ID': 'test-job-123',
|
||||
'TRADING_DAY_ID': trading_day_id
|
||||
}, f)
|
||||
|
||||
# Mock price data
|
||||
monkeypatch.setattr('agent_tools.tool_trade.get_open_prices',
|
||||
lambda date, symbols: {'AAPL_price': 150.0})
|
||||
|
||||
# Execute buy with trading_day_id=None to force config lookup
|
||||
result = _buy_impl(
|
||||
symbol='AAPL',
|
||||
amount=10,
|
||||
signature='test-model',
|
||||
today_date='2025-01-15',
|
||||
job_id='test-job-123',
|
||||
trading_day_id=None # Force fallback to runtime config
|
||||
)
|
||||
|
||||
# Check if there was an error
|
||||
if 'error' in result:
|
||||
print(f"Buy failed with error: {result}")
|
||||
|
||||
# Verify action record created with correct trading_day_id from config
|
||||
cursor = db.connection.execute("""
|
||||
SELECT action_type, symbol, quantity, price, trading_day_id
|
||||
FROM actions
|
||||
WHERE trading_day_id = ?
|
||||
""", (trading_day_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
assert row is not None, "Action record should exist when reading trading_day_id from config"
|
||||
assert row[0] == 'buy'
|
||||
assert row[1] == 'AAPL'
|
||||
assert row[2] == 10
|
||||
assert row[3] == 150.0
|
||||
assert row[4] == trading_day_id, "trading_day_id should match the value from runtime config"
|
||||
|
||||
|
||||
def test_sell_writes_to_actions_table(test_db, monkeypatch):
|
||||
"""Test sell() writes action record to actions table."""
|
||||
db, trading_day_id = test_db
|
||||
|
||||
Reference in New Issue
Block a user