mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-01 17:17:24 -04:00
refactor: implement database-only position tracking with lazy context injection
This commit migrates the system to database-only position storage, eliminating file-based position.jsonl dependencies and fixing ContextInjector initialization timing issues. Key Changes: 1. ContextInjector Lifecycle Refactor: - Remove ContextInjector creation from BaseAgent.__init__() - Add BaseAgent.set_context() method for post-initialization injection - Update ModelDayExecutor to create ContextInjector with correct trading day date - Ensures ContextInjector receives actual trading date instead of init_date - Includes session_id injection for proper database linking 2. Database Position Functions: - Implement get_today_init_position_from_db() for querying previous positions - Implement add_no_trade_record_to_db() for no-trade day handling - Both functions query SQLite directly (positions + holdings tables) - Handle first trading day case with initial cash return - Include comprehensive error handling and logging 3. System Integration: - Update get_agent_system_prompt() to use database queries - Update _handle_trading_result() to write no-trade records to database - Remove dependencies on position.jsonl file reading/writing - Use deployment_config for automatic prod/dev database resolution Data Flow: - ModelDayExecutor creates runtime config and trading session - Agent initialized without context - ContextInjector created with (signature, date, job_id, session_id) - Context injected via set_context() - System prompt queries database for yesterday's position - Trade tools write directly to database - No-trade handler creates database records Fixes: - ContextInjector no longer receives None values - No FileNotFoundError for missing position.jsonl files - Database is single source of truth for position tracking - Session linking maintained across all position records Design: docs/plans/2025-02-11-database-position-tracking-design.md
This commit is contained in:
@@ -173,21 +173,13 @@ class BaseAgent:
|
|||||||
print("⚠️ OpenAI base URL not set, using default")
|
print("⚠️ OpenAI base URL not set, using default")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get job_id from runtime config if available (API mode)
|
# Context injector will be set later via set_context() method
|
||||||
from tools.general_tools import get_config_value
|
self.context_injector = None
|
||||||
job_id = get_config_value("JOB_ID") # Returns None if not in API mode
|
|
||||||
|
|
||||||
# Create context injector for injecting signature and today_date into tool calls
|
# Create MCP client without interceptors initially
|
||||||
self.context_injector = ContextInjector(
|
|
||||||
signature=self.signature,
|
|
||||||
today_date=self.init_date, # Will be updated per trading session
|
|
||||||
job_id=job_id # Will be None in standalone mode, populated in API mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create MCP client with interceptor
|
|
||||||
self.client = MultiServerMCPClient(
|
self.client = MultiServerMCPClient(
|
||||||
self.mcp_config,
|
self.mcp_config,
|
||||||
tool_interceptors=[self.context_injector]
|
tool_interceptors=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get tools
|
# Get tools
|
||||||
@@ -229,6 +221,30 @@ class BaseAgent:
|
|||||||
|
|
||||||
print(f"✅ Agent {self.signature} initialization completed")
|
print(f"✅ Agent {self.signature} initialization completed")
|
||||||
|
|
||||||
|
def set_context(self, context_injector: "ContextInjector") -> None:
|
||||||
|
"""
|
||||||
|
Inject ContextInjector after initialization.
|
||||||
|
|
||||||
|
This allows the ContextInjector to be created with the correct
|
||||||
|
trading day date and session_id after the agent is initialized.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_injector: Configured ContextInjector instance with
|
||||||
|
correct signature, today_date, job_id, session_id
|
||||||
|
"""
|
||||||
|
self.context_injector = context_injector
|
||||||
|
|
||||||
|
# Recreate MCP client with the interceptor
|
||||||
|
# Note: We need to recreate because MultiServerMCPClient doesn't have add_interceptor()
|
||||||
|
self.client = MultiServerMCPClient(
|
||||||
|
self.mcp_config,
|
||||||
|
tool_interceptors=[context_injector]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Context injected: signature={context_injector.signature}, "
|
||||||
|
f"date={context_injector.today_date}, job_id={context_injector.job_id}, "
|
||||||
|
f"session_id={context_injector.session_id}")
|
||||||
|
|
||||||
def _capture_message(self, role: str, content: str, tool_name: str = None, tool_input: str = None) -> None:
|
def _capture_message(self, role: str, content: str, tool_name: str = None, tool_input: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Capture a message in conversation history.
|
Capture a message in conversation history.
|
||||||
@@ -429,18 +445,32 @@ Summary:"""
|
|||||||
await self._handle_trading_result(today_date)
|
await self._handle_trading_result(today_date)
|
||||||
|
|
||||||
async def _handle_trading_result(self, today_date: str) -> None:
|
async def _handle_trading_result(self, today_date: str) -> None:
|
||||||
"""Handle trading results"""
|
"""Handle trading results with database writes."""
|
||||||
|
from tools.price_tools import add_no_trade_record_to_db
|
||||||
|
|
||||||
if_trade = get_config_value("IF_TRADE")
|
if_trade = get_config_value("IF_TRADE")
|
||||||
|
|
||||||
if if_trade:
|
if if_trade:
|
||||||
write_config_value("IF_TRADE", False)
|
write_config_value("IF_TRADE", False)
|
||||||
print("✅ Trading completed")
|
print("✅ Trading completed")
|
||||||
else:
|
else:
|
||||||
print("📊 No trading, maintaining positions")
|
print("📊 No trading, maintaining positions")
|
||||||
try:
|
|
||||||
add_no_trade_record(today_date, self.signature)
|
# Get context from runtime config
|
||||||
except NameError as e:
|
job_id = get_config_value("JOB_ID")
|
||||||
print(f"❌ NameError: {e}")
|
session_id = self.context_injector.session_id if self.context_injector else None
|
||||||
raise
|
|
||||||
|
if not job_id or not session_id:
|
||||||
|
raise ValueError("Missing JOB_ID or session_id for no-trade record")
|
||||||
|
|
||||||
|
# Write no-trade record to database
|
||||||
|
add_no_trade_record_to_db(
|
||||||
|
today_date,
|
||||||
|
self.signature,
|
||||||
|
job_id,
|
||||||
|
session_id
|
||||||
|
)
|
||||||
|
|
||||||
write_config_value("IF_TRADE", False)
|
write_config_value("IF_TRADE", False)
|
||||||
|
|
||||||
def register_agent(self) -> None:
|
def register_agent(self) -> None:
|
||||||
|
|||||||
@@ -129,12 +129,18 @@ class ModelDayExecutor:
|
|||||||
# Set environment variable for agent to use isolated config
|
# Set environment variable for agent to use isolated config
|
||||||
os.environ["RUNTIME_ENV_PATH"] = self.runtime_config_path
|
os.environ["RUNTIME_ENV_PATH"] = self.runtime_config_path
|
||||||
|
|
||||||
# Initialize agent
|
# Initialize agent (without context)
|
||||||
agent = await self._initialize_agent()
|
agent = await self._initialize_agent()
|
||||||
|
|
||||||
# Update context injector with session_id
|
# Create and inject context with correct values
|
||||||
if hasattr(agent, 'context_injector') and agent.context_injector:
|
from agent.context_injector import ContextInjector
|
||||||
agent.context_injector.session_id = session_id
|
context_injector = ContextInjector(
|
||||||
|
signature=self.model_sig,
|
||||||
|
today_date=self.date, # Current trading day
|
||||||
|
job_id=self.job_id,
|
||||||
|
session_id=session_id
|
||||||
|
)
|
||||||
|
agent.set_context(context_injector)
|
||||||
|
|
||||||
# Run trading session
|
# Run trading session
|
||||||
logger.info(f"Running trading session for {self.model_sig} on {self.date}")
|
logger.info(f"Running trading session for {self.model_sig} on {self.date}")
|
||||||
|
|||||||
@@ -68,11 +68,21 @@ When you think your task is complete, output
|
|||||||
def get_agent_system_prompt(today_date: str, signature: str) -> str:
|
def get_agent_system_prompt(today_date: str, signature: str) -> str:
|
||||||
print(f"signature: {signature}")
|
print(f"signature: {signature}")
|
||||||
print(f"today_date: {today_date}")
|
print(f"today_date: {today_date}")
|
||||||
|
|
||||||
|
# Get job_id from runtime config
|
||||||
|
job_id = get_config_value("JOB_ID")
|
||||||
|
if not job_id:
|
||||||
|
raise ValueError("JOB_ID not found in runtime config")
|
||||||
|
|
||||||
|
# Query database for yesterday's position
|
||||||
|
from tools.price_tools import get_today_init_position_from_db
|
||||||
|
today_init_position = get_today_init_position_from_db(today_date, signature, job_id)
|
||||||
|
|
||||||
# Get yesterday's buy and sell prices
|
# Get yesterday's buy and sell prices
|
||||||
yesterday_buy_prices, yesterday_sell_prices = get_yesterday_open_and_close_price(today_date, all_nasdaq_100_symbols)
|
yesterday_buy_prices, yesterday_sell_prices = get_yesterday_open_and_close_price(today_date, all_nasdaq_100_symbols)
|
||||||
today_buy_price = get_open_prices(today_date, all_nasdaq_100_symbols)
|
today_buy_price = get_open_prices(today_date, all_nasdaq_100_symbols)
|
||||||
today_init_position = get_today_init_position(today_date, signature)
|
|
||||||
yesterday_profit = get_yesterday_profit(today_date, yesterday_buy_prices, yesterday_sell_prices, today_init_position)
|
yesterday_profit = get_yesterday_profit(today_date, yesterday_buy_prices, yesterday_sell_prices, today_init_position)
|
||||||
|
|
||||||
return agent_system_prompt.format(
|
return agent_system_prompt.format(
|
||||||
date=today_date,
|
date=today_date,
|
||||||
positions=today_init_position,
|
positions=today_init_position,
|
||||||
|
|||||||
@@ -301,6 +301,173 @@ def add_no_trade_record(today_date: str, modelname: str):
|
|||||||
f.write(json.dumps(save_item) + "\n")
|
f.write(json.dumps(save_item) + "\n")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def get_today_init_position_from_db(
|
||||||
|
today_date: str,
|
||||||
|
modelname: str,
|
||||||
|
job_id: str
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Query yesterday's position from SQLite database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
today_date: Current trading date (YYYY-MM-DD)
|
||||||
|
modelname: Model signature
|
||||||
|
job_id: Job UUID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Position dict: {"AAPL": 50, "MSFT": 30, "CASH": 5000.0}
|
||||||
|
If no position exists: {"CASH": 10000.0} (initial cash)
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from tools.deployment_config import get_db_path
|
||||||
|
from api.database import get_db_connection
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
db_path = get_db_path()
|
||||||
|
conn = get_db_connection(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get most recent position before today
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT p.id, p.cash
|
||||||
|
FROM positions p
|
||||||
|
WHERE p.job_id = ? AND p.model = ? AND p.date < ?
|
||||||
|
ORDER BY p.date DESC, p.action_id DESC
|
||||||
|
LIMIT 1
|
||||||
|
""", (job_id, modelname, today_date))
|
||||||
|
|
||||||
|
row = cursor.fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
# First day - return initial cash
|
||||||
|
logger.info(f"No previous position found for {modelname}, returning initial cash")
|
||||||
|
return {"CASH": 10000.0}
|
||||||
|
|
||||||
|
position_id, cash = row
|
||||||
|
position_dict = {"CASH": cash}
|
||||||
|
|
||||||
|
# Get holdings for this position
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT symbol, quantity
|
||||||
|
FROM holdings
|
||||||
|
WHERE position_id = ?
|
||||||
|
""", (position_id,))
|
||||||
|
|
||||||
|
for symbol, quantity in cursor.fetchall():
|
||||||
|
position_dict[symbol] = quantity
|
||||||
|
|
||||||
|
logger.debug(f"Loaded position for {modelname}: {position_dict}")
|
||||||
|
return position_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database error in get_today_init_position_from_db: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def add_no_trade_record_to_db(
|
||||||
|
today_date: str,
|
||||||
|
modelname: str,
|
||||||
|
job_id: str,
|
||||||
|
session_id: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create no-trade position record in SQLite database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
today_date: Current trading date (YYYY-MM-DD)
|
||||||
|
modelname: Model signature
|
||||||
|
job_id: Job UUID
|
||||||
|
session_id: Trading session ID
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from tools.deployment_config import get_db_path
|
||||||
|
from api.database import get_db_connection
|
||||||
|
from agent_tools.tool_trade import get_current_position_from_db
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
db_path = get_db_path()
|
||||||
|
conn = get_db_connection(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get current position
|
||||||
|
current_position, next_action_id = get_current_position_from_db(
|
||||||
|
job_id, modelname, today_date
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate portfolio value
|
||||||
|
cash = current_position.get("CASH", 0.0)
|
||||||
|
portfolio_value = cash
|
||||||
|
|
||||||
|
# Add stock values
|
||||||
|
for symbol, qty in current_position.items():
|
||||||
|
if symbol != "CASH":
|
||||||
|
try:
|
||||||
|
price = get_open_prices(today_date, [symbol])[f'{symbol}_price']
|
||||||
|
portfolio_value += qty * price
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"Price not found for {symbol} on {today_date}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Get previous value for P&L
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT portfolio_value
|
||||||
|
FROM positions
|
||||||
|
WHERE job_id = ? AND model = ? AND date < ?
|
||||||
|
ORDER BY date DESC, action_id DESC
|
||||||
|
LIMIT 1
|
||||||
|
""", (job_id, modelname, today_date))
|
||||||
|
|
||||||
|
row = cursor.fetchone()
|
||||||
|
previous_value = row[0] if row else 10000.0
|
||||||
|
|
||||||
|
daily_profit = portfolio_value - previous_value
|
||||||
|
daily_return_pct = (daily_profit / previous_value * 100) if previous_value > 0 else 0
|
||||||
|
|
||||||
|
# Insert position record
|
||||||
|
created_at = datetime.utcnow().isoformat() + "Z"
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO positions (
|
||||||
|
job_id, date, model, action_id, action_type,
|
||||||
|
cash, portfolio_value, daily_profit, daily_return_pct,
|
||||||
|
session_id, created_at
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
job_id, today_date, modelname, next_action_id, "no_trade",
|
||||||
|
cash, portfolio_value, daily_profit, daily_return_pct,
|
||||||
|
session_id, created_at
|
||||||
|
))
|
||||||
|
|
||||||
|
position_id = cursor.lastrowid
|
||||||
|
|
||||||
|
# Insert holdings (unchanged from previous position)
|
||||||
|
for symbol, qty in current_position.items():
|
||||||
|
if symbol != "CASH":
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO holdings (position_id, symbol, quantity)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""", (position_id, symbol, qty))
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
logger.info(f"Created no-trade record for {modelname} on {today_date}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
conn.rollback()
|
||||||
|
logger.error(f"Database error in add_no_trade_record_to_db: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
today_date = get_config_value("TODAY_DATE")
|
today_date = get_config_value("TODAY_DATE")
|
||||||
signature = get_config_value("SIGNATURE")
|
signature = get_config_value("SIGNATURE")
|
||||||
|
|||||||
Reference in New Issue
Block a user