diff --git a/agent/base_agent/base_agent.py b/agent/base_agent/base_agent.py index fc67909..9bf4386 100644 --- a/agent/base_agent/base_agent.py +++ b/agent/base_agent/base_agent.py @@ -66,6 +66,7 @@ class BaseAgent: max_retries: int = 3, base_delay: float = 0.5, openai_base_url: Optional[str] = None, + openai_api_key: Optional[str] = None, initial_cash: float = 10000.0, init_date: str = "2025-10-13" ): @@ -82,6 +83,7 @@ class BaseAgent: max_retries: Maximum retry attempts base_delay: Base delay time for retries openai_base_url: OpenAI API base URL + openai_api_key: OpenAI API key initial_cash: Initial cash amount init_date: Initialization date """ @@ -101,7 +103,14 @@ class BaseAgent: self.base_log_path = log_path or "./data/agent_data" # Set OpenAI configuration - self.openai_base_url = openai_base_url or os.getenv("OPENAI_API_BASE") + if openai_base_url==None: + self.openai_base_url = os.getenv("OPENAI_API_BASE") + else: + self.openai_base_url = openai_base_url + if openai_api_key==None: + self.openai_api_key = os.getenv("OPENAI_API_KEY") + else: + self.openai_api_key = openai_api_key # Initialize components self.client: Optional[MultiServerMCPClient] = None @@ -149,6 +158,7 @@ class BaseAgent: self.model = ChatOpenAI( model=self.basemodel, base_url=self.openai_base_url, + api_key=self.openai_api_key, max_retries=3, timeout=30 ) diff --git a/configs/default_config.json b/configs/default_config.json index bdd3c02..d098c89 100644 --- a/configs/default_config.json +++ b/configs/default_config.json @@ -7,9 +7,11 @@ "models": [ { "name": "claude-3.7-sonnet", - "basemodel": "anthropic/claude-3.7-sonnet", + "basemodel": "anthropic/claude-3.7-sonnet", "signature": "claude-3.7-sonnet", - "enabled": false + "enabled": false, + "openai_base_url": "Optional: YOUR_OPENAI_BASE_URL,you can write them in .env file", + "openai_api_key": "Optional: YOUR_OPENAI_API_KEY,you can write them in .env file" }, { "name": "deepseek-chat-v3.1", diff --git a/main.py b/main.py index 9c9f251..576245c 100644 --- a/main.py +++ b/main.py @@ -155,7 +155,9 @@ async def main(config_path=None): model_name = model_config.get("name", "unknown") basemodel = model_config.get("basemodel") signature = model_config.get("signature") - + openai_base_url = model_config.get("openai_base_url",None) + openai_api_key = model_config.get("openai_api_key",None) + # Validate required fields if not basemodel: print(f"❌ Model {model_name} missing basemodel field") @@ -185,6 +187,8 @@ async def main(config_path=None): basemodel=basemodel, stock_symbols=all_nasdaq_100_symbols, log_path=log_path, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, max_steps=max_steps, max_retries=max_retries, base_delay=base_delay,