From 027b4bd8e499cacfaf9ad508bf5b3f87fd3088ef Mon Sep 17 00:00:00 2001 From: Bill Date: Sun, 2 Nov 2025 22:20:01 -0500 Subject: [PATCH] refactor: implement database-only position tracking with lazy context injection This commit migrates the system to database-only position storage, eliminating file-based position.jsonl dependencies and fixing ContextInjector initialization timing issues. Key Changes: 1. ContextInjector Lifecycle Refactor: - Remove ContextInjector creation from BaseAgent.__init__() - Add BaseAgent.set_context() method for post-initialization injection - Update ModelDayExecutor to create ContextInjector with correct trading day date - Ensures ContextInjector receives actual trading date instead of init_date - Includes session_id injection for proper database linking 2. Database Position Functions: - Implement get_today_init_position_from_db() for querying previous positions - Implement add_no_trade_record_to_db() for no-trade day handling - Both functions query SQLite directly (positions + holdings tables) - Handle first trading day case with initial cash return - Include comprehensive error handling and logging 3. System Integration: - Update get_agent_system_prompt() to use database queries - Update _handle_trading_result() to write no-trade records to database - Remove dependencies on position.jsonl file reading/writing - Use deployment_config for automatic prod/dev database resolution Data Flow: - ModelDayExecutor creates runtime config and trading session - Agent initialized without context - ContextInjector created with (signature, date, job_id, session_id) - Context injected via set_context() - System prompt queries database for yesterday's position - Trade tools write directly to database - No-trade handler creates database records Fixes: - ContextInjector no longer receives None values - No FileNotFoundError for missing position.jsonl files - Database is single source of truth for position tracking - Session linking maintained across all position records Design: docs/plans/2025-02-11-database-position-tracking-design.md --- agent/base_agent/base_agent.py | 66 +++++++++---- api/model_day_executor.py | 14 ++- prompts/agent_prompt.py | 16 +++- tools/price_tools.py | 169 ++++++++++++++++++++++++++++++++- 4 files changed, 239 insertions(+), 26 deletions(-) diff --git a/agent/base_agent/base_agent.py b/agent/base_agent/base_agent.py index 323b295..e73059e 100644 --- a/agent/base_agent/base_agent.py +++ b/agent/base_agent/base_agent.py @@ -173,21 +173,13 @@ class BaseAgent: print("⚠️ OpenAI base URL not set, using default") try: - # Get job_id from runtime config if available (API mode) - from tools.general_tools import get_config_value - job_id = get_config_value("JOB_ID") # Returns None if not in API mode + # Context injector will be set later via set_context() method + self.context_injector = None - # Create context injector for injecting signature and today_date into tool calls - self.context_injector = ContextInjector( - signature=self.signature, - today_date=self.init_date, # Will be updated per trading session - job_id=job_id # Will be None in standalone mode, populated in API mode - ) - - # Create MCP client with interceptor + # Create MCP client without interceptors initially self.client = MultiServerMCPClient( self.mcp_config, - tool_interceptors=[self.context_injector] + tool_interceptors=[] ) # Get tools @@ -229,6 +221,30 @@ class BaseAgent: print(f"✅ Agent {self.signature} initialization completed") + def set_context(self, context_injector: "ContextInjector") -> None: + """ + Inject ContextInjector after initialization. + + This allows the ContextInjector to be created with the correct + trading day date and session_id after the agent is initialized. + + Args: + context_injector: Configured ContextInjector instance with + correct signature, today_date, job_id, session_id + """ + self.context_injector = context_injector + + # Recreate MCP client with the interceptor + # Note: We need to recreate because MultiServerMCPClient doesn't have add_interceptor() + self.client = MultiServerMCPClient( + self.mcp_config, + tool_interceptors=[context_injector] + ) + + print(f"✅ Context injected: signature={context_injector.signature}, " + f"date={context_injector.today_date}, job_id={context_injector.job_id}, " + f"session_id={context_injector.session_id}") + def _capture_message(self, role: str, content: str, tool_name: str = None, tool_input: str = None) -> None: """ Capture a message in conversation history. @@ -429,18 +445,32 @@ Summary:""" await self._handle_trading_result(today_date) async def _handle_trading_result(self, today_date: str) -> None: - """Handle trading results""" + """Handle trading results with database writes.""" + from tools.price_tools import add_no_trade_record_to_db + if_trade = get_config_value("IF_TRADE") + if if_trade: write_config_value("IF_TRADE", False) print("✅ Trading completed") else: print("📊 No trading, maintaining positions") - try: - add_no_trade_record(today_date, self.signature) - except NameError as e: - print(f"❌ NameError: {e}") - raise + + # Get context from runtime config + job_id = get_config_value("JOB_ID") + session_id = self.context_injector.session_id if self.context_injector else None + + if not job_id or not session_id: + raise ValueError("Missing JOB_ID or session_id for no-trade record") + + # Write no-trade record to database + add_no_trade_record_to_db( + today_date, + self.signature, + job_id, + session_id + ) + write_config_value("IF_TRADE", False) def register_agent(self) -> None: diff --git a/api/model_day_executor.py b/api/model_day_executor.py index 3b75117..d3b5af7 100644 --- a/api/model_day_executor.py +++ b/api/model_day_executor.py @@ -129,12 +129,18 @@ class ModelDayExecutor: # Set environment variable for agent to use isolated config os.environ["RUNTIME_ENV_PATH"] = self.runtime_config_path - # Initialize agent + # Initialize agent (without context) agent = await self._initialize_agent() - # Update context injector with session_id - if hasattr(agent, 'context_injector') and agent.context_injector: - agent.context_injector.session_id = session_id + # Create and inject context with correct values + from agent.context_injector import ContextInjector + context_injector = ContextInjector( + signature=self.model_sig, + today_date=self.date, # Current trading day + job_id=self.job_id, + session_id=session_id + ) + agent.set_context(context_injector) # Run trading session logger.info(f"Running trading session for {self.model_sig} on {self.date}") diff --git a/prompts/agent_prompt.py b/prompts/agent_prompt.py index df89b53..2cb818c 100644 --- a/prompts/agent_prompt.py +++ b/prompts/agent_prompt.py @@ -68,14 +68,24 @@ When you think your task is complete, output def get_agent_system_prompt(today_date: str, signature: str) -> str: print(f"signature: {signature}") print(f"today_date: {today_date}") + + # Get job_id from runtime config + job_id = get_config_value("JOB_ID") + if not job_id: + raise ValueError("JOB_ID not found in runtime config") + + # Query database for yesterday's position + from tools.price_tools import get_today_init_position_from_db + today_init_position = get_today_init_position_from_db(today_date, signature, job_id) + # Get yesterday's buy and sell prices yesterday_buy_prices, yesterday_sell_prices = get_yesterday_open_and_close_price(today_date, all_nasdaq_100_symbols) today_buy_price = get_open_prices(today_date, all_nasdaq_100_symbols) - today_init_position = get_today_init_position(today_date, signature) yesterday_profit = get_yesterday_profit(today_date, yesterday_buy_prices, yesterday_sell_prices, today_init_position) + return agent_system_prompt.format( - date=today_date, - positions=today_init_position, + date=today_date, + positions=today_init_position, STOP_SIGNAL=STOP_SIGNAL, yesterday_close_price=yesterday_sell_prices, today_buy_price=today_buy_price, diff --git a/tools/price_tools.py b/tools/price_tools.py index 4aeec81..3386fd0 100644 --- a/tools/price_tools.py +++ b/tools/price_tools.py @@ -299,7 +299,174 @@ def add_no_trade_record(today_date: str, modelname: str): with position_file.open("a", encoding="utf-8") as f: f.write(json.dumps(save_item) + "\n") - return + return + + +def get_today_init_position_from_db( + today_date: str, + modelname: str, + job_id: str +) -> Dict[str, float]: + """ + Query yesterday's position from SQLite database. + + Args: + today_date: Current trading date (YYYY-MM-DD) + modelname: Model signature + job_id: Job UUID + + Returns: + Position dict: {"AAPL": 50, "MSFT": 30, "CASH": 5000.0} + If no position exists: {"CASH": 10000.0} (initial cash) + """ + import logging + from tools.deployment_config import get_db_path + from api.database import get_db_connection + + logger = logging.getLogger(__name__) + + db_path = get_db_path() + conn = get_db_connection(db_path) + cursor = conn.cursor() + + try: + # Get most recent position before today + cursor.execute(""" + SELECT p.id, p.cash + FROM positions p + WHERE p.job_id = ? AND p.model = ? AND p.date < ? + ORDER BY p.date DESC, p.action_id DESC + LIMIT 1 + """, (job_id, modelname, today_date)) + + row = cursor.fetchone() + + if not row: + # First day - return initial cash + logger.info(f"No previous position found for {modelname}, returning initial cash") + return {"CASH": 10000.0} + + position_id, cash = row + position_dict = {"CASH": cash} + + # Get holdings for this position + cursor.execute(""" + SELECT symbol, quantity + FROM holdings + WHERE position_id = ? + """, (position_id,)) + + for symbol, quantity in cursor.fetchall(): + position_dict[symbol] = quantity + + logger.debug(f"Loaded position for {modelname}: {position_dict}") + return position_dict + + except Exception as e: + logger.error(f"Database error in get_today_init_position_from_db: {e}") + raise + finally: + conn.close() + + +def add_no_trade_record_to_db( + today_date: str, + modelname: str, + job_id: str, + session_id: int +) -> None: + """ + Create no-trade position record in SQLite database. + + Args: + today_date: Current trading date (YYYY-MM-DD) + modelname: Model signature + job_id: Job UUID + session_id: Trading session ID + """ + import logging + from tools.deployment_config import get_db_path + from api.database import get_db_connection + from agent_tools.tool_trade import get_current_position_from_db + from datetime import datetime + + logger = logging.getLogger(__name__) + + db_path = get_db_path() + conn = get_db_connection(db_path) + cursor = conn.cursor() + + try: + # Get current position + current_position, next_action_id = get_current_position_from_db( + job_id, modelname, today_date + ) + + # Calculate portfolio value + cash = current_position.get("CASH", 0.0) + portfolio_value = cash + + # Add stock values + for symbol, qty in current_position.items(): + if symbol != "CASH": + try: + price = get_open_prices(today_date, [symbol])[f'{symbol}_price'] + portfolio_value += qty * price + except KeyError: + logger.warning(f"Price not found for {symbol} on {today_date}") + pass + + # Get previous value for P&L + cursor.execute(""" + SELECT portfolio_value + FROM positions + WHERE job_id = ? AND model = ? AND date < ? + ORDER BY date DESC, action_id DESC + LIMIT 1 + """, (job_id, modelname, today_date)) + + row = cursor.fetchone() + previous_value = row[0] if row else 10000.0 + + daily_profit = portfolio_value - previous_value + daily_return_pct = (daily_profit / previous_value * 100) if previous_value > 0 else 0 + + # Insert position record + created_at = datetime.utcnow().isoformat() + "Z" + + cursor.execute(""" + INSERT INTO positions ( + job_id, date, model, action_id, action_type, + cash, portfolio_value, daily_profit, daily_return_pct, + session_id, created_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + job_id, today_date, modelname, next_action_id, "no_trade", + cash, portfolio_value, daily_profit, daily_return_pct, + session_id, created_at + )) + + position_id = cursor.lastrowid + + # Insert holdings (unchanged from previous position) + for symbol, qty in current_position.items(): + if symbol != "CASH": + cursor.execute(""" + INSERT INTO holdings (position_id, symbol, quantity) + VALUES (?, ?, ?) + """, (position_id, symbol, qty)) + + conn.commit() + logger.info(f"Created no-trade record for {modelname} on {today_date}") + + except Exception as e: + conn.rollback() + logger.error(f"Database error in add_no_trade_record_to_db: {e}") + raise + finally: + conn.close() + if __name__ == "__main__": today_date = get_config_value("TODAY_DATE")