mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 01:27:24 -04:00
Major improvements: - Fixed all 42 broken tests (database connection leaks) - Added db_connection() context manager for proper cleanup - Created comprehensive test suites for undertested modules New test coverage: - tools/general_tools.py: 26 tests (97% coverage) - tools/price_tools.py: 11 tests (validates NASDAQ symbols, date handling) - api/price_data_manager.py: 12 tests (85% coverage) - api/routes/results_v2.py: 3 tests (98% coverage) - agent/reasoning_summarizer.py: 2 tests (87% coverage) - api/routes/period_metrics.py: 2 edge case tests (100% coverage) - agent/mock_provider: 1 test (100% coverage) Database fixes: - Added db_connection() context manager to prevent leaks - Updated 16+ test files to use context managers - Fixed drop_all_tables() to match new schema - Added CHECK constraint for action_type - Added ON DELETE CASCADE to trading_days foreign key Test improvements: - Updated SQL INSERT statements with all required fields - Fixed date parameter handling in API integration tests - Added edge case tests for validation functions - Fixed import errors across test suite Results: - Total coverage: 84.81% (was 61%) - Tests passing: 406 (was 364 with 42 failures) - Total lines covered: 6364 of 7504 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
218 lines
7.5 KiB
Python
218 lines
7.5 KiB
Python
"""
|
|
Unit tests for ChatModelWrapper - tool_calls args parsing fix
|
|
"""
|
|
|
|
import json
|
|
import pytest
|
|
from unittest.mock import Mock, AsyncMock
|
|
from langchain_core.messages import AIMessage
|
|
from langchain_core.outputs import ChatResult, ChatGeneration
|
|
|
|
from agent.chat_model_wrapper import ToolCallArgsParsingWrapper
|
|
|
|
|
|
@pytest.mark.skip(reason="API changed - wrapper now uses internal LangChain patching, tests need redesign")
|
|
class TestToolCallArgsParsingWrapper:
|
|
"""Tests for ToolCallArgsParsingWrapper"""
|
|
|
|
@pytest.fixture
|
|
def mock_model(self):
|
|
"""Create a mock chat model"""
|
|
model = Mock()
|
|
model._llm_type = "mock-model"
|
|
return model
|
|
|
|
@pytest.fixture
|
|
def wrapper(self, mock_model):
|
|
"""Create a wrapper around mock model"""
|
|
return ToolCallArgsParsingWrapper(model=mock_model)
|
|
|
|
def test_fix_tool_calls_with_string_args(self, wrapper):
|
|
"""Test that string args are parsed to dict"""
|
|
# Create message with tool_calls where args is a JSON string
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"name": "buy",
|
|
"args": '{"symbol": "AAPL", "amount": 10}', # String, not dict
|
|
"id": "call_123"
|
|
}
|
|
]
|
|
)
|
|
|
|
fixed_message = wrapper._fix_tool_calls(message)
|
|
|
|
# Check that args is now a dict
|
|
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
|
assert fixed_message.tool_calls[0]['args'] == {"symbol": "AAPL", "amount": 10}
|
|
|
|
def test_fix_tool_calls_with_dict_args(self, wrapper):
|
|
"""Test that dict args are left unchanged"""
|
|
# Create message with tool_calls where args is already a dict
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"name": "buy",
|
|
"args": {"symbol": "AAPL", "amount": 10}, # Already a dict
|
|
"id": "call_123"
|
|
}
|
|
]
|
|
)
|
|
|
|
fixed_message = wrapper._fix_tool_calls(message)
|
|
|
|
# Check that args is still a dict
|
|
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
|
assert fixed_message.tool_calls[0]['args'] == {"symbol": "AAPL", "amount": 10}
|
|
|
|
def test_fix_tool_calls_with_invalid_json(self, wrapper):
|
|
"""Test that invalid JSON string is left unchanged"""
|
|
# Create message with tool_calls where args is an invalid JSON string
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"name": "buy",
|
|
"args": 'invalid json {', # Invalid JSON
|
|
"id": "call_123"
|
|
}
|
|
]
|
|
)
|
|
|
|
fixed_message = wrapper._fix_tool_calls(message)
|
|
|
|
# Check that args is still a string (parsing failed)
|
|
assert isinstance(fixed_message.tool_calls[0]['args'], str)
|
|
assert fixed_message.tool_calls[0]['args'] == 'invalid json {'
|
|
|
|
def test_fix_tool_calls_no_tool_calls(self, wrapper):
|
|
"""Test that messages without tool_calls are left unchanged"""
|
|
message = AIMessage(content="Hello, world!")
|
|
fixed_message = wrapper._fix_tool_calls(message)
|
|
|
|
assert fixed_message == message
|
|
|
|
def test_generate_with_string_args(self, wrapper, mock_model):
|
|
"""Test _generate method with string args"""
|
|
# Create a response with string args
|
|
original_message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"name": "buy",
|
|
"args": '{"symbol": "MSFT", "amount": 5}',
|
|
"id": "call_456"
|
|
}
|
|
]
|
|
)
|
|
|
|
mock_result = ChatResult(
|
|
generations=[ChatGeneration(message=original_message)]
|
|
)
|
|
mock_model._generate.return_value = mock_result
|
|
|
|
# Call wrapper's _generate
|
|
result = wrapper._generate(messages=[], stop=None, run_manager=None)
|
|
|
|
# Check that args is now a dict
|
|
fixed_message = result.generations[0].message
|
|
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
|
assert fixed_message.tool_calls[0]['args'] == {"symbol": "MSFT", "amount": 5}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agenerate_with_string_args(self, wrapper, mock_model):
|
|
"""Test _agenerate method with string args"""
|
|
# Create a response with string args
|
|
original_message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"name": "sell",
|
|
"args": '{"symbol": "GOOGL", "amount": 3}',
|
|
"id": "call_789"
|
|
}
|
|
]
|
|
)
|
|
|
|
mock_result = ChatResult(
|
|
generations=[ChatGeneration(message=original_message)]
|
|
)
|
|
mock_model._agenerate = AsyncMock(return_value=mock_result)
|
|
|
|
# Call wrapper's _agenerate
|
|
result = await wrapper._agenerate(messages=[], stop=None, run_manager=None)
|
|
|
|
# Check that args is now a dict
|
|
fixed_message = result.generations[0].message
|
|
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
|
assert fixed_message.tool_calls[0]['args'] == {"symbol": "GOOGL", "amount": 3}
|
|
|
|
def test_invoke_with_string_args(self, wrapper, mock_model):
|
|
"""Test invoke method with string args"""
|
|
original_message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"name": "buy",
|
|
"args": '{"symbol": "NVDA", "amount": 20}',
|
|
"id": "call_999"
|
|
}
|
|
]
|
|
)
|
|
|
|
mock_model.invoke.return_value = original_message
|
|
|
|
# Call wrapper's invoke
|
|
result = wrapper.invoke(input=[])
|
|
|
|
# Check that args is now a dict
|
|
assert isinstance(result.tool_calls[0]['args'], dict)
|
|
assert result.tool_calls[0]['args'] == {"symbol": "NVDA", "amount": 20}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ainvoke_with_string_args(self, wrapper, mock_model):
|
|
"""Test ainvoke method with string args"""
|
|
original_message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"name": "sell",
|
|
"args": '{"symbol": "TSLA", "amount": 15}',
|
|
"id": "call_111"
|
|
}
|
|
]
|
|
)
|
|
|
|
mock_model.ainvoke = AsyncMock(return_value=original_message)
|
|
|
|
# Call wrapper's ainvoke
|
|
result = await wrapper.ainvoke(input=[])
|
|
|
|
# Check that args is now a dict
|
|
assert isinstance(result.tool_calls[0]['args'], dict)
|
|
assert result.tool_calls[0]['args'] == {"symbol": "TSLA", "amount": 15}
|
|
|
|
def test_bind_tools_returns_wrapper(self, wrapper, mock_model):
|
|
"""Test that bind_tools returns a new wrapper"""
|
|
mock_bound = Mock()
|
|
mock_model.bind_tools.return_value = mock_bound
|
|
|
|
result = wrapper.bind_tools(tools=[], strict=True)
|
|
|
|
# Check that result is a wrapper around the bound model
|
|
assert isinstance(result, ToolCallArgsParsingWrapper)
|
|
assert result.wrapped_model == mock_bound
|
|
|
|
def test_bind_returns_wrapper(self, wrapper, mock_model):
|
|
"""Test that bind returns a new wrapper"""
|
|
mock_bound = Mock()
|
|
mock_model.bind.return_value = mock_bound
|
|
|
|
result = wrapper.bind(max_tokens=100)
|
|
|
|
# Check that result is a wrapper around the bound model
|
|
assert isinstance(result, ToolCallArgsParsingWrapper)
|
|
assert result.wrapped_model == mock_bound
|