mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-01 17:17:24 -04:00
refactor: implement database-only position tracking with lazy context injection
This commit migrates the system to database-only position storage, eliminating file-based position.jsonl dependencies and fixing ContextInjector initialization timing issues. Key Changes: 1. ContextInjector Lifecycle Refactor: - Remove ContextInjector creation from BaseAgent.__init__() - Add BaseAgent.set_context() method for post-initialization injection - Update ModelDayExecutor to create ContextInjector with correct trading day date - Ensures ContextInjector receives actual trading date instead of init_date - Includes session_id injection for proper database linking 2. Database Position Functions: - Implement get_today_init_position_from_db() for querying previous positions - Implement add_no_trade_record_to_db() for no-trade day handling - Both functions query SQLite directly (positions + holdings tables) - Handle first trading day case with initial cash return - Include comprehensive error handling and logging 3. System Integration: - Update get_agent_system_prompt() to use database queries - Update _handle_trading_result() to write no-trade records to database - Remove dependencies on position.jsonl file reading/writing - Use deployment_config for automatic prod/dev database resolution Data Flow: - ModelDayExecutor creates runtime config and trading session - Agent initialized without context - ContextInjector created with (signature, date, job_id, session_id) - Context injected via set_context() - System prompt queries database for yesterday's position - Trade tools write directly to database - No-trade handler creates database records Fixes: - ContextInjector no longer receives None values - No FileNotFoundError for missing position.jsonl files - Database is single source of truth for position tracking - Session linking maintained across all position records Design: docs/plans/2025-02-11-database-position-tracking-design.md
This commit is contained in:
@@ -173,21 +173,13 @@ class BaseAgent:
|
||||
print("⚠️ OpenAI base URL not set, using default")
|
||||
|
||||
try:
|
||||
# Get job_id from runtime config if available (API mode)
|
||||
from tools.general_tools import get_config_value
|
||||
job_id = get_config_value("JOB_ID") # Returns None if not in API mode
|
||||
# Context injector will be set later via set_context() method
|
||||
self.context_injector = None
|
||||
|
||||
# Create context injector for injecting signature and today_date into tool calls
|
||||
self.context_injector = ContextInjector(
|
||||
signature=self.signature,
|
||||
today_date=self.init_date, # Will be updated per trading session
|
||||
job_id=job_id # Will be None in standalone mode, populated in API mode
|
||||
)
|
||||
|
||||
# Create MCP client with interceptor
|
||||
# Create MCP client without interceptors initially
|
||||
self.client = MultiServerMCPClient(
|
||||
self.mcp_config,
|
||||
tool_interceptors=[self.context_injector]
|
||||
tool_interceptors=[]
|
||||
)
|
||||
|
||||
# Get tools
|
||||
@@ -229,6 +221,30 @@ class BaseAgent:
|
||||
|
||||
print(f"✅ Agent {self.signature} initialization completed")
|
||||
|
||||
def set_context(self, context_injector: "ContextInjector") -> None:
|
||||
"""
|
||||
Inject ContextInjector after initialization.
|
||||
|
||||
This allows the ContextInjector to be created with the correct
|
||||
trading day date and session_id after the agent is initialized.
|
||||
|
||||
Args:
|
||||
context_injector: Configured ContextInjector instance with
|
||||
correct signature, today_date, job_id, session_id
|
||||
"""
|
||||
self.context_injector = context_injector
|
||||
|
||||
# Recreate MCP client with the interceptor
|
||||
# Note: We need to recreate because MultiServerMCPClient doesn't have add_interceptor()
|
||||
self.client = MultiServerMCPClient(
|
||||
self.mcp_config,
|
||||
tool_interceptors=[context_injector]
|
||||
)
|
||||
|
||||
print(f"✅ Context injected: signature={context_injector.signature}, "
|
||||
f"date={context_injector.today_date}, job_id={context_injector.job_id}, "
|
||||
f"session_id={context_injector.session_id}")
|
||||
|
||||
def _capture_message(self, role: str, content: str, tool_name: str = None, tool_input: str = None) -> None:
|
||||
"""
|
||||
Capture a message in conversation history.
|
||||
@@ -429,18 +445,32 @@ Summary:"""
|
||||
await self._handle_trading_result(today_date)
|
||||
|
||||
async def _handle_trading_result(self, today_date: str) -> None:
|
||||
"""Handle trading results"""
|
||||
"""Handle trading results with database writes."""
|
||||
from tools.price_tools import add_no_trade_record_to_db
|
||||
|
||||
if_trade = get_config_value("IF_TRADE")
|
||||
|
||||
if if_trade:
|
||||
write_config_value("IF_TRADE", False)
|
||||
print("✅ Trading completed")
|
||||
else:
|
||||
print("📊 No trading, maintaining positions")
|
||||
try:
|
||||
add_no_trade_record(today_date, self.signature)
|
||||
except NameError as e:
|
||||
print(f"❌ NameError: {e}")
|
||||
raise
|
||||
|
||||
# Get context from runtime config
|
||||
job_id = get_config_value("JOB_ID")
|
||||
session_id = self.context_injector.session_id if self.context_injector else None
|
||||
|
||||
if not job_id or not session_id:
|
||||
raise ValueError("Missing JOB_ID or session_id for no-trade record")
|
||||
|
||||
# Write no-trade record to database
|
||||
add_no_trade_record_to_db(
|
||||
today_date,
|
||||
self.signature,
|
||||
job_id,
|
||||
session_id
|
||||
)
|
||||
|
||||
write_config_value("IF_TRADE", False)
|
||||
|
||||
def register_agent(self) -> None:
|
||||
|
||||
@@ -129,12 +129,18 @@ class ModelDayExecutor:
|
||||
# Set environment variable for agent to use isolated config
|
||||
os.environ["RUNTIME_ENV_PATH"] = self.runtime_config_path
|
||||
|
||||
# Initialize agent
|
||||
# Initialize agent (without context)
|
||||
agent = await self._initialize_agent()
|
||||
|
||||
# Update context injector with session_id
|
||||
if hasattr(agent, 'context_injector') and agent.context_injector:
|
||||
agent.context_injector.session_id = session_id
|
||||
# Create and inject context with correct values
|
||||
from agent.context_injector import ContextInjector
|
||||
context_injector = ContextInjector(
|
||||
signature=self.model_sig,
|
||||
today_date=self.date, # Current trading day
|
||||
job_id=self.job_id,
|
||||
session_id=session_id
|
||||
)
|
||||
agent.set_context(context_injector)
|
||||
|
||||
# Run trading session
|
||||
logger.info(f"Running trading session for {self.model_sig} on {self.date}")
|
||||
|
||||
@@ -68,14 +68,24 @@ When you think your task is complete, output
|
||||
def get_agent_system_prompt(today_date: str, signature: str) -> str:
|
||||
print(f"signature: {signature}")
|
||||
print(f"today_date: {today_date}")
|
||||
|
||||
# Get job_id from runtime config
|
||||
job_id = get_config_value("JOB_ID")
|
||||
if not job_id:
|
||||
raise ValueError("JOB_ID not found in runtime config")
|
||||
|
||||
# Query database for yesterday's position
|
||||
from tools.price_tools import get_today_init_position_from_db
|
||||
today_init_position = get_today_init_position_from_db(today_date, signature, job_id)
|
||||
|
||||
# Get yesterday's buy and sell prices
|
||||
yesterday_buy_prices, yesterday_sell_prices = get_yesterday_open_and_close_price(today_date, all_nasdaq_100_symbols)
|
||||
today_buy_price = get_open_prices(today_date, all_nasdaq_100_symbols)
|
||||
today_init_position = get_today_init_position(today_date, signature)
|
||||
yesterday_profit = get_yesterday_profit(today_date, yesterday_buy_prices, yesterday_sell_prices, today_init_position)
|
||||
|
||||
return agent_system_prompt.format(
|
||||
date=today_date,
|
||||
positions=today_init_position,
|
||||
date=today_date,
|
||||
positions=today_init_position,
|
||||
STOP_SIGNAL=STOP_SIGNAL,
|
||||
yesterday_close_price=yesterday_sell_prices,
|
||||
today_buy_price=today_buy_price,
|
||||
|
||||
@@ -299,7 +299,174 @@ def add_no_trade_record(today_date: str, modelname: str):
|
||||
|
||||
with position_file.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(save_item) + "\n")
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
def get_today_init_position_from_db(
|
||||
today_date: str,
|
||||
modelname: str,
|
||||
job_id: str
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Query yesterday's position from SQLite database.
|
||||
|
||||
Args:
|
||||
today_date: Current trading date (YYYY-MM-DD)
|
||||
modelname: Model signature
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
Position dict: {"AAPL": 50, "MSFT": 30, "CASH": 5000.0}
|
||||
If no position exists: {"CASH": 10000.0} (initial cash)
|
||||
"""
|
||||
import logging
|
||||
from tools.deployment_config import get_db_path
|
||||
from api.database import get_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
db_path = get_db_path()
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Get most recent position before today
|
||||
cursor.execute("""
|
||||
SELECT p.id, p.cash
|
||||
FROM positions p
|
||||
WHERE p.job_id = ? AND p.model = ? AND p.date < ?
|
||||
ORDER BY p.date DESC, p.action_id DESC
|
||||
LIMIT 1
|
||||
""", (job_id, modelname, today_date))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
# First day - return initial cash
|
||||
logger.info(f"No previous position found for {modelname}, returning initial cash")
|
||||
return {"CASH": 10000.0}
|
||||
|
||||
position_id, cash = row
|
||||
position_dict = {"CASH": cash}
|
||||
|
||||
# Get holdings for this position
|
||||
cursor.execute("""
|
||||
SELECT symbol, quantity
|
||||
FROM holdings
|
||||
WHERE position_id = ?
|
||||
""", (position_id,))
|
||||
|
||||
for symbol, quantity in cursor.fetchall():
|
||||
position_dict[symbol] = quantity
|
||||
|
||||
logger.debug(f"Loaded position for {modelname}: {position_dict}")
|
||||
return position_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database error in get_today_init_position_from_db: {e}")
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def add_no_trade_record_to_db(
|
||||
today_date: str,
|
||||
modelname: str,
|
||||
job_id: str,
|
||||
session_id: int
|
||||
) -> None:
|
||||
"""
|
||||
Create no-trade position record in SQLite database.
|
||||
|
||||
Args:
|
||||
today_date: Current trading date (YYYY-MM-DD)
|
||||
modelname: Model signature
|
||||
job_id: Job UUID
|
||||
session_id: Trading session ID
|
||||
"""
|
||||
import logging
|
||||
from tools.deployment_config import get_db_path
|
||||
from api.database import get_db_connection
|
||||
from agent_tools.tool_trade import get_current_position_from_db
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
db_path = get_db_path()
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Get current position
|
||||
current_position, next_action_id = get_current_position_from_db(
|
||||
job_id, modelname, today_date
|
||||
)
|
||||
|
||||
# Calculate portfolio value
|
||||
cash = current_position.get("CASH", 0.0)
|
||||
portfolio_value = cash
|
||||
|
||||
# Add stock values
|
||||
for symbol, qty in current_position.items():
|
||||
if symbol != "CASH":
|
||||
try:
|
||||
price = get_open_prices(today_date, [symbol])[f'{symbol}_price']
|
||||
portfolio_value += qty * price
|
||||
except KeyError:
|
||||
logger.warning(f"Price not found for {symbol} on {today_date}")
|
||||
pass
|
||||
|
||||
# Get previous value for P&L
|
||||
cursor.execute("""
|
||||
SELECT portfolio_value
|
||||
FROM positions
|
||||
WHERE job_id = ? AND model = ? AND date < ?
|
||||
ORDER BY date DESC, action_id DESC
|
||||
LIMIT 1
|
||||
""", (job_id, modelname, today_date))
|
||||
|
||||
row = cursor.fetchone()
|
||||
previous_value = row[0] if row else 10000.0
|
||||
|
||||
daily_profit = portfolio_value - previous_value
|
||||
daily_return_pct = (daily_profit / previous_value * 100) if previous_value > 0 else 0
|
||||
|
||||
# Insert position record
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO positions (
|
||||
job_id, date, model, action_id, action_type,
|
||||
cash, portfolio_value, daily_profit, daily_return_pct,
|
||||
session_id, created_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
job_id, today_date, modelname, next_action_id, "no_trade",
|
||||
cash, portfolio_value, daily_profit, daily_return_pct,
|
||||
session_id, created_at
|
||||
))
|
||||
|
||||
position_id = cursor.lastrowid
|
||||
|
||||
# Insert holdings (unchanged from previous position)
|
||||
for symbol, qty in current_position.items():
|
||||
if symbol != "CASH":
|
||||
cursor.execute("""
|
||||
INSERT INTO holdings (position_id, symbol, quantity)
|
||||
VALUES (?, ?, ?)
|
||||
""", (position_id, symbol, qty))
|
||||
|
||||
conn.commit()
|
||||
logger.info(f"Created no-trade record for {modelname} on {today_date}")
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(f"Database error in add_no_trade_record_to_db: {e}")
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
today_date = get_config_value("TODAY_DATE")
|
||||
|
||||
Reference in New Issue
Block a user