mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 01:27:24 -04:00
Major improvements: - Fixed all 42 broken tests (database connection leaks) - Added db_connection() context manager for proper cleanup - Created comprehensive test suites for undertested modules New test coverage: - tools/general_tools.py: 26 tests (97% coverage) - tools/price_tools.py: 11 tests (validates NASDAQ symbols, date handling) - api/price_data_manager.py: 12 tests (85% coverage) - api/routes/results_v2.py: 3 tests (98% coverage) - agent/reasoning_summarizer.py: 2 tests (87% coverage) - api/routes/period_metrics.py: 2 edge case tests (100% coverage) - agent/mock_provider: 1 test (100% coverage) Database fixes: - Added db_connection() context manager to prevent leaks - Updated 16+ test files to use context managers - Fixed drop_all_tables() to match new schema - Added CHECK constraint for action_type - Added ON DELETE CASCADE to trading_days foreign key Test improvements: - Updated SQL INSERT statements with all required fields - Fixed date parameter handling in API integration tests - Added edge case tests for validation functions - Fixed import errors across test suite Results: - Total coverage: 84.81% (was 61%) - Tests passing: 406 (was 364 with 42 failures) - Total lines covered: 6364 of 7504 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
242 lines
8.2 KiB
Python
242 lines
8.2 KiB
Python
"""Test ContextInjector position tracking functionality."""
|
|
|
|
import pytest
|
|
from agent.context_injector import ContextInjector
|
|
from unittest.mock import Mock
|
|
|
|
|
|
@pytest.fixture
|
|
def injector():
|
|
"""Create a ContextInjector instance for testing."""
|
|
return ContextInjector(
|
|
signature="test-model",
|
|
today_date="2025-01-15",
|
|
job_id="test-job-123",
|
|
trading_day_id=1
|
|
)
|
|
|
|
|
|
class MockRequest:
|
|
"""Mock MCP tool request."""
|
|
def __init__(self, name, args=None):
|
|
self.name = name
|
|
self.args = args or {}
|
|
|
|
|
|
def create_mcp_result(position_dict):
|
|
"""Create a mock MCP CallToolResult object matching production behavior."""
|
|
result = Mock()
|
|
result.structuredContent = position_dict
|
|
return result
|
|
|
|
|
|
async def mock_handler_success(request):
|
|
"""Mock handler that returns a successful position update as MCP CallToolResult."""
|
|
# Simulate a successful trade returning updated position
|
|
if request.name == "sell":
|
|
return create_mcp_result({
|
|
"CASH": 1100.0,
|
|
"AAPL": 7,
|
|
"MSFT": 5
|
|
})
|
|
elif request.name == "buy":
|
|
return create_mcp_result({
|
|
"CASH": 50.0,
|
|
"AAPL": 7,
|
|
"MSFT": 12
|
|
})
|
|
return create_mcp_result({})
|
|
|
|
|
|
async def mock_handler_error(request):
|
|
"""Mock handler that returns an error as MCP CallToolResult."""
|
|
return create_mcp_result({"error": "Insufficient cash"})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_initializes_with_no_position(injector):
|
|
"""Test that ContextInjector starts with no position state."""
|
|
assert injector._current_position is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_reset_position(injector):
|
|
"""Test that reset_position() clears position state."""
|
|
# Set some position state
|
|
injector._current_position = {"CASH": 5000.0, "AAPL": 10}
|
|
|
|
# Reset
|
|
injector.reset_position()
|
|
|
|
assert injector._current_position is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_injects_parameters(injector):
|
|
"""Test that context parameters are injected into buy/sell requests."""
|
|
request = MockRequest("buy", {"symbol": "AAPL", "amount": 10})
|
|
|
|
# Mock handler that returns MCP result containing the request args
|
|
async def handler(req):
|
|
return create_mcp_result(req.args)
|
|
|
|
result = await injector(request, handler)
|
|
|
|
# Verify context was injected (result is MCP CallToolResult object)
|
|
assert result.structuredContent["signature"] == "test-model"
|
|
assert result.structuredContent["today_date"] == "2025-01-15"
|
|
assert result.structuredContent["job_id"] == "test-job-123"
|
|
assert result.structuredContent["trading_day_id"] == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_tracks_position_after_successful_trade(injector):
|
|
"""Test that position state is updated after successful trades."""
|
|
assert injector._current_position is None
|
|
|
|
# Execute a sell trade
|
|
request = MockRequest("sell", {"symbol": "AAPL", "amount": 3})
|
|
result = await injector(request, mock_handler_success)
|
|
|
|
# Verify position was updated
|
|
assert injector._current_position is not None
|
|
assert injector._current_position["CASH"] == 1100.0
|
|
assert injector._current_position["AAPL"] == 7
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_injects_session_id():
|
|
"""Test that session_id is injected when provided."""
|
|
injector = ContextInjector(
|
|
signature="test-sig",
|
|
today_date="2025-01-15",
|
|
session_id="test-session-123"
|
|
)
|
|
|
|
request = MockRequest("buy", {"symbol": "AAPL", "amount": 5})
|
|
|
|
async def capturing_handler(req):
|
|
# Verify session_id was injected
|
|
assert "session_id" in req.args
|
|
assert req.args["session_id"] == "test-session-123"
|
|
return create_mcp_result({"CASH": 100.0})
|
|
|
|
await injector(request, capturing_handler)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_handles_dict_result():
|
|
"""Test handling when handler returns a plain dict instead of CallToolResult."""
|
|
injector = ContextInjector(
|
|
signature="test-sig",
|
|
today_date="2025-01-15"
|
|
)
|
|
|
|
request = MockRequest("buy", {"symbol": "AAPL", "amount": 5})
|
|
|
|
async def dict_handler(req):
|
|
# Return plain dict instead of CallToolResult
|
|
return {"CASH": 500.0, "AAPL": 10}
|
|
|
|
result = await injector(request, dict_handler)
|
|
|
|
# Verify position was still updated
|
|
assert injector._current_position is not None
|
|
assert injector._current_position["CASH"] == 500.0
|
|
assert injector._current_position["AAPL"] == 10
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_injects_current_position_on_subsequent_trades(injector):
|
|
"""Test that current position is injected into subsequent trade requests."""
|
|
# First trade - establish position
|
|
request1 = MockRequest("sell", {"symbol": "AAPL", "amount": 3})
|
|
await injector(request1, mock_handler_success)
|
|
|
|
# Second trade - should receive current position
|
|
request2 = MockRequest("buy", {"symbol": "MSFT", "amount": 7})
|
|
|
|
async def verify_injection_handler(req):
|
|
# Verify that _current_position was injected
|
|
assert "_current_position" in req.args
|
|
assert req.args["_current_position"]["CASH"] == 1100.0
|
|
assert req.args["_current_position"]["AAPL"] == 7
|
|
return mock_handler_success(req)
|
|
|
|
await injector(request2, verify_injection_handler)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_does_not_update_position_on_error(injector):
|
|
"""Test that position state is NOT updated when trade fails."""
|
|
# First successful trade
|
|
request1 = MockRequest("sell", {"symbol": "AAPL", "amount": 3})
|
|
await injector(request1, mock_handler_success)
|
|
|
|
original_position = injector._current_position.copy()
|
|
|
|
# Second trade that fails
|
|
request2 = MockRequest("buy", {"symbol": "MSFT", "amount": 100})
|
|
result = await injector(request2, mock_handler_error)
|
|
|
|
# Verify position was NOT updated
|
|
assert injector._current_position == original_position
|
|
assert "error" in result.structuredContent
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_does_not_inject_position_for_non_trade_tools(injector):
|
|
"""Test that position is not injected for non-buy/sell tools."""
|
|
# Set up position state
|
|
injector._current_position = {"CASH": 5000.0, "AAPL": 10}
|
|
|
|
# Call a non-trade tool
|
|
request = MockRequest("search", {"query": "market news"})
|
|
|
|
async def verify_no_injection_handler(req):
|
|
assert "_current_position" not in req.args
|
|
return create_mcp_result({"results": []})
|
|
|
|
await injector(request, verify_no_injection_handler)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_injector_full_trading_session_simulation(injector):
|
|
"""Test full trading session with multiple trades and position tracking."""
|
|
# Reset position at start of day
|
|
injector.reset_position()
|
|
assert injector._current_position is None
|
|
|
|
# Trade 1: Sell AAPL
|
|
request1 = MockRequest("sell", {"symbol": "AAPL", "amount": 3})
|
|
|
|
async def handler1(req):
|
|
# First trade should NOT have injected position
|
|
assert req.args.get("_current_position") is None
|
|
return create_mcp_result({"CASH": 1100.0, "AAPL": 7})
|
|
|
|
result1 = await injector(request1, handler1)
|
|
assert injector._current_position == {"CASH": 1100.0, "AAPL": 7}
|
|
|
|
# Trade 2: Buy MSFT (should use position from trade 1)
|
|
request2 = MockRequest("buy", {"symbol": "MSFT", "amount": 7})
|
|
|
|
async def handler2(req):
|
|
# Second trade SHOULD have injected position from trade 1
|
|
assert req.args["_current_position"]["CASH"] == 1100.0
|
|
assert req.args["_current_position"]["AAPL"] == 7
|
|
return create_mcp_result({"CASH": 50.0, "AAPL": 7, "MSFT": 7})
|
|
|
|
result2 = await injector(request2, handler2)
|
|
assert injector._current_position == {"CASH": 50.0, "AAPL": 7, "MSFT": 7}
|
|
|
|
# Trade 3: Failed trade (should not update position)
|
|
request3 = MockRequest("buy", {"symbol": "GOOGL", "amount": 100})
|
|
|
|
async def handler3(req):
|
|
return create_mcp_result({"error": "Insufficient cash", "cash_available": 50.0})
|
|
|
|
result3 = await injector(request3, handler3)
|
|
# Position should remain unchanged after failed trade
|
|
assert injector._current_position == {"CASH": 50.0, "AAPL": 7, "MSFT": 7}
|