""" BaseAgent class - Base class for trading agents Encapsulates core functionality including MCP tool management, AI agent creation, and trading execution """ import os import json import asyncio from datetime import datetime, timedelta from typing import Dict, List, Optional, Any from pathlib import Path from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_openai import ChatOpenAI from langchain.agents import create_agent from dotenv import load_dotenv # Import project tools import sys project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, project_root) from tools.general_tools import extract_conversation, extract_tool_messages, get_config_value, write_config_value from tools.price_tools import add_no_trade_record from prompts.agent_prompt import get_agent_system_prompt, STOP_SIGNAL from tools.deployment_config import ( is_dev_mode, get_data_path, log_api_key_warning, get_deployment_mode ) from agent.context_injector import ContextInjector # Load environment variables load_dotenv() class BaseAgent: """ Base class for trading agents Main functionalities: 1. MCP tool management and connection 2. AI agent creation and configuration 3. Trading execution and decision loops 4. Logging and management 5. Position and configuration management """ # Default NASDAQ 100 stock symbols DEFAULT_STOCK_SYMBOLS = [ "NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA", "NFLX", "PLTR", "COST", "ASML", "AMD", "CSCO", "AZN", "TMUS", "MU", "LIN", "PEP", "SHOP", "APP", "INTU", "AMAT", "LRCX", "PDD", "QCOM", "ARM", "INTC", "BKNG", "AMGN", "TXN", "ISRG", "GILD", "KLAC", "PANW", "ADBE", "HON", "CRWD", "CEG", "ADI", "ADP", "DASH", "CMCSA", "VRTX", "MELI", "SBUX", "CDNS", "ORLY", "SNPS", "MSTR", "MDLZ", "ABNB", "MRVL", "CTAS", "TRI", "MAR", "MNST", "CSX", "ADSK", "PYPL", "FTNT", "AEP", "WDAY", "REGN", "ROP", "NXPI", "DDOG", "AXON", "ROST", "IDXX", "EA", "PCAR", "FAST", "EXC", "TTWO", "XEL", "ZS", "PAYX", "WBD", "BKR", "CPRT", "CCEP", "FANG", "TEAM", "CHTR", "KDP", "MCHP", "GEHC", "VRSK", "CTSH", "CSGP", "KHC", "ODFL", "DXCM", "TTD", "ON", "BIIB", "LULU", "CDW", "GFS" ] def __init__( self, signature: str, basemodel: str, stock_symbols: Optional[List[str]] = None, mcp_config: Optional[Dict[str, Dict[str, Any]]] = None, log_path: Optional[str] = None, max_steps: int = 10, max_retries: int = 3, base_delay: float = 0.5, openai_base_url: Optional[str] = None, openai_api_key: Optional[str] = None, initial_cash: float = 10000.0, init_date: str = "2025-10-13" ): """ Initialize BaseAgent Args: signature: Agent signature/name basemodel: Base model name stock_symbols: List of stock symbols, defaults to NASDAQ 100 mcp_config: MCP tool configuration, including port and URL information log_path: Data path for position files (JSONL logging removed, kept for backward compatibility) max_steps: Maximum reasoning steps max_retries: Maximum retry attempts base_delay: Base delay time for retries openai_base_url: OpenAI API base URL openai_api_key: OpenAI API key initial_cash: Initial cash amount init_date: Initialization date """ self.signature = signature self.basemodel = basemodel self.stock_symbols = stock_symbols or self.DEFAULT_STOCK_SYMBOLS self.max_steps = max_steps self.max_retries = max_retries self.base_delay = base_delay self.initial_cash = initial_cash self.init_date = init_date # Set MCP configuration self.mcp_config = mcp_config or self._get_default_mcp_config() # Set data path (apply deployment mode path resolution) # Note: Used for position files only; JSONL logging has been removed self.base_log_path = get_data_path(log_path or "./data/agent_data") # Set OpenAI configuration if openai_base_url==None: self.openai_base_url = os.getenv("OPENAI_API_BASE") else: self.openai_base_url = openai_base_url if openai_api_key==None: self.openai_api_key = os.getenv("OPENAI_API_KEY") else: self.openai_api_key = openai_api_key # Initialize components self.client: Optional[MultiServerMCPClient] = None self.tools: Optional[List] = None self.model: Optional[ChatOpenAI] = None self.agent: Optional[Any] = None # Context injector for MCP tools self.context_injector: Optional[ContextInjector] = None # Data paths self.data_path = os.path.join(self.base_log_path, self.signature) self.position_file = os.path.join(self.data_path, "position", "position.jsonl") # Conversation history for reasoning logs self.conversation_history: List[Dict[str, Any]] = [] def _get_default_mcp_config(self) -> Dict[str, Dict[str, Any]]: """Get default MCP configuration""" return { "math": { "transport": "streamable_http", "url": f"http://localhost:{os.getenv('MATH_HTTP_PORT', '8000')}/mcp", }, "stock_local": { "transport": "streamable_http", "url": f"http://localhost:{os.getenv('GETPRICE_HTTP_PORT', '8003')}/mcp", }, "search": { "transport": "streamable_http", "url": f"http://localhost:{os.getenv('SEARCH_HTTP_PORT', '8001')}/mcp", }, "trade": { "transport": "streamable_http", "url": f"http://localhost:{os.getenv('TRADE_HTTP_PORT', '8002')}/mcp", }, } async def initialize(self) -> None: """Initialize MCP client and AI model""" print(f"🚀 Initializing agent: {self.signature}") print(f"🔧 Deployment mode: {get_deployment_mode()}") # Log API key warning if in dev mode log_api_key_warning() # Validate OpenAI configuration (only in PROD mode) if not is_dev_mode(): if not self.openai_api_key: raise ValueError("❌ OpenAI API key not set. Please configure OPENAI_API_KEY in environment or config file.") if not self.openai_base_url: print("âš ī¸ OpenAI base URL not set, using default") try: # Context injector will be set later via set_context() method self.context_injector = None # Create MCP client without interceptors initially self.client = MultiServerMCPClient( self.mcp_config, tool_interceptors=[] ) # Get tools raw_tools = await self.client.get_tools() if not raw_tools: print("âš ī¸ Warning: No MCP tools loaded. MCP services may not be running.") print(f" MCP configuration: {self.mcp_config}") self.tools = [] else: print(f"✅ Loaded {len(raw_tools)} MCP tools") self.tools = raw_tools except Exception as e: raise RuntimeError( f"❌ Failed to initialize MCP client: {e}\n" f" Please ensure MCP services are running at the configured ports.\n" f" Run: python agent_tools/start_mcp_services.py" ) try: # Create AI model (mock in DEV mode, real in PROD mode) if is_dev_mode(): from agent.mock_provider import MockChatModel self.model = MockChatModel(date="2025-01-01") # Date will be updated per session print(f"🤖 Using MockChatModel (DEV mode)") else: self.model = ChatOpenAI( model=self.basemodel, base_url=self.openai_base_url, api_key=self.openai_api_key, max_retries=3, timeout=30 ) print(f"🤖 Using {self.basemodel} (PROD mode)") except Exception as e: raise RuntimeError(f"❌ Failed to initialize AI model: {e}") # Note: agent will be created in run_trading_session() based on specific date # because system_prompt needs the current date and price information print(f"✅ Agent {self.signature} initialization completed") def set_context(self, context_injector: "ContextInjector") -> None: """ Inject ContextInjector after initialization. This allows the ContextInjector to be created with the correct trading day date and session_id after the agent is initialized. Args: context_injector: Configured ContextInjector instance with correct signature, today_date, job_id, session_id """ self.context_injector = context_injector # Recreate MCP client with the interceptor # Note: We need to recreate because MultiServerMCPClient doesn't have add_interceptor() self.client = MultiServerMCPClient( self.mcp_config, tool_interceptors=[context_injector] ) print(f"✅ Context injected: signature={context_injector.signature}, " f"date={context_injector.today_date}, job_id={context_injector.job_id}, " f"session_id={context_injector.session_id}") def _capture_message(self, role: str, content: str, tool_name: str = None, tool_input: str = None) -> None: """ Capture a message in conversation history. Args: role: Message role ('user', 'assistant', 'tool') content: Message content tool_name: Tool name for tool messages tool_input: Tool input for tool messages """ from datetime import datetime, timezone message = { "role": role, "content": content, "timestamp": datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z') } if tool_name: message["tool_name"] = tool_name if tool_input: message["tool_input"] = tool_input self.conversation_history.append(message) def get_conversation_history(self) -> List[Dict[str, Any]]: """ Get the complete conversation history for this trading session. Returns: List of message dictionaries with role, content, timestamp """ return self.conversation_history.copy() def clear_conversation_history(self) -> None: """Clear conversation history (called at start of each trading day).""" self.conversation_history = [] async def generate_summary(self, content: str, max_length: int = 200) -> str: """ Generate a concise summary of reasoning content. Uses the same AI model to summarize its own reasoning. Args: content: Full reasoning content to summarize max_length: Approximate character limit for summary Returns: 1-2 sentence summary of key decisions and reasoning """ # Truncate content to avoid token limits (keep first 2000 chars) truncated = content[:2000] if len(content) > 2000 else content prompt = f"""Summarize the following trading decision in 1-2 sentences (max {max_length} characters), focusing on the key reasoning and actions taken: {truncated} Summary:""" try: # Use ainvoke for async call response = await self.model.ainvoke(prompt) # Extract content from response if hasattr(response, 'content'): summary = response.content.strip() elif isinstance(response, dict) and 'content' in response: summary = response['content'].strip() else: summary = str(response).strip() # Truncate if too long if len(summary) > max_length: summary = summary[:max_length-3] + "..." return summary except Exception as e: # If summary generation fails, return truncated original return truncated[:max_length-3] + "..." def generate_summary_sync(self, content: str, max_length: int = 200) -> str: """ Synchronous wrapper for generate_summary. Args: content: Full reasoning content to summarize max_length: Approximate character limit for summary Returns: Summary string """ import asyncio try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete(self.generate_summary(content, max_length)) async def _ainvoke_with_retry(self, message: List[Dict[str, str]]) -> Any: """Agent invocation with retry""" for attempt in range(1, self.max_retries + 1): try: return await self.agent.ainvoke( {"messages": message}, {"recursion_limit": 100} ) except Exception as e: if attempt == self.max_retries: raise e print(f"âš ī¸ Attempt {attempt} failed, retrying after {self.base_delay * attempt} seconds...") print(f"Error details: {e}") await asyncio.sleep(self.base_delay * attempt) async def run_trading_session(self, today_date: str) -> None: """ Run single day trading session Args: today_date: Trading date """ print(f"📈 Starting trading session: {today_date}") # Update context injector with current trading date if self.context_injector: self.context_injector.today_date = today_date # Clear conversation history for new trading day self.clear_conversation_history() # Update mock model date if in dev mode if is_dev_mode(): self.model.date = today_date # Get system prompt system_prompt = get_agent_system_prompt(today_date, self.signature) # Update agent with system prompt self.agent = create_agent( self.model, tools=self.tools, system_prompt=system_prompt, ) # Capture user prompt user_prompt = f"Please analyze and update today's ({today_date}) positions." self._capture_message("user", user_prompt) # Initial user query user_query = [{"role": "user", "content": user_prompt}] message = user_query.copy() # Trading loop current_step = 0 while current_step < self.max_steps: current_step += 1 print(f"🔄 Step {current_step}/{self.max_steps}") try: # Call agent response = await self._ainvoke_with_retry(message) # Extract agent response agent_response = extract_conversation(response, "final") # Capture assistant response self._capture_message("assistant", agent_response) # Check stop signal if STOP_SIGNAL in agent_response: print("✅ Received stop signal, trading session ended") print(agent_response) break # Extract tool messages tool_msgs = extract_tool_messages(response) tool_response = '\n'.join([msg.content for msg in tool_msgs]) # Prepare new messages new_messages = [ {"role": "assistant", "content": agent_response}, {"role": "user", "content": f'Tool results: {tool_response}'} ] # Add new messages message.extend(new_messages) except Exception as e: print(f"❌ Trading session error: {str(e)}") print(f"Error details: {e}") raise # Handle trading results await self._handle_trading_result(today_date) async def _handle_trading_result(self, today_date: str) -> None: """Handle trading results with database writes.""" from tools.price_tools import add_no_trade_record_to_db if_trade = get_config_value("IF_TRADE") if if_trade: write_config_value("IF_TRADE", False) print("✅ Trading completed") else: print("📊 No trading, maintaining positions") # Get context from runtime config job_id = get_config_value("JOB_ID") session_id = self.context_injector.session_id if self.context_injector else None if not job_id or not session_id: raise ValueError("Missing JOB_ID or session_id for no-trade record") # Write no-trade record to database add_no_trade_record_to_db( today_date, self.signature, job_id, session_id ) write_config_value("IF_TRADE", False) def register_agent(self) -> None: """Register new agent, create initial positions""" # Check if position.jsonl file already exists if os.path.exists(self.position_file): print(f"âš ī¸ Position file {self.position_file} already exists, skipping registration") return # Ensure directory structure exists position_dir = os.path.join(self.data_path, "position") if not os.path.exists(position_dir): os.makedirs(position_dir) print(f"📁 Created position directory: {position_dir}") # Create initial positions init_position = {symbol: 0 for symbol in self.stock_symbols} init_position['CASH'] = self.initial_cash with open(self.position_file, "w") as f: # Use "w" mode to ensure creating new file f.write(json.dumps({ "date": self.init_date, "id": 0, "positions": init_position }) + "\n") print(f"✅ Agent {self.signature} registration completed") print(f"📁 Position file: {self.position_file}") print(f"💰 Initial cash: ${self.initial_cash}") print(f"📊 Number of stocks: {len(self.stock_symbols)}") def get_trading_dates(self, init_date: str, end_date: str) -> List[str]: """ Get trading date list Args: init_date: Start date end_date: End date Returns: List of trading dates """ dates = [] max_date = None if not os.path.exists(self.position_file): self.register_agent() max_date = init_date else: # Read existing position file, find latest date with open(self.position_file, "r") as f: for line in f: doc = json.loads(line) current_date = doc['date'] if max_date is None: max_date = current_date else: current_date_obj = datetime.strptime(current_date, "%Y-%m-%d") max_date_obj = datetime.strptime(max_date, "%Y-%m-%d") if current_date_obj > max_date_obj: max_date = current_date # Check if new dates need to be processed max_date_obj = datetime.strptime(max_date, "%Y-%m-%d") end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") if end_date_obj <= max_date_obj: return [] # Generate trading date list trading_dates = [] current_date = max_date_obj + timedelta(days=1) while current_date <= end_date_obj: if current_date.weekday() < 5: # Weekdays trading_dates.append(current_date.strftime("%Y-%m-%d")) current_date += timedelta(days=1) return trading_dates async def run_with_retry(self, today_date: str) -> None: """Run method with retry""" for attempt in range(1, self.max_retries + 1): try: print(f"🔄 Attempting to run {self.signature} - {today_date} (Attempt {attempt})") await self.run_trading_session(today_date) print(f"✅ {self.signature} - {today_date} run successful") return except Exception as e: print(f"❌ Attempt {attempt} failed: {str(e)}") if attempt == self.max_retries: print(f"đŸ’Ĩ {self.signature} - {today_date} all retries failed") raise else: wait_time = self.base_delay * attempt print(f"âŗ Waiting {wait_time} seconds before retry...") await asyncio.sleep(wait_time) async def run_date_range(self, init_date: str, end_date: str) -> None: """ Run all trading days in date range Args: init_date: Start date end_date: End date """ print(f"📅 Running date range: {init_date} to {end_date}") # Get trading date list trading_dates = self.get_trading_dates(init_date, end_date) if not trading_dates: print(f"â„šī¸ No trading days to process") return print(f"📊 Trading days to process: {trading_dates}") # Process each trading day for date in trading_dates: print(f"🔄 Processing {self.signature} - Date: {date}") # Set configuration write_config_value("TODAY_DATE", date) write_config_value("SIGNATURE", self.signature) try: await self.run_with_retry(date) except Exception as e: print(f"❌ Error processing {self.signature} - Date: {date}") print(e) raise print(f"✅ {self.signature} processing completed") def get_position_summary(self) -> Dict[str, Any]: """Get position summary""" if not os.path.exists(self.position_file): return {"error": "Position file does not exist"} positions = [] with open(self.position_file, "r") as f: for line in f: positions.append(json.loads(line)) if not positions: return {"error": "No position records"} latest_position = positions[-1] return { "signature": self.signature, "latest_date": latest_position.get("date"), "positions": latest_position.get("positions", {}), "total_records": len(positions) } def __str__(self) -> str: return f"BaseAgent(signature='{self.signature}', basemodel='{self.basemodel}', stocks={len(self.stock_symbols)})" def __repr__(self) -> str: return self.__str__()