mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-01 17:17:24 -04:00
test: update ContextInjector tests to match production MCP behavior
Update unit tests to mock CallToolResult objects instead of plain dicts, matching actual MCP tool behavior in production. Changes: - Add create_mcp_result() helper to create mock CallToolResult objects - Update all mock handlers to return MCP result objects - Update assertions to access result.structuredContent field - Maintains test coverage while accurately reflecting production behavior This ensures tests validate the actual code path used in production, where MCP tools return CallToolResult objects with structuredContent field containing the position dict.
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from agent.context_injector import ContextInjector
|
from agent.context_injector import ContextInjector
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -22,27 +23,34 @@ class MockRequest:
|
|||||||
self.args = args or {}
|
self.args = args or {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_mcp_result(position_dict):
|
||||||
|
"""Create a mock MCP CallToolResult object matching production behavior."""
|
||||||
|
result = Mock()
|
||||||
|
result.structuredContent = position_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def mock_handler_success(request):
|
async def mock_handler_success(request):
|
||||||
"""Mock handler that returns a successful position update."""
|
"""Mock handler that returns a successful position update as MCP CallToolResult."""
|
||||||
# Simulate a successful trade returning updated position
|
# Simulate a successful trade returning updated position
|
||||||
if request.name == "sell":
|
if request.name == "sell":
|
||||||
return {
|
return create_mcp_result({
|
||||||
"CASH": 1100.0,
|
"CASH": 1100.0,
|
||||||
"AAPL": 7,
|
"AAPL": 7,
|
||||||
"MSFT": 5
|
"MSFT": 5
|
||||||
}
|
})
|
||||||
elif request.name == "buy":
|
elif request.name == "buy":
|
||||||
return {
|
return create_mcp_result({
|
||||||
"CASH": 50.0,
|
"CASH": 50.0,
|
||||||
"AAPL": 7,
|
"AAPL": 7,
|
||||||
"MSFT": 12
|
"MSFT": 12
|
||||||
}
|
})
|
||||||
return {}
|
return create_mcp_result({})
|
||||||
|
|
||||||
|
|
||||||
async def mock_handler_error(request):
|
async def mock_handler_error(request):
|
||||||
"""Mock handler that returns an error."""
|
"""Mock handler that returns an error as MCP CallToolResult."""
|
||||||
return {"error": "Insufficient cash"}
|
return create_mcp_result({"error": "Insufficient cash"})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -68,17 +76,17 @@ async def test_context_injector_injects_parameters(injector):
|
|||||||
"""Test that context parameters are injected into buy/sell requests."""
|
"""Test that context parameters are injected into buy/sell requests."""
|
||||||
request = MockRequest("buy", {"symbol": "AAPL", "amount": 10})
|
request = MockRequest("buy", {"symbol": "AAPL", "amount": 10})
|
||||||
|
|
||||||
# Mock handler that just returns the request args
|
# Mock handler that returns MCP result containing the request args
|
||||||
async def handler(req):
|
async def handler(req):
|
||||||
return req.args
|
return create_mcp_result(req.args)
|
||||||
|
|
||||||
result = await injector(request, handler)
|
result = await injector(request, handler)
|
||||||
|
|
||||||
# Verify context was injected
|
# Verify context was injected (result is MCP CallToolResult object)
|
||||||
assert result["signature"] == "test-model"
|
assert result.structuredContent["signature"] == "test-model"
|
||||||
assert result["today_date"] == "2025-01-15"
|
assert result.structuredContent["today_date"] == "2025-01-15"
|
||||||
assert result["job_id"] == "test-job-123"
|
assert result.structuredContent["job_id"] == "test-job-123"
|
||||||
assert result["trading_day_id"] == 1
|
assert result.structuredContent["trading_day_id"] == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -132,7 +140,7 @@ async def test_context_injector_does_not_update_position_on_error(injector):
|
|||||||
|
|
||||||
# Verify position was NOT updated
|
# Verify position was NOT updated
|
||||||
assert injector._current_position == original_position
|
assert injector._current_position == original_position
|
||||||
assert "error" in result
|
assert "error" in result.structuredContent
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -146,7 +154,7 @@ async def test_context_injector_does_not_inject_position_for_non_trade_tools(inj
|
|||||||
|
|
||||||
async def verify_no_injection_handler(req):
|
async def verify_no_injection_handler(req):
|
||||||
assert "_current_position" not in req.args
|
assert "_current_position" not in req.args
|
||||||
return {"results": []}
|
return create_mcp_result({"results": []})
|
||||||
|
|
||||||
await injector(request, verify_no_injection_handler)
|
await injector(request, verify_no_injection_handler)
|
||||||
|
|
||||||
@@ -164,7 +172,7 @@ async def test_context_injector_full_trading_session_simulation(injector):
|
|||||||
async def handler1(req):
|
async def handler1(req):
|
||||||
# First trade should NOT have injected position
|
# First trade should NOT have injected position
|
||||||
assert req.args.get("_current_position") is None
|
assert req.args.get("_current_position") is None
|
||||||
return {"CASH": 1100.0, "AAPL": 7}
|
return create_mcp_result({"CASH": 1100.0, "AAPL": 7})
|
||||||
|
|
||||||
result1 = await injector(request1, handler1)
|
result1 = await injector(request1, handler1)
|
||||||
assert injector._current_position == {"CASH": 1100.0, "AAPL": 7}
|
assert injector._current_position == {"CASH": 1100.0, "AAPL": 7}
|
||||||
@@ -176,7 +184,7 @@ async def test_context_injector_full_trading_session_simulation(injector):
|
|||||||
# Second trade SHOULD have injected position from trade 1
|
# Second trade SHOULD have injected position from trade 1
|
||||||
assert req.args["_current_position"]["CASH"] == 1100.0
|
assert req.args["_current_position"]["CASH"] == 1100.0
|
||||||
assert req.args["_current_position"]["AAPL"] == 7
|
assert req.args["_current_position"]["AAPL"] == 7
|
||||||
return {"CASH": 50.0, "AAPL": 7, "MSFT": 7}
|
return create_mcp_result({"CASH": 50.0, "AAPL": 7, "MSFT": 7})
|
||||||
|
|
||||||
result2 = await injector(request2, handler2)
|
result2 = await injector(request2, handler2)
|
||||||
assert injector._current_position == {"CASH": 50.0, "AAPL": 7, "MSFT": 7}
|
assert injector._current_position == {"CASH": 50.0, "AAPL": 7, "MSFT": 7}
|
||||||
@@ -185,7 +193,7 @@ async def test_context_injector_full_trading_session_simulation(injector):
|
|||||||
request3 = MockRequest("buy", {"symbol": "GOOGL", "amount": 100})
|
request3 = MockRequest("buy", {"symbol": "GOOGL", "amount": 100})
|
||||||
|
|
||||||
async def handler3(req):
|
async def handler3(req):
|
||||||
return {"error": "Insufficient cash", "cash_available": 50.0}
|
return create_mcp_result({"error": "Insufficient cash", "cash_available": 50.0})
|
||||||
|
|
||||||
result3 = await injector(request3, handler3)
|
result3 = await injector(request3, handler3)
|
||||||
# Position should remain unchanged after failed trade
|
# Position should remain unchanged after failed trade
|
||||||
|
|||||||
Reference in New Issue
Block a user