diff --git a/tests/unit/test_context_injector.py b/tests/unit/test_context_injector.py index f0e7d50..e0e63cd 100644 --- a/tests/unit/test_context_injector.py +++ b/tests/unit/test_context_injector.py @@ -2,6 +2,7 @@ import pytest from agent.context_injector import ContextInjector +from unittest.mock import Mock @pytest.fixture @@ -22,27 +23,34 @@ class MockRequest: self.args = args or {} +def create_mcp_result(position_dict): + """Create a mock MCP CallToolResult object matching production behavior.""" + result = Mock() + result.structuredContent = position_dict + return result + + async def mock_handler_success(request): - """Mock handler that returns a successful position update.""" + """Mock handler that returns a successful position update as MCP CallToolResult.""" # Simulate a successful trade returning updated position if request.name == "sell": - return { + return create_mcp_result({ "CASH": 1100.0, "AAPL": 7, "MSFT": 5 - } + }) elif request.name == "buy": - return { + return create_mcp_result({ "CASH": 50.0, "AAPL": 7, "MSFT": 12 - } - return {} + }) + return create_mcp_result({}) async def mock_handler_error(request): - """Mock handler that returns an error.""" - return {"error": "Insufficient cash"} + """Mock handler that returns an error as MCP CallToolResult.""" + return create_mcp_result({"error": "Insufficient cash"}) @pytest.mark.asyncio @@ -68,17 +76,17 @@ async def test_context_injector_injects_parameters(injector): """Test that context parameters are injected into buy/sell requests.""" request = MockRequest("buy", {"symbol": "AAPL", "amount": 10}) - # Mock handler that just returns the request args + # Mock handler that returns MCP result containing the request args async def handler(req): - return req.args + return create_mcp_result(req.args) result = await injector(request, handler) - # Verify context was injected - assert result["signature"] == "test-model" - assert result["today_date"] == "2025-01-15" - assert result["job_id"] == "test-job-123" - assert result["trading_day_id"] == 1 + # Verify context was injected (result is MCP CallToolResult object) + assert result.structuredContent["signature"] == "test-model" + assert result.structuredContent["today_date"] == "2025-01-15" + assert result.structuredContent["job_id"] == "test-job-123" + assert result.structuredContent["trading_day_id"] == 1 @pytest.mark.asyncio @@ -132,7 +140,7 @@ async def test_context_injector_does_not_update_position_on_error(injector): # Verify position was NOT updated assert injector._current_position == original_position - assert "error" in result + assert "error" in result.structuredContent @pytest.mark.asyncio @@ -146,7 +154,7 @@ async def test_context_injector_does_not_inject_position_for_non_trade_tools(inj async def verify_no_injection_handler(req): assert "_current_position" not in req.args - return {"results": []} + return create_mcp_result({"results": []}) await injector(request, verify_no_injection_handler) @@ -164,7 +172,7 @@ async def test_context_injector_full_trading_session_simulation(injector): async def handler1(req): # First trade should NOT have injected position assert req.args.get("_current_position") is None - return {"CASH": 1100.0, "AAPL": 7} + return create_mcp_result({"CASH": 1100.0, "AAPL": 7}) result1 = await injector(request1, handler1) assert injector._current_position == {"CASH": 1100.0, "AAPL": 7} @@ -176,7 +184,7 @@ async def test_context_injector_full_trading_session_simulation(injector): # Second trade SHOULD have injected position from trade 1 assert req.args["_current_position"]["CASH"] == 1100.0 assert req.args["_current_position"]["AAPL"] == 7 - return {"CASH": 50.0, "AAPL": 7, "MSFT": 7} + return create_mcp_result({"CASH": 50.0, "AAPL": 7, "MSFT": 7}) result2 = await injector(request2, handler2) assert injector._current_position == {"CASH": 50.0, "AAPL": 7, "MSFT": 7} @@ -185,7 +193,7 @@ async def test_context_injector_full_trading_session_simulation(injector): request3 = MockRequest("buy", {"symbol": "GOOGL", "amount": 100}) async def handler3(req): - return {"error": "Insufficient cash", "cash_available": 50.0} + return create_mcp_result({"error": "Insufficient cash", "cash_available": 50.0}) result3 = await injector(request3, handler3) # Position should remain unchanged after failed trade