diff --git a/agent/base_agent/base_agent.py b/agent/base_agent/base_agent.py index 9bf4386..0112836 100644 --- a/agent/base_agent/base_agent.py +++ b/agent/base_agent/base_agent.py @@ -147,21 +147,41 @@ class BaseAgent: """Initialize MCP client and AI model""" print(f"🚀 Initializing agent: {self.signature}") - # Create MCP client - self.client = MultiServerMCPClient(self.mcp_config) + # Validate OpenAI configuration + if not self.openai_api_key: + raise ValueError("❌ OpenAI API key not set. Please configure OPENAI_API_KEY in environment or config file.") + if not self.openai_base_url: + print("⚠️ OpenAI base URL not set, using default") - # Get tools - self.tools = await self.client.get_tools() - print(f"✅ Loaded {len(self.tools)} MCP tools") + try: + # Create MCP client + self.client = MultiServerMCPClient(self.mcp_config) + + # Get tools + self.tools = await self.client.get_tools() + if not self.tools: + print("⚠️ Warning: No MCP tools loaded. MCP services may not be running.") + print(f" MCP configuration: {self.mcp_config}") + else: + print(f"✅ Loaded {len(self.tools)} MCP tools") + except Exception as e: + raise RuntimeError( + f"❌ Failed to initialize MCP client: {e}\n" + f" Please ensure MCP services are running at the configured ports.\n" + f" Run: python agent_tools/start_mcp_services.py" + ) - # Create AI model - self.model = ChatOpenAI( - model=self.basemodel, - base_url=self.openai_base_url, - api_key=self.openai_api_key, - max_retries=3, - timeout=30 - ) + try: + # Create AI model + self.model = ChatOpenAI( + model=self.basemodel, + base_url=self.openai_base_url, + api_key=self.openai_api_key, + max_retries=3, + timeout=30 + ) + except Exception as e: + raise RuntimeError(f"❌ Failed to initialize AI model: {e}") # Note: agent will be created in run_trading_session() based on specific date # because system_prompt needs the current date and price information diff --git a/tools/general_tools.py b/tools/general_tools.py index d7c5f11..f68596f 100644 --- a/tools/general_tools.py +++ b/tools/general_tools.py @@ -2,11 +2,14 @@ import os import json from pathlib import Path +from typing import Any from dotenv import load_dotenv load_dotenv() def _load_runtime_env() -> dict: path = os.environ.get("RUNTIME_ENV_PATH") + if path is None: + return {} try: if os.path.exists(path): with open(path, "r", encoding="utf-8") as f: @@ -25,12 +28,18 @@ def get_config_value(key: str, default=None): return _RUNTIME_ENV[key] return os.getenv(key, default) -def write_config_value(key: str, value: any): +def write_config_value(key: str, value: Any): + path = os.environ.get("RUNTIME_ENV_PATH") + if path is None: + print(f"⚠️ WARNING: RUNTIME_ENV_PATH not set, config value '{key}' not persisted") + return _RUNTIME_ENV = _load_runtime_env() _RUNTIME_ENV[key] = value - path = os.environ.get("RUNTIME_ENV_PATH") - with open(path, "w", encoding="utf-8") as f: - json.dump(_RUNTIME_ENV, f, ensure_ascii=False, indent=4) + try: + with open(path, "w", encoding="utf-8") as f: + json.dump(_RUNTIME_ENV, f, ensure_ascii=False, indent=4) + except Exception as e: + print(f"❌ Error writing config to {path}: {e}") def extract_conversation(conversation: dict, output_type: str): """Extract information from a conversation payload. diff --git a/tools/price_tools.py b/tools/price_tools.py index a14af8f..f70ddda 100644 --- a/tools/price_tools.py +++ b/tools/price_tools.py @@ -4,7 +4,7 @@ load_dotenv() import json from datetime import datetime, timedelta from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import sys # 将项目根目录加入 Python 路径,便于从子目录直接运行本文件 @@ -95,7 +95,7 @@ def get_open_prices(today_date: str, symbols: List[str], merged_path: Optional[s return results -def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None) -> tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]: +def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]: """从 data/merged.jsonl 中读取指定日期与股票的昨日买入价和卖出价。 Args: @@ -260,7 +260,7 @@ def get_today_init_position(today_date: str, modelname: str) -> Dict[str, float] return latest_positions -def get_latest_position(today_date: str, modelname: str) -> Dict[str, float]: +def get_latest_position(today_date: str, modelname: str) -> Tuple[Dict[str, float], int]: """ 获取最新持仓。从 ../data/agent_data/{modelname}/position/position.jsonl 中读取。 优先选择当天 (today_date) 中 id 最大的记录; @@ -273,7 +273,7 @@ def get_latest_position(today_date: str, modelname: str) -> Dict[str, float]: Returns: (positions, max_id): - positions: {symbol: weight} 的字典;若未找到任何记录,则为空字典。 - - max_id: 选中记录的最大 id;若未找到任何记录,则为 -1。 + - max_id: 选中记录的最大 id;若未找到任何记录,则为 -1. """ base_dir = Path(__file__).resolve().parents[1] position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"