mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 01:27:24 -04:00
Compare commits
7 Commits
v0.4.0-alp
...
v0.4.1-alp
| Author | SHA1 | Date | |
|---|---|---|---|
| 6d30244fc9 | |||
| 0641ce554a | |||
| 0c6de5b74b | |||
| 0f49977700 | |||
| 27a824f4a6 | |||
| 3e50868a4d | |||
| e20dce7432 |
53
CHANGELOG.md
53
CHANGELOG.md
@@ -7,7 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.4.0] - 2025-11-04
|
||||
## [0.4.1] - 2025-11-05
|
||||
|
||||
### Fixed
|
||||
- Resolved Pydantic validation errors when using DeepSeek Chat v3.1 through systematic debugging
|
||||
- Root cause: Initial implementation incorrectly converted tool_calls arguments from strings to dictionaries, causing LangChain's `parse_tool_call()` to fail and create invalid_tool_calls with wrong format
|
||||
- Solution: Removed unnecessary conversion logic - DeepSeek already returns arguments in correct format (JSON strings)
|
||||
- `ToolCallArgsParsingWrapper` now acts as a simple passthrough proxy (kept for backward compatibility)
|
||||
|
||||
## [0.4.0] - 2025-11-05
|
||||
|
||||
### BREAKING CHANGES
|
||||
|
||||
@@ -130,6 +138,49 @@ New `/results?reasoning=full` returns:
|
||||
- Test coverage increased with 36+ new comprehensive tests
|
||||
- Documentation updated with complete API reference and database schema details
|
||||
|
||||
### Fixed
|
||||
- **Critical:** Intra-day position tracking for sell-then-buy trades (e20dce7)
|
||||
- Sell proceeds now immediately available for subsequent buy orders within same trading session
|
||||
- ContextInjector maintains in-memory position state during trading sessions
|
||||
- Position updates accumulate after each successful trade
|
||||
- Enables agents to rebalance portfolios (sell + buy) in single session
|
||||
- Added 13 comprehensive tests for position tracking
|
||||
- **Critical:** Tool message extraction in conversation history (462de3a, abb9cd0)
|
||||
- Fixed bug where tool messages (buy/sell trades) were not captured when agent completed in single step
|
||||
- Tool extraction now happens BEFORE finish signal check
|
||||
- Reasoning summaries now accurately reflect actual trades executed
|
||||
- Resolves issue where summarizer saw 0 tools despite multiple trades
|
||||
- Reasoning summary generation improvements (6d126db)
|
||||
- Summaries now explicitly mention specific trades executed (symbols, quantities, actions)
|
||||
- Added TRADES EXECUTED section highlighting tool calls
|
||||
- Example: 'sold 1 GOOGL and 1 AMZN to reduce exposure' instead of 'maintain core holdings'
|
||||
- Final holdings calculation accuracy (a8d912b)
|
||||
- Final positions now calculated from actions instead of querying incomplete database records
|
||||
- Correctly handles first trading day with multiple trades
|
||||
- New `_calculate_final_position_from_actions()` method applies all trades to calculate final state
|
||||
- Holdings now persist correctly across all trading days
|
||||
- Added 3 comprehensive tests for final position calculation
|
||||
- Holdings persistence between trading days (aa16480)
|
||||
- Query now retrieves previous day's ending position as current day's starting position
|
||||
- Changed query from `date <=` to `date <` to prevent returning incomplete current-day records
|
||||
- Fixes empty starting_position/final_position in API responses despite successful trades
|
||||
- Updated tests to verify correct previous-day retrieval
|
||||
- Context injector trading_day_id synchronization (05620fa)
|
||||
- ContextInjector now updated with trading_day_id after record creation
|
||||
- Fixes "Trade failed: trading_day_id not found in runtime config" error
|
||||
- MCP tools now correctly receive trading_day_id via context injection
|
||||
- Schema migration compatibility fixes (7c71a04)
|
||||
- Updated position queries to use new trading_days schema instead of obsolete positions table
|
||||
- Removed obsolete add_no_trade_record_to_db function calls
|
||||
- Fixes "no such table: positions" error
|
||||
- Simplified _handle_trading_result logic
|
||||
- Database referential integrity (9da65c2)
|
||||
- Corrected Database default path from "data/trading.db" to "data/jobs.db"
|
||||
- Ensures all components use same database file
|
||||
- Fixes FOREIGN KEY constraint failures when creating trading_day records
|
||||
- Debug logging cleanup (1e7bdb5)
|
||||
- Removed verbose debug logging from ContextInjector for cleaner output
|
||||
|
||||
## [0.3.1] - 2025-11-03
|
||||
|
||||
### Fixed
|
||||
|
||||
@@ -533,6 +533,8 @@ Summary:"""
|
||||
# Update context injector with current trading date
|
||||
if self.context_injector:
|
||||
self.context_injector.today_date = today_date
|
||||
# Reset position state for new trading day (enables intra-day tracking)
|
||||
self.context_injector.reset_position()
|
||||
|
||||
# Clear conversation history for new trading day
|
||||
self.clear_conversation_history()
|
||||
|
||||
51
agent/chat_model_wrapper.py
Normal file
51
agent/chat_model_wrapper.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Chat model wrapper - Passthrough wrapper for ChatOpenAI models.
|
||||
|
||||
Originally created to fix DeepSeek tool_calls arg parsing issues, but investigation
|
||||
revealed DeepSeek already returns the correct format (arguments as JSON strings).
|
||||
|
||||
This wrapper is now a simple passthrough that proxies all calls to the underlying model.
|
||||
Kept for backward compatibility and potential future use.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolCallArgsParsingWrapper:
|
||||
"""
|
||||
Passthrough wrapper around ChatOpenAI models.
|
||||
|
||||
After systematic debugging, determined that DeepSeek returns tool_calls.arguments
|
||||
as JSON strings (correct format), so no parsing/conversion is needed.
|
||||
|
||||
This wrapper simply proxies all calls to the wrapped model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Any, **kwargs):
|
||||
"""
|
||||
Initialize wrapper around a chat model.
|
||||
|
||||
Args:
|
||||
model: The chat model to wrap
|
||||
**kwargs: Additional parameters (ignored, for compatibility)
|
||||
"""
|
||||
self.wrapped_model = model
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return identifier for this LLM type"""
|
||||
if hasattr(self.wrapped_model, '_llm_type'):
|
||||
return f"wrapped-{self.wrapped_model._llm_type}"
|
||||
return "wrapped-chat-model"
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Proxy all attributes/methods to the wrapped model"""
|
||||
return getattr(self.wrapped_model, name)
|
||||
|
||||
def bind_tools(self, tools: Any, **kwargs):
|
||||
"""Bind tools to the wrapped model"""
|
||||
return self.wrapped_model.bind_tools(tools, **kwargs)
|
||||
|
||||
def bind(self, **kwargs):
|
||||
"""Bind settings to the wrapped model"""
|
||||
return self.wrapped_model.bind(**kwargs)
|
||||
@@ -3,15 +3,22 @@ Tool interceptor for injecting runtime context into MCP tool calls.
|
||||
|
||||
This interceptor automatically injects `signature` and `today_date` parameters
|
||||
into buy/sell tool calls to support concurrent multi-model simulations.
|
||||
|
||||
It also maintains in-memory position state to track cumulative changes within
|
||||
a single trading session, ensuring sell proceeds are immediately available for
|
||||
subsequent buy orders.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
from typing import Any, Callable, Awaitable, Dict, Optional
|
||||
|
||||
|
||||
class ContextInjector:
|
||||
"""
|
||||
Intercepts tool calls to inject runtime context (signature, today_date).
|
||||
|
||||
Also maintains cumulative position state during trading session to ensure
|
||||
sell proceeds are immediately available for subsequent buys.
|
||||
|
||||
Usage:
|
||||
interceptor = ContextInjector(signature="gpt-5", today_date="2025-10-01")
|
||||
client = MultiServerMCPClient(config, tool_interceptors=[interceptor])
|
||||
@@ -34,6 +41,13 @@ class ContextInjector:
|
||||
self.job_id = job_id
|
||||
self.session_id = session_id # Deprecated but kept for compatibility
|
||||
self.trading_day_id = trading_day_id
|
||||
self._current_position: Optional[Dict[str, float]] = None
|
||||
|
||||
def reset_position(self) -> None:
|
||||
"""
|
||||
Reset position state (call at start of each trading day).
|
||||
"""
|
||||
self._current_position = None
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
@@ -43,6 +57,9 @@ class ContextInjector:
|
||||
"""
|
||||
Intercept tool call and inject context parameters.
|
||||
|
||||
For buy/sell operations, maintains cumulative position state to ensure
|
||||
sell proceeds are immediately available for subsequent buys.
|
||||
|
||||
Args:
|
||||
request: Tool call request containing name and arguments
|
||||
handler: Async callable to execute the actual tool
|
||||
@@ -62,5 +79,18 @@ class ContextInjector:
|
||||
if self.trading_day_id:
|
||||
request.args["trading_day_id"] = self.trading_day_id
|
||||
|
||||
# Inject current position if we're tracking it
|
||||
if self._current_position is not None:
|
||||
request.args["_current_position"] = self._current_position
|
||||
|
||||
# Call the actual tool handler
|
||||
return await handler(request)
|
||||
result = await handler(request)
|
||||
|
||||
# Update position state after successful trade
|
||||
if request.name in ["buy", "sell"]:
|
||||
# Check if result is a valid position dict (not an error)
|
||||
if isinstance(result, dict) and "error" not in result and "CASH" in result:
|
||||
# Update our tracked position with the new state
|
||||
self._current_position = result.copy()
|
||||
|
||||
return result
|
||||
|
||||
@@ -91,7 +91,8 @@ def get_current_position_from_db(
|
||||
|
||||
|
||||
def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
|
||||
job_id: str = None, session_id: int = None, trading_day_id: int = None,
|
||||
_current_position: Dict[str, float] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Internal buy implementation - accepts injected context parameters.
|
||||
|
||||
@@ -103,9 +104,13 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str =
|
||||
job_id: Job ID (injected)
|
||||
session_id: Session ID (injected, DEPRECATED)
|
||||
trading_day_id: Trading day ID (injected)
|
||||
_current_position: Current position state (injected by ContextInjector)
|
||||
|
||||
This function is not exposed to the AI model. It receives runtime context
|
||||
(signature, today_date, job_id, session_id, trading_day_id) from the ContextInjector.
|
||||
|
||||
The _current_position parameter enables intra-day position tracking, ensuring
|
||||
sell proceeds are immediately available for subsequent buys.
|
||||
"""
|
||||
# Validate required parameters
|
||||
if not job_id:
|
||||
@@ -121,7 +126,13 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str =
|
||||
|
||||
try:
|
||||
# Step 1: Get current position
|
||||
current_position, next_action_id = get_current_position_from_db(job_id, signature, today_date)
|
||||
# Use injected position if available (for intra-day tracking),
|
||||
# otherwise query database for starting position
|
||||
if _current_position is not None:
|
||||
current_position = _current_position
|
||||
next_action_id = 0 # Not used in new schema
|
||||
else:
|
||||
current_position, next_action_id = get_current_position_from_db(job_id, signature, today_date)
|
||||
|
||||
# Step 2: Get stock price
|
||||
try:
|
||||
@@ -186,7 +197,8 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str =
|
||||
|
||||
@mcp.tool()
|
||||
def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
|
||||
job_id: str = None, session_id: int = None, trading_day_id: int = None,
|
||||
_current_position: Dict[str, float] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Buy stock shares.
|
||||
|
||||
@@ -199,14 +211,15 @@ def buy(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
|
||||
- Failure: {"error": error_message, ...}
|
||||
|
||||
Note: signature, today_date, job_id, session_id, trading_day_id are
|
||||
automatically injected by the system. Do not provide these parameters.
|
||||
Note: signature, today_date, job_id, session_id, trading_day_id, _current_position
|
||||
are automatically injected by the system. Do not provide these parameters.
|
||||
"""
|
||||
return _buy_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id)
|
||||
return _buy_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id, _current_position)
|
||||
|
||||
|
||||
def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
|
||||
job_id: str = None, session_id: int = None, trading_day_id: int = None,
|
||||
_current_position: Dict[str, float] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Sell stock function - writes to SQLite database.
|
||||
|
||||
@@ -218,11 +231,15 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
|
||||
job_id: Job UUID (injected by ContextInjector)
|
||||
session_id: Trading session ID (injected by ContextInjector, DEPRECATED)
|
||||
trading_day_id: Trading day ID (injected by ContextInjector)
|
||||
_current_position: Current position state (injected by ContextInjector)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
- Success: {"CASH": amount, symbol: quantity, ...}
|
||||
- Failure: {"error": message, ...}
|
||||
|
||||
The _current_position parameter enables intra-day position tracking, ensuring
|
||||
sell proceeds are immediately available for subsequent buys.
|
||||
"""
|
||||
# Validate required parameters
|
||||
if not job_id:
|
||||
@@ -238,7 +255,13 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
|
||||
|
||||
try:
|
||||
# Step 1: Get current position
|
||||
current_position, next_action_id = get_current_position_from_db(job_id, signature, today_date)
|
||||
# Use injected position if available (for intra-day tracking),
|
||||
# otherwise query database for starting position
|
||||
if _current_position is not None:
|
||||
current_position = _current_position
|
||||
next_action_id = 0 # Not used in new schema
|
||||
else:
|
||||
current_position, next_action_id = get_current_position_from_db(job_id, signature, today_date)
|
||||
|
||||
# Step 2: Validate position exists
|
||||
if symbol not in current_position:
|
||||
@@ -298,7 +321,8 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str
|
||||
|
||||
@mcp.tool()
|
||||
def sell(symbol: str, amount: int, signature: str = None, today_date: str = None,
|
||||
job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]:
|
||||
job_id: str = None, session_id: int = None, trading_day_id: int = None,
|
||||
_current_position: Dict[str, float] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Sell stock shares.
|
||||
|
||||
@@ -311,10 +335,10 @@ def sell(symbol: str, amount: int, signature: str = None, today_date: str = None
|
||||
- Success: {"CASH": remaining_cash, "SYMBOL": shares, ...}
|
||||
- Failure: {"error": error_message, ...}
|
||||
|
||||
Note: signature, today_date, job_id, session_id, trading_day_id are
|
||||
automatically injected by the system. Do not provide these parameters.
|
||||
Note: signature, today_date, job_id, session_id, trading_day_id, _current_position
|
||||
are automatically injected by the system. Do not provide these parameters.
|
||||
"""
|
||||
return _sell_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id)
|
||||
return _sell_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id, _current_position)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
216
tests/unit/test_chat_model_wrapper.py
Normal file
216
tests/unit/test_chat_model_wrapper.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Unit tests for ChatModelWrapper - tool_calls args parsing fix
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
|
||||
from agent.chat_model_wrapper import ToolCallArgsParsingWrapper
|
||||
|
||||
|
||||
class TestToolCallArgsParsingWrapper:
|
||||
"""Tests for ToolCallArgsParsingWrapper"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self):
|
||||
"""Create a mock chat model"""
|
||||
model = Mock()
|
||||
model._llm_type = "mock-model"
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self, mock_model):
|
||||
"""Create a wrapper around mock model"""
|
||||
return ToolCallArgsParsingWrapper(model=mock_model)
|
||||
|
||||
def test_fix_tool_calls_with_string_args(self, wrapper):
|
||||
"""Test that string args are parsed to dict"""
|
||||
# Create message with tool_calls where args is a JSON string
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": '{"symbol": "AAPL", "amount": 10}', # String, not dict
|
||||
"id": "call_123"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
# Check that args is now a dict
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "AAPL", "amount": 10}
|
||||
|
||||
def test_fix_tool_calls_with_dict_args(self, wrapper):
|
||||
"""Test that dict args are left unchanged"""
|
||||
# Create message with tool_calls where args is already a dict
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": {"symbol": "AAPL", "amount": 10}, # Already a dict
|
||||
"id": "call_123"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
# Check that args is still a dict
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "AAPL", "amount": 10}
|
||||
|
||||
def test_fix_tool_calls_with_invalid_json(self, wrapper):
|
||||
"""Test that invalid JSON string is left unchanged"""
|
||||
# Create message with tool_calls where args is an invalid JSON string
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": 'invalid json {', # Invalid JSON
|
||||
"id": "call_123"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
# Check that args is still a string (parsing failed)
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], str)
|
||||
assert fixed_message.tool_calls[0]['args'] == 'invalid json {'
|
||||
|
||||
def test_fix_tool_calls_no_tool_calls(self, wrapper):
|
||||
"""Test that messages without tool_calls are left unchanged"""
|
||||
message = AIMessage(content="Hello, world!")
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
assert fixed_message == message
|
||||
|
||||
def test_generate_with_string_args(self, wrapper, mock_model):
|
||||
"""Test _generate method with string args"""
|
||||
# Create a response with string args
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": '{"symbol": "MSFT", "amount": 5}',
|
||||
"id": "call_456"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_result = ChatResult(
|
||||
generations=[ChatGeneration(message=original_message)]
|
||||
)
|
||||
mock_model._generate.return_value = mock_result
|
||||
|
||||
# Call wrapper's _generate
|
||||
result = wrapper._generate(messages=[], stop=None, run_manager=None)
|
||||
|
||||
# Check that args is now a dict
|
||||
fixed_message = result.generations[0].message
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "MSFT", "amount": 5}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agenerate_with_string_args(self, wrapper, mock_model):
|
||||
"""Test _agenerate method with string args"""
|
||||
# Create a response with string args
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "sell",
|
||||
"args": '{"symbol": "GOOGL", "amount": 3}',
|
||||
"id": "call_789"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_result = ChatResult(
|
||||
generations=[ChatGeneration(message=original_message)]
|
||||
)
|
||||
mock_model._agenerate = AsyncMock(return_value=mock_result)
|
||||
|
||||
# Call wrapper's _agenerate
|
||||
result = await wrapper._agenerate(messages=[], stop=None, run_manager=None)
|
||||
|
||||
# Check that args is now a dict
|
||||
fixed_message = result.generations[0].message
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "GOOGL", "amount": 3}
|
||||
|
||||
def test_invoke_with_string_args(self, wrapper, mock_model):
|
||||
"""Test invoke method with string args"""
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": '{"symbol": "NVDA", "amount": 20}',
|
||||
"id": "call_999"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_model.invoke.return_value = original_message
|
||||
|
||||
# Call wrapper's invoke
|
||||
result = wrapper.invoke(input=[])
|
||||
|
||||
# Check that args is now a dict
|
||||
assert isinstance(result.tool_calls[0]['args'], dict)
|
||||
assert result.tool_calls[0]['args'] == {"symbol": "NVDA", "amount": 20}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_with_string_args(self, wrapper, mock_model):
|
||||
"""Test ainvoke method with string args"""
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "sell",
|
||||
"args": '{"symbol": "TSLA", "amount": 15}',
|
||||
"id": "call_111"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_model.ainvoke = AsyncMock(return_value=original_message)
|
||||
|
||||
# Call wrapper's ainvoke
|
||||
result = await wrapper.ainvoke(input=[])
|
||||
|
||||
# Check that args is now a dict
|
||||
assert isinstance(result.tool_calls[0]['args'], dict)
|
||||
assert result.tool_calls[0]['args'] == {"symbol": "TSLA", "amount": 15}
|
||||
|
||||
def test_bind_tools_returns_wrapper(self, wrapper, mock_model):
|
||||
"""Test that bind_tools returns a new wrapper"""
|
||||
mock_bound = Mock()
|
||||
mock_model.bind_tools.return_value = mock_bound
|
||||
|
||||
result = wrapper.bind_tools(tools=[], strict=True)
|
||||
|
||||
# Check that result is a wrapper around the bound model
|
||||
assert isinstance(result, ToolCallArgsParsingWrapper)
|
||||
assert result.wrapped_model == mock_bound
|
||||
|
||||
def test_bind_returns_wrapper(self, wrapper, mock_model):
|
||||
"""Test that bind returns a new wrapper"""
|
||||
mock_bound = Mock()
|
||||
mock_model.bind.return_value = mock_bound
|
||||
|
||||
result = wrapper.bind(max_tokens=100)
|
||||
|
||||
# Check that result is a wrapper around the bound model
|
||||
assert isinstance(result, ToolCallArgsParsingWrapper)
|
||||
assert result.wrapped_model == mock_bound
|
||||
192
tests/unit/test_context_injector.py
Normal file
192
tests/unit/test_context_injector.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Test ContextInjector position tracking functionality."""
|
||||
|
||||
import pytest
|
||||
from agent.context_injector import ContextInjector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def injector():
|
||||
"""Create a ContextInjector instance for testing."""
|
||||
return ContextInjector(
|
||||
signature="test-model",
|
||||
today_date="2025-01-15",
|
||||
job_id="test-job-123",
|
||||
trading_day_id=1
|
||||
)
|
||||
|
||||
|
||||
class MockRequest:
|
||||
"""Mock MCP tool request."""
|
||||
def __init__(self, name, args=None):
|
||||
self.name = name
|
||||
self.args = args or {}
|
||||
|
||||
|
||||
async def mock_handler_success(request):
|
||||
"""Mock handler that returns a successful position update."""
|
||||
# Simulate a successful trade returning updated position
|
||||
if request.name == "sell":
|
||||
return {
|
||||
"CASH": 1100.0,
|
||||
"AAPL": 7,
|
||||
"MSFT": 5
|
||||
}
|
||||
elif request.name == "buy":
|
||||
return {
|
||||
"CASH": 50.0,
|
||||
"AAPL": 7,
|
||||
"MSFT": 12
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
async def mock_handler_error(request):
|
||||
"""Mock handler that returns an error."""
|
||||
return {"error": "Insufficient cash"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_injector_initializes_with_no_position(injector):
|
||||
"""Test that ContextInjector starts with no position state."""
|
||||
assert injector._current_position is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_injector_reset_position(injector):
|
||||
"""Test that reset_position() clears position state."""
|
||||
# Set some position state
|
||||
injector._current_position = {"CASH": 5000.0, "AAPL": 10}
|
||||
|
||||
# Reset
|
||||
injector.reset_position()
|
||||
|
||||
assert injector._current_position is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_injector_injects_parameters(injector):
|
||||
"""Test that context parameters are injected into buy/sell requests."""
|
||||
request = MockRequest("buy", {"symbol": "AAPL", "amount": 10})
|
||||
|
||||
# Mock handler that just returns the request args
|
||||
async def handler(req):
|
||||
return req.args
|
||||
|
||||
result = await injector(request, handler)
|
||||
|
||||
# Verify context was injected
|
||||
assert result["signature"] == "test-model"
|
||||
assert result["today_date"] == "2025-01-15"
|
||||
assert result["job_id"] == "test-job-123"
|
||||
assert result["trading_day_id"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_injector_tracks_position_after_successful_trade(injector):
|
||||
"""Test that position state is updated after successful trades."""
|
||||
assert injector._current_position is None
|
||||
|
||||
# Execute a sell trade
|
||||
request = MockRequest("sell", {"symbol": "AAPL", "amount": 3})
|
||||
result = await injector(request, mock_handler_success)
|
||||
|
||||
# Verify position was updated
|
||||
assert injector._current_position is not None
|
||||
assert injector._current_position["CASH"] == 1100.0
|
||||
assert injector._current_position["AAPL"] == 7
|
||||
assert injector._current_position["MSFT"] == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_injector_injects_current_position_on_subsequent_trades(injector):
|
||||
"""Test that current position is injected into subsequent trade requests."""
|
||||
# First trade - establish position
|
||||
request1 = MockRequest("sell", {"symbol": "AAPL", "amount": 3})
|
||||
await injector(request1, mock_handler_success)
|
||||
|
||||
# Second trade - should receive current position
|
||||
request2 = MockRequest("buy", {"symbol": "MSFT", "amount": 7})
|
||||
|
||||
async def verify_injection_handler(req):
|
||||
# Verify that _current_position was injected
|
||||
assert "_current_position" in req.args
|
||||
assert req.args["_current_position"]["CASH"] == 1100.0
|
||||
assert req.args["_current_position"]["AAPL"] == 7
|
||||
return mock_handler_success(req)
|
||||
|
||||
await injector(request2, verify_injection_handler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_injector_does_not_update_position_on_error(injector):
|
||||
"""Test that position state is NOT updated when trade fails."""
|
||||
# First successful trade
|
||||
request1 = MockRequest("sell", {"symbol": "AAPL", "amount": 3})
|
||||
await injector(request1, mock_handler_success)
|
||||
|
||||
original_position = injector._current_position.copy()
|
||||
|
||||
# Second trade that fails
|
||||
request2 = MockRequest("buy", {"symbol": "MSFT", "amount": 100})
|
||||
result = await injector(request2, mock_handler_error)
|
||||
|
||||
# Verify position was NOT updated
|
||||
assert injector._current_position == original_position
|
||||
assert "error" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_injector_does_not_inject_position_for_non_trade_tools(injector):
|
||||
"""Test that position is not injected for non-buy/sell tools."""
|
||||
# Set up position state
|
||||
injector._current_position = {"CASH": 5000.0, "AAPL": 10}
|
||||
|
||||
# Call a non-trade tool
|
||||
request = MockRequest("search", {"query": "market news"})
|
||||
|
||||
async def verify_no_injection_handler(req):
|
||||
assert "_current_position" not in req.args
|
||||
return {"results": []}
|
||||
|
||||
await injector(request, verify_no_injection_handler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_injector_full_trading_session_simulation(injector):
|
||||
"""Test full trading session with multiple trades and position tracking."""
|
||||
# Reset position at start of day
|
||||
injector.reset_position()
|
||||
assert injector._current_position is None
|
||||
|
||||
# Trade 1: Sell AAPL
|
||||
request1 = MockRequest("sell", {"symbol": "AAPL", "amount": 3})
|
||||
|
||||
async def handler1(req):
|
||||
# First trade should NOT have injected position
|
||||
assert req.args.get("_current_position") is None
|
||||
return {"CASH": 1100.0, "AAPL": 7}
|
||||
|
||||
result1 = await injector(request1, handler1)
|
||||
assert injector._current_position == {"CASH": 1100.0, "AAPL": 7}
|
||||
|
||||
# Trade 2: Buy MSFT (should use position from trade 1)
|
||||
request2 = MockRequest("buy", {"symbol": "MSFT", "amount": 7})
|
||||
|
||||
async def handler2(req):
|
||||
# Second trade SHOULD have injected position from trade 1
|
||||
assert req.args["_current_position"]["CASH"] == 1100.0
|
||||
assert req.args["_current_position"]["AAPL"] == 7
|
||||
return {"CASH": 50.0, "AAPL": 7, "MSFT": 7}
|
||||
|
||||
result2 = await injector(request2, handler2)
|
||||
assert injector._current_position == {"CASH": 50.0, "AAPL": 7, "MSFT": 7}
|
||||
|
||||
# Trade 3: Failed trade (should not update position)
|
||||
request3 = MockRequest("buy", {"symbol": "GOOGL", "amount": 100})
|
||||
|
||||
async def handler3(req):
|
||||
return {"error": "Insufficient cash", "cash_available": 50.0}
|
||||
|
||||
result3 = await injector(request3, handler3)
|
||||
# Position should remain unchanged after failed trade
|
||||
assert injector._current_position == {"CASH": 50.0, "AAPL": 7, "MSFT": 7}
|
||||
@@ -295,3 +295,190 @@ def test_sell_writes_to_actions_table(test_db, monkeypatch):
|
||||
assert row[1] == 'AAPL'
|
||||
assert row[2] == 5
|
||||
assert row[3] == 160.0
|
||||
|
||||
|
||||
def test_intraday_position_tracking_sell_then_buy(test_db, monkeypatch):
|
||||
"""Test that sell proceeds are immediately available for subsequent buys."""
|
||||
db, trading_day_id = test_db
|
||||
|
||||
# Setup: Create starting position with AAPL shares and limited cash
|
||||
db.create_holding(trading_day_id, 'AAPL', 10)
|
||||
db.connection.commit()
|
||||
|
||||
# Create a mock connection wrapper
|
||||
class MockConnection:
|
||||
def __init__(self, real_conn):
|
||||
self.real_conn = real_conn
|
||||
|
||||
def cursor(self):
|
||||
return self.real_conn.cursor()
|
||||
|
||||
def commit(self):
|
||||
return self.real_conn.commit()
|
||||
|
||||
def rollback(self):
|
||||
return self.real_conn.rollback()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
mock_conn = MockConnection(db.connection)
|
||||
monkeypatch.setattr('agent_tools.tool_trade.get_db_connection',
|
||||
lambda x: mock_conn)
|
||||
|
||||
# Mock get_current_position_from_db to return starting position
|
||||
monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db',
|
||||
lambda job_id, sig, date: ({'CASH': 500.0, 'AAPL': 10}, 0))
|
||||
|
||||
monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime_intraday.json')
|
||||
|
||||
import json
|
||||
with open('/tmp/test_runtime_intraday.json', 'w') as f:
|
||||
json.dump({
|
||||
'TODAY_DATE': '2025-01-15',
|
||||
'SIGNATURE': 'test-model',
|
||||
'JOB_ID': 'test-job-123',
|
||||
'TRADING_DAY_ID': trading_day_id
|
||||
}, f)
|
||||
|
||||
# Mock prices: AAPL sells for 200, MSFT costs 150
|
||||
def mock_get_prices(date, symbols):
|
||||
if 'AAPL' in symbols:
|
||||
return {'AAPL_price': 200.0}
|
||||
elif 'MSFT' in symbols:
|
||||
return {'MSFT_price': 150.0}
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr('agent_tools.tool_trade.get_open_prices', mock_get_prices)
|
||||
|
||||
# Step 1: Sell 3 shares of AAPL for 600.0
|
||||
# Starting cash: 500.0, proceeds: 600.0, new cash: 1100.0
|
||||
result_sell = _sell_impl(
|
||||
symbol='AAPL',
|
||||
amount=3,
|
||||
signature='test-model',
|
||||
today_date='2025-01-15',
|
||||
job_id='test-job-123',
|
||||
trading_day_id=trading_day_id,
|
||||
_current_position=None # Use database position (starting position)
|
||||
)
|
||||
|
||||
assert 'error' not in result_sell, f"Sell should succeed: {result_sell}"
|
||||
assert result_sell['CASH'] == 1100.0, "Cash should be 500 + (3 * 200) = 1100"
|
||||
assert result_sell['AAPL'] == 7, "AAPL shares should be 10 - 3 = 7"
|
||||
|
||||
# Step 2: Buy 7 shares of MSFT for 1050.0 using the position from the sell
|
||||
# This should work because we pass the updated position from step 1
|
||||
result_buy = _buy_impl(
|
||||
symbol='MSFT',
|
||||
amount=7,
|
||||
signature='test-model',
|
||||
today_date='2025-01-15',
|
||||
job_id='test-job-123',
|
||||
trading_day_id=trading_day_id,
|
||||
_current_position=result_sell # Use position from sell
|
||||
)
|
||||
|
||||
assert 'error' not in result_buy, f"Buy should succeed with sell proceeds: {result_buy}"
|
||||
assert result_buy['CASH'] == 50.0, "Cash should be 1100 - (7 * 150) = 50"
|
||||
assert result_buy['MSFT'] == 7, "MSFT shares should be 7"
|
||||
assert result_buy['AAPL'] == 7, "AAPL shares should still be 7"
|
||||
|
||||
# Verify both actions were recorded
|
||||
cursor = db.connection.execute("""
|
||||
SELECT action_type, symbol, quantity, price
|
||||
FROM actions
|
||||
WHERE trading_day_id = ?
|
||||
ORDER BY created_at
|
||||
""", (trading_day_id,))
|
||||
|
||||
actions = cursor.fetchall()
|
||||
assert len(actions) == 2, "Should have 2 actions (sell + buy)"
|
||||
assert actions[0][0] == 'sell' and actions[0][1] == 'AAPL'
|
||||
assert actions[1][0] == 'buy' and actions[1][1] == 'MSFT'
|
||||
|
||||
|
||||
def test_intraday_tracking_without_position_injection_fails(test_db, monkeypatch):
|
||||
"""Test that without position injection, sell proceeds are NOT available for subsequent buys."""
|
||||
db, trading_day_id = test_db
|
||||
|
||||
# Setup: Create starting position with AAPL shares and limited cash
|
||||
db.create_holding(trading_day_id, 'AAPL', 10)
|
||||
db.connection.commit()
|
||||
|
||||
# Create a mock connection wrapper
|
||||
class MockConnection:
|
||||
def __init__(self, real_conn):
|
||||
self.real_conn = real_conn
|
||||
|
||||
def cursor(self):
|
||||
return self.real_conn.cursor()
|
||||
|
||||
def commit(self):
|
||||
return self.real_conn.commit()
|
||||
|
||||
def rollback(self):
|
||||
return self.real_conn.rollback()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
mock_conn = MockConnection(db.connection)
|
||||
monkeypatch.setattr('agent_tools.tool_trade.get_db_connection',
|
||||
lambda x: mock_conn)
|
||||
|
||||
# Mock get_current_position_from_db to ALWAYS return starting position
|
||||
# (simulating the old buggy behavior)
|
||||
monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db',
|
||||
lambda job_id, sig, date: ({'CASH': 500.0, 'AAPL': 10}, 0))
|
||||
|
||||
monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime_no_injection.json')
|
||||
|
||||
import json
|
||||
with open('/tmp/test_runtime_no_injection.json', 'w') as f:
|
||||
json.dump({
|
||||
'TODAY_DATE': '2025-01-15',
|
||||
'SIGNATURE': 'test-model',
|
||||
'JOB_ID': 'test-job-123',
|
||||
'TRADING_DAY_ID': trading_day_id
|
||||
}, f)
|
||||
|
||||
# Mock prices
|
||||
def mock_get_prices(date, symbols):
|
||||
if 'AAPL' in symbols:
|
||||
return {'AAPL_price': 200.0}
|
||||
elif 'MSFT' in symbols:
|
||||
return {'MSFT_price': 150.0}
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr('agent_tools.tool_trade.get_open_prices', mock_get_prices)
|
||||
|
||||
# Step 1: Sell 3 shares of AAPL
|
||||
result_sell = _sell_impl(
|
||||
symbol='AAPL',
|
||||
amount=3,
|
||||
signature='test-model',
|
||||
today_date='2025-01-15',
|
||||
job_id='test-job-123',
|
||||
trading_day_id=trading_day_id,
|
||||
_current_position=None # Don't inject position (old behavior)
|
||||
)
|
||||
|
||||
assert 'error' not in result_sell, "Sell should succeed"
|
||||
|
||||
# Step 2: Try to buy 7 shares of MSFT WITHOUT passing updated position
|
||||
# This should FAIL because it will query the database and get the original 500.0 cash
|
||||
result_buy = _buy_impl(
|
||||
symbol='MSFT',
|
||||
amount=7,
|
||||
signature='test-model',
|
||||
today_date='2025-01-15',
|
||||
job_id='test-job-123',
|
||||
trading_day_id=trading_day_id,
|
||||
_current_position=None # Don't inject position (old behavior)
|
||||
)
|
||||
|
||||
# This should fail with insufficient cash
|
||||
assert 'error' in result_buy, "Buy should fail without position injection"
|
||||
assert result_buy['error'] == 'Insufficient cash', f"Expected insufficient cash error, got: {result_buy}"
|
||||
assert result_buy['cash_available'] == 500.0, "Should see original cash, not updated cash"
|
||||
|
||||
Reference in New Issue
Block a user