mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 01:27:24 -04:00
Compare commits
7 Commits
v0.3.0-alp
...
v0.3.0-alp
| Author | SHA1 | Date | |
|---|---|---|---|
| e590cdc13b | |||
| c74747d1d4 | |||
| 96f6b78a93 | |||
| 6c395f740d | |||
| 618943b278 | |||
| 1c19eea29a | |||
| e968434062 |
@@ -221,7 +221,7 @@ class BaseAgent:
|
||||
|
||||
print(f"✅ Agent {self.signature} initialization completed")
|
||||
|
||||
def set_context(self, context_injector: "ContextInjector") -> None:
|
||||
async def set_context(self, context_injector: "ContextInjector") -> None:
|
||||
"""
|
||||
Inject ContextInjector after initialization.
|
||||
|
||||
@@ -232,14 +232,24 @@ class BaseAgent:
|
||||
context_injector: Configured ContextInjector instance with
|
||||
correct signature, today_date, job_id, session_id
|
||||
"""
|
||||
print(f"[DEBUG] set_context() ENTRY: Received context_injector with signature={context_injector.signature}, date={context_injector.today_date}, job_id={context_injector.job_id}, session_id={context_injector.session_id}")
|
||||
|
||||
self.context_injector = context_injector
|
||||
print(f"[DEBUG] set_context(): Set self.context_injector, id={id(self.context_injector)}")
|
||||
|
||||
# Recreate MCP client with the interceptor
|
||||
# Note: We need to recreate because MultiServerMCPClient doesn't have add_interceptor()
|
||||
print(f"[DEBUG] set_context(): Creating new MCP client with interceptor, id={id(context_injector)}")
|
||||
self.client = MultiServerMCPClient(
|
||||
self.mcp_config,
|
||||
tool_interceptors=[context_injector]
|
||||
)
|
||||
print(f"[DEBUG] set_context(): MCP client created")
|
||||
|
||||
# CRITICAL: Reload tools from new client so they use the interceptor
|
||||
print(f"[DEBUG] set_context(): Reloading tools...")
|
||||
self.tools = await self.client.get_tools()
|
||||
print(f"[DEBUG] set_context(): Tools reloaded, count={len(self.tools)}")
|
||||
|
||||
print(f"✅ Context injected: signature={context_injector.signature}, "
|
||||
f"date={context_injector.today_date}, job_id={context_injector.job_id}, "
|
||||
|
||||
@@ -49,14 +49,16 @@ class ContextInjector:
|
||||
"""
|
||||
# Inject context parameters for trade tools
|
||||
if request.name in ["buy", "sell"]:
|
||||
# Add signature and today_date to args if not present
|
||||
if "signature" not in request.args:
|
||||
request.args["signature"] = self.signature
|
||||
if "today_date" not in request.args:
|
||||
request.args["today_date"] = self.today_date
|
||||
if "job_id" not in request.args and self.job_id:
|
||||
# Debug: Log self attributes BEFORE injection
|
||||
print(f"[ContextInjector.__call__] ENTRY: id={id(self)}, self.signature={self.signature}, self.today_date={self.today_date}, self.job_id={self.job_id}, self.session_id={self.session_id}")
|
||||
print(f"[ContextInjector.__call__] Args BEFORE injection: {request.args}")
|
||||
|
||||
# ALWAYS inject/override context parameters (don't trust AI-provided values)
|
||||
request.args["signature"] = self.signature
|
||||
request.args["today_date"] = self.today_date
|
||||
if self.job_id:
|
||||
request.args["job_id"] = self.job_id
|
||||
if "session_id" not in request.args and self.session_id:
|
||||
if self.session_id:
|
||||
request.args["session_id"] = self.session_id
|
||||
|
||||
# Debug logging
|
||||
|
||||
@@ -82,24 +82,13 @@ def get_current_position_from_db(job_id: str, model: str, date: str) -> Tuple[Di
|
||||
conn.close()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
||||
def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Buy stock function - writes to SQLite database.
|
||||
Internal buy implementation - accepts injected context parameters.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol (e.g., "AAPL", "MSFT")
|
||||
amount: Number of shares to buy (positive integer)
|
||||
signature: Model signature (injected by ContextInjector)
|
||||
today_date: Trading date YYYY-MM-DD (injected by ContextInjector)
|
||||
job_id: Job UUID (injected by ContextInjector)
|
||||
session_id: Trading session ID (injected by ContextInjector)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
- Success: {"CASH": amount, symbol: quantity, ...}
|
||||
- Failure: {"error": message, ...}
|
||||
This function is not exposed to the AI model. It receives runtime context
|
||||
(signature, today_date, job_id, session_id) from the ContextInjector.
|
||||
"""
|
||||
# Validate required parameters
|
||||
if not job_id:
|
||||
@@ -206,8 +195,29 @@ def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def sell(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
||||
def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Buy stock shares.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol (e.g., "AAPL", "MSFT", "GOOGL")
|
||||
amount: Number of shares to buy (positive integer)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
|
||||
- Failure: {"error": error_message, ...}
|
||||
|
||||
Note: signature, today_date, job_id, session_id are automatically injected by the system.
|
||||
Do not provide these parameters - they will be added automatically.
|
||||
"""
|
||||
# Delegate to internal implementation
|
||||
return _buy_impl(symbol, amount, signature, today_date, job_id, session_id)
|
||||
|
||||
|
||||
def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Sell stock function - writes to SQLite database.
|
||||
|
||||
@@ -327,6 +337,28 @@ def sell(symbol: str, amount: int, signature: str = None, today_date: str = None
|
||||
conn.close()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def sell(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Sell stock shares.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol (e.g., "AAPL", "MSFT", "GOOGL")
|
||||
amount: Number of shares to sell (positive integer)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
|
||||
- Failure: {"error": error_message, ...}
|
||||
|
||||
Note: signature, today_date, job_id, session_id are automatically injected by the system.
|
||||
Do not provide these parameters - they will be added automatically.
|
||||
"""
|
||||
# Delegate to internal implementation
|
||||
return _sell_impl(symbol, amount, signature, today_date, job_id, session_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.getenv("TRADE_HTTP_PORT", "8002"))
|
||||
mcp.run(transport="streamable-http", port=port)
|
||||
|
||||
@@ -140,7 +140,10 @@ class ModelDayExecutor:
|
||||
job_id=self.job_id,
|
||||
session_id=session_id
|
||||
)
|
||||
agent.set_context(context_injector)
|
||||
logger.info(f"[DEBUG] ModelDayExecutor: Created ContextInjector with signature={self.model_sig}, date={self.date}, job_id={self.job_id}, session_id={session_id}")
|
||||
logger.info(f"[DEBUG] ModelDayExecutor: Calling await agent.set_context()")
|
||||
await agent.set_context(context_injector)
|
||||
logger.info(f"[DEBUG] ModelDayExecutor: set_context() completed")
|
||||
|
||||
# Run trading session
|
||||
logger.info(f"Running trading session for {self.model_sig} on {self.date}")
|
||||
@@ -155,10 +158,13 @@ class ModelDayExecutor:
|
||||
# Update session summary
|
||||
await self._update_session_summary(cursor, session_id, conversation, agent)
|
||||
|
||||
# Store positions (pass session_id)
|
||||
self._write_results_to_db(agent, session_id)
|
||||
|
||||
# Commit and close connection before _write_results_to_db opens a new one
|
||||
conn.commit()
|
||||
conn.close()
|
||||
conn = None # Mark as closed
|
||||
|
||||
# Store positions (pass session_id) - this opens its own connection
|
||||
self._write_results_to_db(agent, session_id)
|
||||
|
||||
# Update status to completed
|
||||
self.job_manager.update_job_detail_status(
|
||||
|
||||
@@ -90,7 +90,7 @@ class SimulationWorker:
|
||||
logger.info(f"Starting job {self.job_id}: {len(date_range)} dates, {len(models)} models")
|
||||
|
||||
# NEW: Prepare price data (download if needed)
|
||||
available_dates, warnings = self._prepare_data(date_range, models, config_path)
|
||||
available_dates, warnings, completion_skips = self._prepare_data(date_range, models, config_path)
|
||||
|
||||
if not available_dates:
|
||||
error_msg = "No trading dates available after price data preparation"
|
||||
@@ -100,7 +100,7 @@ class SimulationWorker:
|
||||
# Execute available dates only
|
||||
for date in available_dates:
|
||||
logger.info(f"Processing date {date} with {len(models)} models")
|
||||
self._execute_date(date, models, config_path)
|
||||
self._execute_date(date, models, config_path, completion_skips)
|
||||
|
||||
# Job completed - determine final status
|
||||
progress = self.job_manager.get_job_progress(self.job_id)
|
||||
@@ -145,7 +145,8 @@ class SimulationWorker:
|
||||
"error": error_msg
|
||||
}
|
||||
|
||||
def _execute_date(self, date: str, models: List[str], config_path: str) -> None:
|
||||
def _execute_date(self, date: str, models: List[str], config_path: str,
|
||||
completion_skips: Dict[str, Set[str]] = None) -> None:
|
||||
"""
|
||||
Execute all models for a single date in parallel.
|
||||
|
||||
@@ -153,14 +154,24 @@ class SimulationWorker:
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
models: List of model signatures to execute
|
||||
config_path: Path to configuration file
|
||||
completion_skips: {model: {dates}} of already-completed model-days to skip
|
||||
|
||||
Uses ThreadPoolExecutor to run all models concurrently for this date.
|
||||
Waits for all models to complete before returning.
|
||||
Skips models that have already completed this date.
|
||||
"""
|
||||
if completion_skips is None:
|
||||
completion_skips = {}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# Submit all model executions for this date
|
||||
futures = []
|
||||
for model in models:
|
||||
# Skip if this model-day was already completed
|
||||
if date in completion_skips.get(model, set()):
|
||||
logger.debug(f"Skipping {model} on {date} (already completed)")
|
||||
continue
|
||||
|
||||
future = executor.submit(
|
||||
self._execute_model_day,
|
||||
date,
|
||||
@@ -397,7 +408,10 @@ class SimulationWorker:
|
||||
config_path: Path to configuration file
|
||||
|
||||
Returns:
|
||||
Tuple of (available_dates, warnings)
|
||||
Tuple of (available_dates, warnings, completion_skips)
|
||||
- available_dates: Dates to process
|
||||
- warnings: Warning messages
|
||||
- completion_skips: {model: {dates}} of already-completed model-days
|
||||
"""
|
||||
from api.price_data_manager import PriceDataManager
|
||||
|
||||
@@ -456,7 +470,7 @@ class SimulationWorker:
|
||||
self.job_manager.update_job_status(self.job_id, "running")
|
||||
logger.info(f"Job {self.job_id}: Starting execution - {len(dates_to_process)} dates, {len(models)} models")
|
||||
|
||||
return dates_to_process, warnings
|
||||
return dates_to_process, warnings, completion_skips
|
||||
|
||||
def get_job_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user