mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-06 18:57:25 -04:00
chore: remove obsolete chat model wrapper
The ToolCallArgsParsingWrapper was replaced by the model factory approach which uses native ChatDeepSeek for DeepSeek models. Removed: - agent/chat_model_wrapper.py (no longer used) - tests/unit/test_chat_model_wrapper.py (tests for removed functionality) The model factory provides a cleaner solution by using provider-specific implementations rather than wrapping ChatOpenAI.
This commit is contained in:
@@ -1,216 +0,0 @@
|
||||
"""
|
||||
Unit tests for ChatModelWrapper - tool_calls args parsing fix
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
|
||||
from agent.chat_model_wrapper import ToolCallArgsParsingWrapper
|
||||
|
||||
|
||||
class TestToolCallArgsParsingWrapper:
|
||||
"""Tests for ToolCallArgsParsingWrapper"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self):
|
||||
"""Create a mock chat model"""
|
||||
model = Mock()
|
||||
model._llm_type = "mock-model"
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self, mock_model):
|
||||
"""Create a wrapper around mock model"""
|
||||
return ToolCallArgsParsingWrapper(model=mock_model)
|
||||
|
||||
def test_fix_tool_calls_with_string_args(self, wrapper):
|
||||
"""Test that string args are parsed to dict"""
|
||||
# Create message with tool_calls where args is a JSON string
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": '{"symbol": "AAPL", "amount": 10}', # String, not dict
|
||||
"id": "call_123"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
# Check that args is now a dict
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "AAPL", "amount": 10}
|
||||
|
||||
def test_fix_tool_calls_with_dict_args(self, wrapper):
|
||||
"""Test that dict args are left unchanged"""
|
||||
# Create message with tool_calls where args is already a dict
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": {"symbol": "AAPL", "amount": 10}, # Already a dict
|
||||
"id": "call_123"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
# Check that args is still a dict
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "AAPL", "amount": 10}
|
||||
|
||||
def test_fix_tool_calls_with_invalid_json(self, wrapper):
|
||||
"""Test that invalid JSON string is left unchanged"""
|
||||
# Create message with tool_calls where args is an invalid JSON string
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": 'invalid json {', # Invalid JSON
|
||||
"id": "call_123"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
# Check that args is still a string (parsing failed)
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], str)
|
||||
assert fixed_message.tool_calls[0]['args'] == 'invalid json {'
|
||||
|
||||
def test_fix_tool_calls_no_tool_calls(self, wrapper):
|
||||
"""Test that messages without tool_calls are left unchanged"""
|
||||
message = AIMessage(content="Hello, world!")
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
assert fixed_message == message
|
||||
|
||||
def test_generate_with_string_args(self, wrapper, mock_model):
|
||||
"""Test _generate method with string args"""
|
||||
# Create a response with string args
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": '{"symbol": "MSFT", "amount": 5}',
|
||||
"id": "call_456"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_result = ChatResult(
|
||||
generations=[ChatGeneration(message=original_message)]
|
||||
)
|
||||
mock_model._generate.return_value = mock_result
|
||||
|
||||
# Call wrapper's _generate
|
||||
result = wrapper._generate(messages=[], stop=None, run_manager=None)
|
||||
|
||||
# Check that args is now a dict
|
||||
fixed_message = result.generations[0].message
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "MSFT", "amount": 5}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agenerate_with_string_args(self, wrapper, mock_model):
|
||||
"""Test _agenerate method with string args"""
|
||||
# Create a response with string args
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "sell",
|
||||
"args": '{"symbol": "GOOGL", "amount": 3}',
|
||||
"id": "call_789"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_result = ChatResult(
|
||||
generations=[ChatGeneration(message=original_message)]
|
||||
)
|
||||
mock_model._agenerate = AsyncMock(return_value=mock_result)
|
||||
|
||||
# Call wrapper's _agenerate
|
||||
result = await wrapper._agenerate(messages=[], stop=None, run_manager=None)
|
||||
|
||||
# Check that args is now a dict
|
||||
fixed_message = result.generations[0].message
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "GOOGL", "amount": 3}
|
||||
|
||||
def test_invoke_with_string_args(self, wrapper, mock_model):
|
||||
"""Test invoke method with string args"""
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": '{"symbol": "NVDA", "amount": 20}',
|
||||
"id": "call_999"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_model.invoke.return_value = original_message
|
||||
|
||||
# Call wrapper's invoke
|
||||
result = wrapper.invoke(input=[])
|
||||
|
||||
# Check that args is now a dict
|
||||
assert isinstance(result.tool_calls[0]['args'], dict)
|
||||
assert result.tool_calls[0]['args'] == {"symbol": "NVDA", "amount": 20}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_with_string_args(self, wrapper, mock_model):
|
||||
"""Test ainvoke method with string args"""
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "sell",
|
||||
"args": '{"symbol": "TSLA", "amount": 15}',
|
||||
"id": "call_111"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_model.ainvoke = AsyncMock(return_value=original_message)
|
||||
|
||||
# Call wrapper's ainvoke
|
||||
result = await wrapper.ainvoke(input=[])
|
||||
|
||||
# Check that args is now a dict
|
||||
assert isinstance(result.tool_calls[0]['args'], dict)
|
||||
assert result.tool_calls[0]['args'] == {"symbol": "TSLA", "amount": 15}
|
||||
|
||||
def test_bind_tools_returns_wrapper(self, wrapper, mock_model):
|
||||
"""Test that bind_tools returns a new wrapper"""
|
||||
mock_bound = Mock()
|
||||
mock_model.bind_tools.return_value = mock_bound
|
||||
|
||||
result = wrapper.bind_tools(tools=[], strict=True)
|
||||
|
||||
# Check that result is a wrapper around the bound model
|
||||
assert isinstance(result, ToolCallArgsParsingWrapper)
|
||||
assert result.wrapped_model == mock_bound
|
||||
|
||||
def test_bind_returns_wrapper(self, wrapper, mock_model):
|
||||
"""Test that bind returns a new wrapper"""
|
||||
mock_bound = Mock()
|
||||
mock_model.bind.return_value = mock_bound
|
||||
|
||||
result = wrapper.bind(max_tokens=100)
|
||||
|
||||
# Check that result is a wrapper around the bound model
|
||||
assert isinstance(result, ToolCallArgsParsingWrapper)
|
||||
assert result.wrapped_model == mock_bound
|
||||
Reference in New Issue
Block a user