Files
AI-Trader/agent/base_agent/base_agent.py
2025-10-27 21:37:46 +08:00

447 lines
17 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
BaseAgent class - Base class for trading agents
Encapsulates core functionality including MCP tool management, AI agent creation, and trading execution
"""
import os
import json
import asyncio
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from pathlib import Path
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent
from dotenv import load_dotenv
# Import project tools
import sys
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)
from tools.general_tools import extract_conversation, extract_tool_messages, get_config_value, write_config_value
from tools.price_tools import add_no_trade_record
from prompts.agent_prompt import get_agent_system_prompt, STOP_SIGNAL
# Load environment variables
load_dotenv()
class BaseAgent:
"""
Base class for trading agents
Main functionalities:
1. MCP tool management and connection
2. AI agent creation and configuration
3. Trading execution and decision loops
4. Logging and management
5. Position and configuration management
"""
# Default NASDAQ 100 stock symbols
DEFAULT_STOCK_SYMBOLS = [
"NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA",
"NFLX", "PLTR", "COST", "ASML", "AMD", "CSCO", "AZN", "TMUS", "MU", "LIN",
"PEP", "SHOP", "APP", "INTU", "AMAT", "LRCX", "PDD", "QCOM", "ARM", "INTC",
"BKNG", "AMGN", "TXN", "ISRG", "GILD", "KLAC", "PANW", "ADBE", "HON",
"CRWD", "CEG", "ADI", "ADP", "DASH", "CMCSA", "VRTX", "MELI", "SBUX",
"CDNS", "ORLY", "SNPS", "MSTR", "MDLZ", "ABNB", "MRVL", "CTAS", "TRI",
"MAR", "MNST", "CSX", "ADSK", "PYPL", "FTNT", "AEP", "WDAY", "REGN", "ROP",
"NXPI", "DDOG", "AXON", "ROST", "IDXX", "EA", "PCAR", "FAST", "EXC", "TTWO",
"XEL", "ZS", "PAYX", "WBD", "BKR", "CPRT", "CCEP", "FANG", "TEAM", "CHTR",
"KDP", "MCHP", "GEHC", "VRSK", "CTSH", "CSGP", "KHC", "ODFL", "DXCM", "TTD",
"ON", "BIIB", "LULU", "CDW", "GFS"
]
def __init__(
self,
signature: str,
basemodel: str,
stock_symbols: Optional[List[str]] = None,
mcp_config: Optional[Dict[str, Dict[str, Any]]] = None,
log_path: Optional[str] = None,
max_steps: int = 10,
max_retries: int = 3,
base_delay: float = 0.5,
openai_base_url: Optional[str] = None,
openai_api_key: Optional[str] = None,
initial_cash: float = 10000.0,
init_date: str = "2025-10-13"
):
"""
Initialize BaseAgent
Args:
signature: Agent signature/name
basemodel: Base model name
stock_symbols: List of stock symbols, defaults to NASDAQ 100
mcp_config: MCP tool configuration, including port and URL information
log_path: Log path, defaults to ./data/agent_data
max_steps: Maximum reasoning steps
max_retries: Maximum retry attempts
base_delay: Base delay time for retries
openai_base_url: OpenAI API base URL
openai_api_key: OpenAI API key
initial_cash: Initial cash amount
init_date: Initialization date
"""
self.signature = signature
self.basemodel = basemodel
self.stock_symbols = stock_symbols or self.DEFAULT_STOCK_SYMBOLS
self.max_steps = max_steps
self.max_retries = max_retries
self.base_delay = base_delay
self.initial_cash = initial_cash
self.init_date = init_date
# Set MCP configuration
self.mcp_config = mcp_config or self._get_default_mcp_config()
# Set log path
self.base_log_path = log_path or "./data/agent_data"
# Set OpenAI configuration
if openai_base_url==None:
self.openai_base_url = os.getenv("OPENAI_API_BASE")
else:
self.openai_base_url = openai_base_url
if openai_api_key==None:
self.openai_api_key = os.getenv("OPENAI_API_KEY")
else:
self.openai_api_key = openai_api_key
# Initialize components
self.client: Optional[MultiServerMCPClient] = None
self.tools: Optional[List] = None
self.model: Optional[ChatOpenAI] = None
self.agent: Optional[Any] = None
# Data paths
self.data_path = os.path.join(self.base_log_path, self.signature)
self.position_file = os.path.join(self.data_path, "position", "position.jsonl")
def _get_default_mcp_config(self) -> Dict[str, Dict[str, Any]]:
"""Get default MCP configuration"""
return {
"math": {
"transport": "streamable_http",
"url": f"http://localhost:{os.getenv('MATH_HTTP_PORT', '8000')}/mcp",
},
"stock_local": {
"transport": "streamable_http",
"url": f"http://localhost:{os.getenv('GETPRICE_HTTP_PORT', '8003')}/mcp",
},
"search": {
"transport": "streamable_http",
"url": f"http://localhost:{os.getenv('SEARCH_HTTP_PORT', '8001')}/mcp",
},
"trade": {
"transport": "streamable_http",
"url": f"http://localhost:{os.getenv('TRADE_HTTP_PORT', '8002')}/mcp",
},
}
async def initialize(self) -> None:
"""Initialize MCP client and AI model"""
print(f"🚀 Initializing agent: {self.signature}")
# Create MCP client
self.client = MultiServerMCPClient(self.mcp_config)
# Get tools
self.tools = await self.client.get_tools()
print(f"✅ Loaded {len(self.tools)} MCP tools")
# Create AI model
self.model = ChatOpenAI(
model=self.basemodel,
base_url=self.openai_base_url,
api_key=self.openai_api_key,
max_retries=3,
timeout=30
)
# Note: agent will be created in run_trading_session() based on specific date
# because system_prompt needs the current date and price information
print(f"✅ Agent {self.signature} initialization completed")
def _setup_logging(self, today_date: str) -> str:
"""Set up log file path"""
log_path = os.path.join(self.base_log_path, self.signature, 'log', today_date)
if not os.path.exists(log_path):
os.makedirs(log_path)
return os.path.join(log_path, "log.jsonl")
def _log_message(self, log_file: str, new_messages: List[Dict[str, str]]) -> None:
"""Log messages to log file"""
log_entry = {
"timestamp": datetime.now().isoformat(),
"signature": self.signature,
"new_messages": new_messages
}
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
async def _ainvoke_with_retry(self, message: List[Dict[str, str]]) -> Any:
"""Agent invocation with retry"""
for attempt in range(1, self.max_retries + 1):
try:
return await self.agent.ainvoke(
{"messages": message},
{"recursion_limit": 100}
)
except Exception as e:
if attempt == self.max_retries:
raise e
print(f"⚠️ Attempt {attempt} failed, retrying after {self.base_delay * attempt} seconds...")
print(f"Error details: {e}")
await asyncio.sleep(self.base_delay * attempt)
async def run_trading_session(self, today_date: str) -> None:
"""
Run single day trading session
Args:
today_date: Trading date
"""
print(f"📈 Starting trading session: {today_date}")
# Set up logging
log_file = self._setup_logging(today_date)
# Update system prompt
self.agent = create_agent(
self.model,
tools=self.tools,
system_prompt=get_agent_system_prompt(today_date, self.signature),
)
# Initial user query
user_query = [{"role": "user", "content": f"Please analyze and update today's ({today_date}) positions."}]
message = user_query.copy()
# Log initial message
self._log_message(log_file, user_query)
# Trading loop
current_step = 0
while current_step < self.max_steps:
current_step += 1
print(f"🔄 Step {current_step}/{self.max_steps}")
try:
# Call agent
response = await self._ainvoke_with_retry(message)
# Extract agent response
agent_response = extract_conversation(response, "final")
# Check stop signal
if STOP_SIGNAL in agent_response:
print("✅ Received stop signal, trading session ended")
print(agent_response)
self._log_message(log_file, [{"role": "assistant", "content": agent_response}])
break
# Extract tool messages
tool_msgs = extract_tool_messages(response)
tool_response = '\n'.join([msg.content for msg in tool_msgs])
# Prepare new messages
new_messages = [
{"role": "assistant", "content": agent_response},
{"role": "user", "content": f'Tool results: {tool_response}'}
]
# Add new messages
message.extend(new_messages)
# Log messages
self._log_message(log_file, new_messages[0])
self._log_message(log_file, new_messages[1])
except Exception as e:
print(f"❌ Trading session error: {str(e)}")
print(f"Error details: {e}")
raise
# Handle trading results
await self._handle_trading_result(today_date)
async def _handle_trading_result(self, today_date: str) -> None:
"""Handle trading results"""
if_trade = get_config_value("IF_TRADE")
if if_trade:
write_config_value("IF_TRADE", False)
print("✅ Trading completed")
else:
print("📊 No trading, maintaining positions")
try:
add_no_trade_record(today_date, self.signature)
except NameError as e:
print(f"❌ NameError: {e}")
raise
write_config_value("IF_TRADE", False)
def register_agent(self) -> None:
"""Register new agent, create initial positions"""
# Check if position.jsonl file already exists
if os.path.exists(self.position_file):
print(f"⚠️ Position file {self.position_file} already exists, skipping registration")
return
# Ensure directory structure exists
position_dir = os.path.join(self.data_path, "position")
if not os.path.exists(position_dir):
os.makedirs(position_dir)
print(f"📁 Created position directory: {position_dir}")
# Create initial positions
init_position = {symbol: 0 for symbol in self.stock_symbols}
init_position['CASH'] = self.initial_cash
with open(self.position_file, "w") as f: # Use "w" mode to ensure creating new file
f.write(json.dumps({
"date": self.init_date,
"id": 0,
"positions": init_position
}) + "\n")
print(f"✅ Agent {self.signature} registration completed")
print(f"📁 Position file: {self.position_file}")
print(f"💰 Initial cash: ${self.initial_cash}")
print(f"📊 Number of stocks: {len(self.stock_symbols)}")
def get_trading_dates(self, init_date: str, end_date: str) -> List[str]:
"""
Get trading date list
Args:
init_date: Start date
end_date: End date
Returns:
List of trading dates
"""
dates = []
max_date = None
if not os.path.exists(self.position_file):
self.register_agent()
max_date = init_date
else:
# Read existing position file, find latest date
with open(self.position_file, "r") as f:
for line in f:
doc = json.loads(line)
current_date = doc['date']
if max_date is None:
max_date = current_date
else:
current_date_obj = datetime.strptime(current_date, "%Y-%m-%d")
max_date_obj = datetime.strptime(max_date, "%Y-%m-%d")
if current_date_obj > max_date_obj:
max_date = current_date
# Check if new dates need to be processed
max_date_obj = datetime.strptime(max_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
if end_date_obj <= max_date_obj:
return []
# Generate trading date list
trading_dates = []
current_date = max_date_obj + timedelta(days=1)
while current_date <= end_date_obj:
if current_date.weekday() < 5: # Weekdays
trading_dates.append(current_date.strftime("%Y-%m-%d"))
current_date += timedelta(days=1)
return trading_dates
async def run_with_retry(self, today_date: str) -> None:
"""Run method with retry"""
for attempt in range(1, self.max_retries + 1):
try:
print(f"🔄 Attempting to run {self.signature} - {today_date} (Attempt {attempt})")
await self.run_trading_session(today_date)
print(f"{self.signature} - {today_date} run successful")
return
except Exception as e:
print(f"❌ Attempt {attempt} failed: {str(e)}")
if attempt == self.max_retries:
print(f"💥 {self.signature} - {today_date} all retries failed")
raise
else:
wait_time = self.base_delay * attempt
print(f"⏳ Waiting {wait_time} seconds before retry...")
await asyncio.sleep(wait_time)
async def run_date_range(self, init_date: str, end_date: str) -> None:
"""
Run all trading days in date range
Args:
init_date: Start date
end_date: End date
"""
print(f"📅 Running date range: {init_date} to {end_date}")
# Get trading date list
trading_dates = self.get_trading_dates(init_date, end_date)
if not trading_dates:
print(f" No trading days to process")
return
print(f"📊 Trading days to process: {trading_dates}")
# Process each trading day
for date in trading_dates:
print(f"🔄 Processing {self.signature} - Date: {date}")
# Set configuration
write_config_value("TODAY_DATE", date)
write_config_value("SIGNATURE", self.signature)
try:
await self.run_with_retry(date)
except Exception as e:
print(f"❌ Error processing {self.signature} - Date: {date}")
print(e)
raise
print(f"{self.signature} processing completed")
def get_position_summary(self) -> Dict[str, Any]:
"""Get position summary"""
if not os.path.exists(self.position_file):
return {"error": "Position file does not exist"}
positions = []
with open(self.position_file, "r") as f:
for line in f:
positions.append(json.loads(line))
if not positions:
return {"error": "No position records"}
latest_position = positions[-1]
return {
"signature": self.signature,
"latest_date": latest_position.get("date"),
"positions": latest_position.get("positions", {}),
"total_records": len(positions)
}
def __str__(self) -> str:
return f"BaseAgent(signature='{self.signature}', basemodel='{self.basemodel}', stocks={len(self.stock_symbols)})"
def __repr__(self) -> str:
return self.__str__()