mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-11 13:17:23 -04:00
fix: inject signature and today_date into trade tool calls for concurrent simulations
Resolves issue where MCP trade tools couldn't access SIGNATURE and TODAY_DATE during concurrent API simulations, causing "SIGNATURE environment variable is not set" errors. Problem: - MCP services run as separate HTTP processes - Multiple simulations execute concurrently via ThreadPoolExecutor - Environment variables from executor process not accessible to MCP services Solution: - Add ContextInjector that implements ToolCallInterceptor - Automatically injects signature and today_date into buy/sell tool calls - Trade tools accept optional parameters, falling back to config/env - BaseAgent creates interceptor and updates today_date per session Changes: - agent/context_injector.py: New interceptor for context injection - agent/base_agent/base_agent.py: Create and use ContextInjector - agent_tools/tool_trade.py: Add optional signature/today_date parameters Benefits: - Supports concurrent multi-model simulations - Maintains backward compatibility with CLI mode - AI model unaware of injected parameters
This commit is contained in:
@@ -29,6 +29,7 @@ from tools.deployment_config import (
|
|||||||
log_api_key_warning,
|
log_api_key_warning,
|
||||||
get_deployment_mode
|
get_deployment_mode
|
||||||
)
|
)
|
||||||
|
from agent.context_injector import ContextInjector
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -125,6 +126,9 @@ class BaseAgent:
|
|||||||
self.model: Optional[ChatOpenAI] = None
|
self.model: Optional[ChatOpenAI] = None
|
||||||
self.agent: Optional[Any] = None
|
self.agent: Optional[Any] = None
|
||||||
|
|
||||||
|
# Context injector for MCP tools
|
||||||
|
self.context_injector: Optional[ContextInjector] = None
|
||||||
|
|
||||||
# Data paths
|
# Data paths
|
||||||
self.data_path = os.path.join(self.base_log_path, self.signature)
|
self.data_path = os.path.join(self.base_log_path, self.signature)
|
||||||
self.position_file = os.path.join(self.data_path, "position", "position.jsonl")
|
self.position_file = os.path.join(self.data_path, "position", "position.jsonl")
|
||||||
@@ -169,16 +173,27 @@ class BaseAgent:
|
|||||||
print("⚠️ OpenAI base URL not set, using default")
|
print("⚠️ OpenAI base URL not set, using default")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create MCP client
|
# Create context injector for injecting signature and today_date into tool calls
|
||||||
self.client = MultiServerMCPClient(self.mcp_config)
|
self.context_injector = ContextInjector(
|
||||||
|
signature=self.signature,
|
||||||
|
today_date=self.init_date # Will be updated per trading session
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create MCP client with interceptor
|
||||||
|
self.client = MultiServerMCPClient(
|
||||||
|
self.mcp_config,
|
||||||
|
tool_interceptors=[self.context_injector]
|
||||||
|
)
|
||||||
|
|
||||||
# Get tools
|
# Get tools
|
||||||
self.tools = await self.client.get_tools()
|
raw_tools = await self.client.get_tools()
|
||||||
if not self.tools:
|
if not raw_tools:
|
||||||
print("⚠️ Warning: No MCP tools loaded. MCP services may not be running.")
|
print("⚠️ Warning: No MCP tools loaded. MCP services may not be running.")
|
||||||
print(f" MCP configuration: {self.mcp_config}")
|
print(f" MCP configuration: {self.mcp_config}")
|
||||||
|
self.tools = []
|
||||||
else:
|
else:
|
||||||
print(f"✅ Loaded {len(self.tools)} MCP tools")
|
print(f"✅ Loaded {len(raw_tools)} MCP tools")
|
||||||
|
self.tools = raw_tools
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"❌ Failed to initialize MCP client: {e}\n"
|
f"❌ Failed to initialize MCP client: {e}\n"
|
||||||
@@ -336,6 +351,10 @@ Summary:"""
|
|||||||
"""
|
"""
|
||||||
print(f"📈 Starting trading session: {today_date}")
|
print(f"📈 Starting trading session: {today_date}")
|
||||||
|
|
||||||
|
# Update context injector with current trading date
|
||||||
|
if self.context_injector:
|
||||||
|
self.context_injector.today_date = today_date
|
||||||
|
|
||||||
# Clear conversation history for new trading day
|
# Clear conversation history for new trading day
|
||||||
self.clear_conversation_history()
|
self.clear_conversation_history()
|
||||||
|
|
||||||
|
|||||||
50
agent/context_injector.py
Normal file
50
agent/context_injector.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
Tool interceptor for injecting runtime context into MCP tool calls.
|
||||||
|
|
||||||
|
This interceptor automatically injects `signature` and `today_date` parameters
|
||||||
|
into buy/sell tool calls to support concurrent multi-model simulations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class ContextInjector:
|
||||||
|
"""
|
||||||
|
Intercepts tool calls to inject runtime context (signature, today_date).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
interceptor = ContextInjector(signature="gpt-5", today_date="2025-10-01")
|
||||||
|
client = MultiServerMCPClient(config, tool_interceptors=[interceptor])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, signature: str, today_date: str):
|
||||||
|
"""
|
||||||
|
Initialize context injector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signature: Model signature to inject
|
||||||
|
today_date: Trading date to inject
|
||||||
|
"""
|
||||||
|
self.signature = signature
|
||||||
|
self.today_date = today_date
|
||||||
|
|
||||||
|
def __call__(self, tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Intercept tool call and inject context parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool being called
|
||||||
|
tool_input: Original tool input parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified tool input with injected context
|
||||||
|
"""
|
||||||
|
# Only inject for trade tools (buy/sell)
|
||||||
|
if tool_name in ["buy", "sell"]:
|
||||||
|
# Inject signature and today_date if not already provided
|
||||||
|
if "signature" not in tool_input:
|
||||||
|
tool_input["signature"] = self.signature
|
||||||
|
if "today_date" not in tool_input:
|
||||||
|
tool_input["today_date"] = self.today_date
|
||||||
|
|
||||||
|
return tool_input
|
||||||
@@ -13,7 +13,7 @@ mcp = FastMCP("TradeTools")
|
|||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def buy(symbol: str, amount: int) -> Dict[str, Any]:
|
def buy(symbol: str, amount: int, signature: str = None, today_date: str = None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Buy stock function
|
Buy stock function
|
||||||
|
|
||||||
@@ -27,6 +27,8 @@ def buy(symbol: str, amount: int) -> Dict[str, Any]:
|
|||||||
Args:
|
Args:
|
||||||
symbol: Stock symbol, such as "AAPL", "MSFT", etc.
|
symbol: Stock symbol, such as "AAPL", "MSFT", etc.
|
||||||
amount: Buy quantity, must be a positive integer, indicating how many shares to buy
|
amount: Buy quantity, must be a positive integer, indicating how many shares to buy
|
||||||
|
signature: Model signature (optional, will use config/env if not provided)
|
||||||
|
today_date: Trading date (optional, will use config/env if not provided)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]:
|
Dict[str, Any]:
|
||||||
@@ -41,13 +43,15 @@ def buy(symbol: str, amount: int) -> Dict[str, Any]:
|
|||||||
>>> print(result) # {"AAPL": 110, "MSFT": 5, "CASH": 5000.0, ...}
|
>>> print(result) # {"AAPL": 110, "MSFT": 5, "CASH": 5000.0, ...}
|
||||||
"""
|
"""
|
||||||
# Step 1: Get environment variables and basic information
|
# Step 1: Get environment variables and basic information
|
||||||
# Get signature (model name) from environment variable, used to determine data storage path
|
# Get signature (model name) from parameter or fallback to config/env
|
||||||
signature = get_config_value("SIGNATURE")
|
|
||||||
if signature is None:
|
if signature is None:
|
||||||
raise ValueError("SIGNATURE environment variable is not set")
|
signature = get_config_value("SIGNATURE")
|
||||||
|
if signature is None:
|
||||||
|
raise ValueError("SIGNATURE not provided and environment variable is not set")
|
||||||
|
|
||||||
# Get current trading date from environment variable
|
# Get current trading date from parameter or fallback to config/env
|
||||||
today_date = get_config_value("TODAY_DATE")
|
if today_date is None:
|
||||||
|
today_date = get_config_value("TODAY_DATE")
|
||||||
|
|
||||||
# Step 2: Get current latest position and operation ID
|
# Step 2: Get current latest position and operation ID
|
||||||
# get_latest_position returns two values: position dictionary and current maximum operation ID
|
# get_latest_position returns two values: position dictionary and current maximum operation ID
|
||||||
@@ -104,7 +108,7 @@ def buy(symbol: str, amount: int) -> Dict[str, Any]:
|
|||||||
return new_position
|
return new_position
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def sell(symbol: str, amount: int) -> Dict[str, Any]:
|
def sell(symbol: str, amount: int, signature: str = None, today_date: str = None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Sell stock function
|
Sell stock function
|
||||||
|
|
||||||
@@ -118,6 +122,8 @@ def sell(symbol: str, amount: int) -> Dict[str, Any]:
|
|||||||
Args:
|
Args:
|
||||||
symbol: Stock symbol, such as "AAPL", "MSFT", etc.
|
symbol: Stock symbol, such as "AAPL", "MSFT", etc.
|
||||||
amount: Sell quantity, must be a positive integer, indicating how many shares to sell
|
amount: Sell quantity, must be a positive integer, indicating how many shares to sell
|
||||||
|
signature: Model signature (optional, will use config/env if not provided)
|
||||||
|
today_date: Trading date (optional, will use config/env if not provided)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]:
|
Dict[str, Any]:
|
||||||
@@ -132,13 +138,15 @@ def sell(symbol: str, amount: int) -> Dict[str, Any]:
|
|||||||
>>> print(result) # {"AAPL": 90, "MSFT": 5, "CASH": 15000.0, ...}
|
>>> print(result) # {"AAPL": 90, "MSFT": 5, "CASH": 15000.0, ...}
|
||||||
"""
|
"""
|
||||||
# Step 1: Get environment variables and basic information
|
# Step 1: Get environment variables and basic information
|
||||||
# Get signature (model name) from environment variable, used to determine data storage path
|
# Get signature (model name) from parameter or fallback to config/env
|
||||||
signature = get_config_value("SIGNATURE")
|
|
||||||
if signature is None:
|
if signature is None:
|
||||||
raise ValueError("SIGNATURE environment variable is not set")
|
signature = get_config_value("SIGNATURE")
|
||||||
|
if signature is None:
|
||||||
|
raise ValueError("SIGNATURE not provided and environment variable is not set")
|
||||||
|
|
||||||
# Get current trading date from environment variable
|
# Get current trading date from parameter or fallback to config/env
|
||||||
today_date = get_config_value("TODAY_DATE")
|
if today_date is None:
|
||||||
|
today_date = get_config_value("TODAY_DATE")
|
||||||
|
|
||||||
# Step 2: Get current latest position and operation ID
|
# Step 2: Get current latest position and operation ID
|
||||||
# get_latest_position returns two values: position dictionary and current maximum operation ID
|
# get_latest_position returns two values: position dictionary and current maximum operation ID
|
||||||
|
|||||||
Reference in New Issue
Block a user