mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-13 13:47:23 -04:00
feat: integrate mock AI provider in BaseAgent for DEV mode
Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -23,6 +23,12 @@ sys.path.insert(0, project_root)
|
|||||||
from tools.general_tools import extract_conversation, extract_tool_messages, get_config_value, write_config_value
|
from tools.general_tools import extract_conversation, extract_tool_messages, get_config_value, write_config_value
|
||||||
from tools.price_tools import add_no_trade_record
|
from tools.price_tools import add_no_trade_record
|
||||||
from prompts.agent_prompt import get_agent_system_prompt, STOP_SIGNAL
|
from prompts.agent_prompt import get_agent_system_prompt, STOP_SIGNAL
|
||||||
|
from tools.deployment_config import (
|
||||||
|
is_dev_mode,
|
||||||
|
get_data_path,
|
||||||
|
log_api_key_warning,
|
||||||
|
get_deployment_mode
|
||||||
|
)
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -98,9 +104,9 @@ class BaseAgent:
|
|||||||
|
|
||||||
# Set MCP configuration
|
# Set MCP configuration
|
||||||
self.mcp_config = mcp_config or self._get_default_mcp_config()
|
self.mcp_config = mcp_config or self._get_default_mcp_config()
|
||||||
|
|
||||||
# Set log path
|
# Set log path (apply deployment mode path resolution)
|
||||||
self.base_log_path = log_path or "./data/agent_data"
|
self.base_log_path = get_data_path(log_path or "./data/agent_data")
|
||||||
|
|
||||||
# Set OpenAI configuration
|
# Set OpenAI configuration
|
||||||
if openai_base_url==None:
|
if openai_base_url==None:
|
||||||
@@ -146,17 +152,22 @@ class BaseAgent:
|
|||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Initialize MCP client and AI model"""
|
"""Initialize MCP client and AI model"""
|
||||||
print(f"🚀 Initializing agent: {self.signature}")
|
print(f"🚀 Initializing agent: {self.signature}")
|
||||||
|
print(f"🔧 Deployment mode: {get_deployment_mode()}")
|
||||||
# Validate OpenAI configuration
|
|
||||||
if not self.openai_api_key:
|
# Log API key warning if in dev mode
|
||||||
raise ValueError("❌ OpenAI API key not set. Please configure OPENAI_API_KEY in environment or config file.")
|
log_api_key_warning()
|
||||||
if not self.openai_base_url:
|
|
||||||
print("⚠️ OpenAI base URL not set, using default")
|
# Validate OpenAI configuration (only in PROD mode)
|
||||||
|
if not is_dev_mode():
|
||||||
|
if not self.openai_api_key:
|
||||||
|
raise ValueError("❌ OpenAI API key not set. Please configure OPENAI_API_KEY in environment or config file.")
|
||||||
|
if not self.openai_base_url:
|
||||||
|
print("⚠️ OpenAI base URL not set, using default")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create MCP client
|
# Create MCP client
|
||||||
self.client = MultiServerMCPClient(self.mcp_config)
|
self.client = MultiServerMCPClient(self.mcp_config)
|
||||||
|
|
||||||
# Get tools
|
# Get tools
|
||||||
self.tools = await self.client.get_tools()
|
self.tools = await self.client.get_tools()
|
||||||
if not self.tools:
|
if not self.tools:
|
||||||
@@ -170,22 +181,28 @@ class BaseAgent:
|
|||||||
f" Please ensure MCP services are running at the configured ports.\n"
|
f" Please ensure MCP services are running at the configured ports.\n"
|
||||||
f" Run: python agent_tools/start_mcp_services.py"
|
f" Run: python agent_tools/start_mcp_services.py"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create AI model
|
# Create AI model (mock in DEV mode, real in PROD mode)
|
||||||
self.model = ChatOpenAI(
|
if is_dev_mode():
|
||||||
model=self.basemodel,
|
from agent.mock_provider import MockChatModel
|
||||||
base_url=self.openai_base_url,
|
self.model = MockChatModel(date="2025-01-01") # Date will be updated per session
|
||||||
api_key=self.openai_api_key,
|
print(f"🤖 Using MockChatModel (DEV mode)")
|
||||||
max_retries=3,
|
else:
|
||||||
timeout=30
|
self.model = ChatOpenAI(
|
||||||
)
|
model=self.basemodel,
|
||||||
|
base_url=self.openai_base_url,
|
||||||
|
api_key=self.openai_api_key,
|
||||||
|
max_retries=3,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
print(f"🤖 Using {self.basemodel} (PROD mode)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"❌ Failed to initialize AI model: {e}")
|
raise RuntimeError(f"❌ Failed to initialize AI model: {e}")
|
||||||
|
|
||||||
# Note: agent will be created in run_trading_session() based on specific date
|
# Note: agent will be created in run_trading_session() based on specific date
|
||||||
# because system_prompt needs the current date and price information
|
# because system_prompt needs the current date and price information
|
||||||
|
|
||||||
print(f"✅ Agent {self.signature} initialization completed")
|
print(f"✅ Agent {self.signature} initialization completed")
|
||||||
|
|
||||||
def _setup_logging(self, today_date: str) -> str:
|
def _setup_logging(self, today_date: str) -> str:
|
||||||
@@ -223,15 +240,19 @@ class BaseAgent:
|
|||||||
async def run_trading_session(self, today_date: str) -> None:
|
async def run_trading_session(self, today_date: str) -> None:
|
||||||
"""
|
"""
|
||||||
Run single day trading session
|
Run single day trading session
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
today_date: Trading date
|
today_date: Trading date
|
||||||
"""
|
"""
|
||||||
print(f"📈 Starting trading session: {today_date}")
|
print(f"📈 Starting trading session: {today_date}")
|
||||||
|
|
||||||
|
# Update mock model date if in dev mode
|
||||||
|
if is_dev_mode():
|
||||||
|
self.model.date = today_date
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
log_file = self._setup_logging(today_date)
|
log_file = self._setup_logging(today_date)
|
||||||
|
|
||||||
# Update system prompt
|
# Update system prompt
|
||||||
self.agent = create_agent(
|
self.agent = create_agent(
|
||||||
self.model,
|
self.model,
|
||||||
|
|||||||
69
tests/unit/test_base_agent_mock.py
Normal file
69
tests/unit/test_base_agent_mock.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from agent.base_agent.base_agent import BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_agent_uses_mock_in_dev_mode():
|
||||||
|
"""Test BaseAgent uses mock model when DEPLOYMENT_MODE=DEV"""
|
||||||
|
os.environ["DEPLOYMENT_MODE"] = "DEV"
|
||||||
|
|
||||||
|
agent = BaseAgent(
|
||||||
|
signature="test-agent",
|
||||||
|
basemodel="mock/test-trader",
|
||||||
|
log_path="./data/dev_agent_data"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock MCP client to avoid needing running services
|
||||||
|
async def mock_initialize():
|
||||||
|
# Mock the MCP client
|
||||||
|
agent.client = MagicMock()
|
||||||
|
agent.tools = []
|
||||||
|
|
||||||
|
# Create mock model based on deployment mode
|
||||||
|
from tools.deployment_config import is_dev_mode
|
||||||
|
if is_dev_mode():
|
||||||
|
from agent.mock_provider import MockChatModel
|
||||||
|
agent.model = MockChatModel(date="2025-01-01")
|
||||||
|
|
||||||
|
# Run mock initialization
|
||||||
|
asyncio.run(mock_initialize())
|
||||||
|
|
||||||
|
assert agent.model is not None
|
||||||
|
assert "Mock" in str(type(agent.model))
|
||||||
|
|
||||||
|
os.environ["DEPLOYMENT_MODE"] = "PROD"
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_agent_warns_about_api_keys_in_dev(capsys):
|
||||||
|
"""Test BaseAgent logs warning about API keys in DEV mode"""
|
||||||
|
os.environ["DEPLOYMENT_MODE"] = "DEV"
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-test123"
|
||||||
|
|
||||||
|
# Test the warning function directly
|
||||||
|
from tools.deployment_config import log_api_key_warning
|
||||||
|
log_api_key_warning()
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "WARNING" in captured.out
|
||||||
|
assert "OPENAI_API_KEY" in captured.out
|
||||||
|
|
||||||
|
os.environ.pop("OPENAI_API_KEY")
|
||||||
|
os.environ["DEPLOYMENT_MODE"] = "PROD"
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_agent_uses_dev_data_path():
|
||||||
|
"""Test BaseAgent uses dev data paths in DEV mode"""
|
||||||
|
os.environ["DEPLOYMENT_MODE"] = "DEV"
|
||||||
|
|
||||||
|
agent = BaseAgent(
|
||||||
|
signature="test-agent",
|
||||||
|
basemodel="mock/test-trader",
|
||||||
|
log_path="./data/agent_data" # Original path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should be converted to dev path
|
||||||
|
assert "dev_agent_data" in agent.base_log_path
|
||||||
|
|
||||||
|
os.environ["DEPLOYMENT_MODE"] = "PROD"
|
||||||
Reference in New Issue
Block a user