mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-09 12:17:24 -04:00
Compare commits
5 Commits
v0.3.0-alp
...
v0.3.0-alp
| Author | SHA1 | Date | |
|---|---|---|---|
| 6c395f740d | |||
| 618943b278 | |||
| 1c19eea29a | |||
| e968434062 | |||
| 4c1d23a7c8 |
@@ -221,7 +221,7 @@ class BaseAgent:
|
|||||||
|
|
||||||
print(f"✅ Agent {self.signature} initialization completed")
|
print(f"✅ Agent {self.signature} initialization completed")
|
||||||
|
|
||||||
def set_context(self, context_injector: "ContextInjector") -> None:
|
async def set_context(self, context_injector: "ContextInjector") -> None:
|
||||||
"""
|
"""
|
||||||
Inject ContextInjector after initialization.
|
Inject ContextInjector after initialization.
|
||||||
|
|
||||||
@@ -232,14 +232,24 @@ class BaseAgent:
|
|||||||
context_injector: Configured ContextInjector instance with
|
context_injector: Configured ContextInjector instance with
|
||||||
correct signature, today_date, job_id, session_id
|
correct signature, today_date, job_id, session_id
|
||||||
"""
|
"""
|
||||||
|
print(f"[DEBUG] set_context() ENTRY: Received context_injector with signature={context_injector.signature}, date={context_injector.today_date}, job_id={context_injector.job_id}, session_id={context_injector.session_id}")
|
||||||
|
|
||||||
self.context_injector = context_injector
|
self.context_injector = context_injector
|
||||||
|
print(f"[DEBUG] set_context(): Set self.context_injector, id={id(self.context_injector)}")
|
||||||
|
|
||||||
# Recreate MCP client with the interceptor
|
# Recreate MCP client with the interceptor
|
||||||
# Note: We need to recreate because MultiServerMCPClient doesn't have add_interceptor()
|
# Note: We need to recreate because MultiServerMCPClient doesn't have add_interceptor()
|
||||||
|
print(f"[DEBUG] set_context(): Creating new MCP client with interceptor, id={id(context_injector)}")
|
||||||
self.client = MultiServerMCPClient(
|
self.client = MultiServerMCPClient(
|
||||||
self.mcp_config,
|
self.mcp_config,
|
||||||
tool_interceptors=[context_injector]
|
tool_interceptors=[context_injector]
|
||||||
)
|
)
|
||||||
|
print(f"[DEBUG] set_context(): MCP client created")
|
||||||
|
|
||||||
|
# CRITICAL: Reload tools from new client so they use the interceptor
|
||||||
|
print(f"[DEBUG] set_context(): Reloading tools...")
|
||||||
|
self.tools = await self.client.get_tools()
|
||||||
|
print(f"[DEBUG] set_context(): Tools reloaded, count={len(self.tools)}")
|
||||||
|
|
||||||
print(f"✅ Context injected: signature={context_injector.signature}, "
|
print(f"✅ Context injected: signature={context_injector.signature}, "
|
||||||
f"date={context_injector.today_date}, job_id={context_injector.job_id}, "
|
f"date={context_injector.today_date}, job_id={context_injector.job_id}, "
|
||||||
|
|||||||
@@ -49,14 +49,16 @@ class ContextInjector:
|
|||||||
"""
|
"""
|
||||||
# Inject context parameters for trade tools
|
# Inject context parameters for trade tools
|
||||||
if request.name in ["buy", "sell"]:
|
if request.name in ["buy", "sell"]:
|
||||||
# Add signature and today_date to args if not present
|
# Debug: Log self attributes BEFORE injection
|
||||||
if "signature" not in request.args:
|
print(f"[ContextInjector.__call__] ENTRY: id={id(self)}, self.signature={self.signature}, self.today_date={self.today_date}, self.job_id={self.job_id}, self.session_id={self.session_id}")
|
||||||
request.args["signature"] = self.signature
|
print(f"[ContextInjector.__call__] Args BEFORE injection: {request.args}")
|
||||||
if "today_date" not in request.args:
|
|
||||||
request.args["today_date"] = self.today_date
|
# ALWAYS inject/override context parameters (don't trust AI-provided values)
|
||||||
if "job_id" not in request.args and self.job_id:
|
request.args["signature"] = self.signature
|
||||||
|
request.args["today_date"] = self.today_date
|
||||||
|
if self.job_id:
|
||||||
request.args["job_id"] = self.job_id
|
request.args["job_id"] = self.job_id
|
||||||
if "session_id" not in request.args and self.session_id:
|
if self.session_id:
|
||||||
request.args["session_id"] = self.session_id
|
request.args["session_id"] = self.session_id
|
||||||
|
|
||||||
# Debug logging
|
# Debug logging
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
from tools.price_tools import get_open_prices
|
from tools.price_tools import get_open_prices
|
||||||
import json
|
import json
|
||||||
from tools.deployment_config import get_db_path
|
|
||||||
from api.database import get_db_connection
|
from api.database import get_db_connection
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
mcp = FastMCP("TradeTools")
|
mcp = FastMCP("TradeTools")
|
||||||
@@ -30,7 +29,7 @@ def get_current_position_from_db(job_id: str, model: str, date: str) -> Tuple[Di
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If database query fails
|
Exception: If database query fails
|
||||||
"""
|
"""
|
||||||
db_path = get_db_path()
|
db_path = "data/jobs.db"
|
||||||
conn = get_db_connection(db_path)
|
conn = get_db_connection(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -110,7 +109,7 @@ def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
|||||||
if not today_date:
|
if not today_date:
|
||||||
return {"error": "Missing required parameter: today_date"}
|
return {"error": "Missing required parameter: today_date"}
|
||||||
|
|
||||||
db_path = get_db_path()
|
db_path = "data/jobs.db"
|
||||||
conn = get_db_connection(db_path)
|
conn = get_db_connection(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -233,7 +232,7 @@ def sell(symbol: str, amount: int, signature: str = None, today_date: str = None
|
|||||||
if not today_date:
|
if not today_date:
|
||||||
return {"error": "Missing required parameter: today_date"}
|
return {"error": "Missing required parameter: today_date"}
|
||||||
|
|
||||||
db_path = get_db_path()
|
db_path = "data/jobs.db"
|
||||||
conn = get_db_connection(db_path)
|
conn = get_db_connection(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
|||||||
@@ -140,7 +140,10 @@ class ModelDayExecutor:
|
|||||||
job_id=self.job_id,
|
job_id=self.job_id,
|
||||||
session_id=session_id
|
session_id=session_id
|
||||||
)
|
)
|
||||||
agent.set_context(context_injector)
|
logger.info(f"[DEBUG] ModelDayExecutor: Created ContextInjector with signature={self.model_sig}, date={self.date}, job_id={self.job_id}, session_id={session_id}")
|
||||||
|
logger.info(f"[DEBUG] ModelDayExecutor: Calling await agent.set_context()")
|
||||||
|
await agent.set_context(context_injector)
|
||||||
|
logger.info(f"[DEBUG] ModelDayExecutor: set_context() completed")
|
||||||
|
|
||||||
# Run trading session
|
# Run trading session
|
||||||
logger.info(f"Running trading session for {self.model_sig} on {self.date}")
|
logger.info(f"Running trading session for {self.model_sig} on {self.date}")
|
||||||
@@ -155,10 +158,13 @@ class ModelDayExecutor:
|
|||||||
# Update session summary
|
# Update session summary
|
||||||
await self._update_session_summary(cursor, session_id, conversation, agent)
|
await self._update_session_summary(cursor, session_id, conversation, agent)
|
||||||
|
|
||||||
# Store positions (pass session_id)
|
# Commit and close connection before _write_results_to_db opens a new one
|
||||||
self._write_results_to_db(agent, session_id)
|
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
conn = None # Mark as closed
|
||||||
|
|
||||||
|
# Store positions (pass session_id) - this opens its own connection
|
||||||
|
self._write_results_to_db(agent, session_id)
|
||||||
|
|
||||||
# Update status to completed
|
# Update status to completed
|
||||||
self.job_manager.update_job_detail_status(
|
self.job_manager.update_job_detail_status(
|
||||||
|
|||||||
@@ -320,12 +320,11 @@ def get_today_init_position_from_db(
|
|||||||
If no position exists: {"CASH": 10000.0} (initial cash)
|
If no position exists: {"CASH": 10000.0} (initial cash)
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from tools.deployment_config import get_db_path
|
|
||||||
from api.database import get_db_connection
|
from api.database import get_db_connection
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
db_path = get_db_path()
|
db_path = "data/jobs.db"
|
||||||
conn = get_db_connection(db_path)
|
conn = get_db_connection(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -385,14 +384,13 @@ def add_no_trade_record_to_db(
|
|||||||
session_id: Trading session ID
|
session_id: Trading session ID
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from tools.deployment_config import get_db_path
|
|
||||||
from api.database import get_db_connection
|
from api.database import get_db_connection
|
||||||
from agent_tools.tool_trade import get_current_position_from_db
|
from agent_tools.tool_trade import get_current_position_from_db
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
db_path = get_db_path()
|
db_path = "data/jobs.db"
|
||||||
conn = get_db_connection(db_path)
|
conn = get_db_connection(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user