Files
AI-Trader/tests/unit/test_general_tools.py
Bill 14cf88f642 test: improve test coverage from 61% to 84.81%
Major improvements:
- Fixed all 42 broken tests (database connection leaks)
- Added db_connection() context manager for proper cleanup
- Created comprehensive test suites for undertested modules

New test coverage:
- tools/general_tools.py: 26 tests (97% coverage)
- tools/price_tools.py: 11 tests (validates NASDAQ symbols, date handling)
- api/price_data_manager.py: 12 tests (85% coverage)
- api/routes/results_v2.py: 3 tests (98% coverage)
- agent/reasoning_summarizer.py: 2 tests (87% coverage)
- api/routes/period_metrics.py: 2 edge case tests (100% coverage)
- agent/mock_provider: 1 test (100% coverage)

Database fixes:
- Added db_connection() context manager to prevent leaks
- Updated 16+ test files to use context managers
- Fixed drop_all_tables() to match new schema
- Added CHECK constraint for action_type
- Added ON DELETE CASCADE to trading_days foreign key

Test improvements:
- Updated SQL INSERT statements with all required fields
- Fixed date parameter handling in API integration tests
- Added edge case tests for validation functions
- Fixed import errors across test suite

Results:
- Total coverage: 84.81% (was 61%)
- Tests passing: 406 (was 364 with 42 failures)
- Total lines covered: 6364 of 7504

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-07 21:02:38 -05:00

329 lines
11 KiB
Python

"""Unit tests for tools/general_tools.py"""
import pytest
import os
import json
import tempfile
from pathlib import Path
from tools.general_tools import (
get_config_value,
write_config_value,
extract_conversation,
extract_tool_messages,
extract_first_tool_message_content
)
@pytest.fixture
def temp_runtime_env(tmp_path):
"""Create temporary runtime environment file."""
env_file = tmp_path / "runtime_env.json"
original_path = os.environ.get("RUNTIME_ENV_PATH")
os.environ["RUNTIME_ENV_PATH"] = str(env_file)
yield env_file
# Cleanup
if original_path:
os.environ["RUNTIME_ENV_PATH"] = original_path
else:
os.environ.pop("RUNTIME_ENV_PATH", None)
@pytest.mark.unit
class TestConfigManagement:
"""Test configuration value reading and writing."""
def test_get_config_value_from_env(self):
"""Should read from environment variables."""
os.environ["TEST_KEY"] = "test_value"
result = get_config_value("TEST_KEY")
assert result == "test_value"
os.environ.pop("TEST_KEY")
def test_get_config_value_default(self):
"""Should return default when key not found."""
result = get_config_value("NONEXISTENT_KEY", "default_value")
assert result == "default_value"
def test_get_config_value_from_runtime_env(self, temp_runtime_env):
"""Should read from runtime env file."""
temp_runtime_env.write_text('{"RUNTIME_KEY": "runtime_value"}')
result = get_config_value("RUNTIME_KEY")
assert result == "runtime_value"
def test_get_config_value_runtime_overrides_env(self, temp_runtime_env):
"""Runtime env should override environment variables."""
os.environ["OVERRIDE_KEY"] = "env_value"
temp_runtime_env.write_text('{"OVERRIDE_KEY": "runtime_value"}')
result = get_config_value("OVERRIDE_KEY")
assert result == "runtime_value"
os.environ.pop("OVERRIDE_KEY")
def test_write_config_value_creates_file(self, temp_runtime_env):
"""Should create runtime env file if it doesn't exist."""
write_config_value("NEW_KEY", "new_value")
assert temp_runtime_env.exists()
data = json.loads(temp_runtime_env.read_text())
assert data["NEW_KEY"] == "new_value"
def test_write_config_value_updates_existing(self, temp_runtime_env):
"""Should update existing values in runtime env."""
temp_runtime_env.write_text('{"EXISTING": "old"}')
write_config_value("EXISTING", "new")
write_config_value("ANOTHER", "value")
data = json.loads(temp_runtime_env.read_text())
assert data["EXISTING"] == "new"
assert data["ANOTHER"] == "value"
def test_write_config_value_no_path_set(self, capsys):
"""Should warn when RUNTIME_ENV_PATH not set."""
os.environ.pop("RUNTIME_ENV_PATH", None)
write_config_value("TEST", "value")
captured = capsys.readouterr()
assert "WARNING" in captured.out
assert "RUNTIME_ENV_PATH not set" in captured.out
@pytest.mark.unit
class TestExtractConversation:
"""Test conversation extraction functions."""
def test_extract_conversation_final_with_stop(self):
"""Should extract final message with finish_reason='stop'."""
conversation = {
"messages": [
{"content": "Hello", "response_metadata": {"finish_reason": "stop"}},
{"content": "World", "response_metadata": {"finish_reason": "stop"}}
]
}
result = extract_conversation(conversation, "final")
assert result == "World"
def test_extract_conversation_final_fallback(self):
"""Should fallback to last non-tool message."""
conversation = {
"messages": [
{"content": "First message"},
{"content": "Second message"},
{"content": "", "additional_kwargs": {"tool_calls": [{"name": "tool"}]}}
]
}
result = extract_conversation(conversation, "final")
assert result == "Second message"
def test_extract_conversation_final_no_messages(self):
"""Should return None when no suitable messages."""
conversation = {"messages": []}
result = extract_conversation(conversation, "final")
assert result is None
def test_extract_conversation_final_only_tool_calls(self):
"""Should return None when only tool calls exist."""
conversation = {
"messages": [
{"content": "tool result", "tool_call_id": "123"}
]
}
result = extract_conversation(conversation, "final")
assert result is None
def test_extract_conversation_all(self):
"""Should return all messages."""
messages = [
{"content": "Message 1"},
{"content": "Message 2"}
]
conversation = {"messages": messages}
result = extract_conversation(conversation, "all")
assert result == messages
def test_extract_conversation_invalid_type(self):
"""Should raise ValueError for invalid output_type."""
conversation = {"messages": []}
with pytest.raises(ValueError, match="output_type must be 'final' or 'all'"):
extract_conversation(conversation, "invalid")
def test_extract_conversation_missing_messages(self):
"""Should handle missing messages gracefully."""
conversation = {}
result = extract_conversation(conversation, "all")
assert result == []
result = extract_conversation(conversation, "final")
assert result is None
@pytest.mark.unit
class TestExtractToolMessages:
"""Test tool message extraction."""
def test_extract_tool_messages_with_tool_call_id(self):
"""Should extract messages with tool_call_id."""
conversation = {
"messages": [
{"content": "Regular message"},
{"content": "Tool result", "tool_call_id": "call_123"},
{"content": "Another regular"}
]
}
result = extract_tool_messages(conversation)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_123"
def test_extract_tool_messages_with_name(self):
"""Should extract messages with tool name."""
conversation = {
"messages": [
{"content": "Tool output", "name": "get_price"},
{"content": "AI response", "response_metadata": {"finish_reason": "stop"}}
]
}
result = extract_tool_messages(conversation)
assert len(result) == 1
assert result[0]["name"] == "get_price"
def test_extract_tool_messages_none_found(self):
"""Should return empty list when no tool messages."""
conversation = {
"messages": [
{"content": "Message 1"},
{"content": "Message 2"}
]
}
result = extract_tool_messages(conversation)
assert result == []
def test_extract_first_tool_message_content(self):
"""Should extract content from first tool message."""
conversation = {
"messages": [
{"content": "Regular"},
{"content": "First tool", "tool_call_id": "1"},
{"content": "Second tool", "tool_call_id": "2"}
]
}
result = extract_first_tool_message_content(conversation)
assert result == "First tool"
def test_extract_first_tool_message_content_none(self):
"""Should return None when no tool messages."""
conversation = {"messages": [{"content": "Regular"}]}
result = extract_first_tool_message_content(conversation)
assert result is None
def test_extract_tool_messages_object_based(self):
"""Should work with object-based messages."""
class Message:
def __init__(self, content, tool_call_id=None):
self.content = content
self.tool_call_id = tool_call_id
conversation = {
"messages": [
Message("Regular"),
Message("Tool result", tool_call_id="abc123")
]
}
result = extract_tool_messages(conversation)
assert len(result) == 1
assert result[0].tool_call_id == "abc123"
@pytest.mark.unit
class TestEdgeCases:
"""Test edge cases and error handling."""
def test_get_config_value_none_default(self):
"""Should handle None as default value."""
result = get_config_value("MISSING_KEY", None)
assert result is None
def test_extract_conversation_whitespace_only(self):
"""Should skip whitespace-only content."""
conversation = {
"messages": [
{"content": " ", "response_metadata": {"finish_reason": "stop"}},
{"content": "Valid content"}
]
}
result = extract_conversation(conversation, "final")
assert result == "Valid content"
def test_write_config_value_with_special_chars(self, temp_runtime_env):
"""Should handle special characters in values."""
write_config_value("SPECIAL", "value with 日本語 and émojis 🎉")
data = json.loads(temp_runtime_env.read_text())
assert data["SPECIAL"] == "value with 日本語 and émojis 🎉"
def test_write_config_value_invalid_path(self, capsys):
"""Should handle write errors gracefully."""
os.environ["RUNTIME_ENV_PATH"] = "/invalid/nonexistent/path/config.json"
write_config_value("TEST", "value")
captured = capsys.readouterr()
assert "Error writing config" in captured.out
# Cleanup
os.environ.pop("RUNTIME_ENV_PATH", None)
def test_extract_conversation_with_object_messages(self):
"""Should work with object-based messages (not just dicts)."""
class Message:
def __init__(self, content, response_metadata=None):
self.content = content
self.response_metadata = response_metadata or {}
class ResponseMetadata:
def __init__(self, finish_reason):
self.finish_reason = finish_reason
conversation = {
"messages": [
Message("First", ResponseMetadata("stop")),
Message("Second", ResponseMetadata("stop"))
]
}
result = extract_conversation(conversation, "final")
assert result == "Second"
def test_extract_first_tool_message_content_with_object(self):
"""Should extract content from object-based tool messages."""
class ToolMessage:
def __init__(self, content):
self.content = content
self.tool_call_id = "test123"
conversation = {
"messages": [
ToolMessage("Tool output")
]
}
result = extract_first_tool_message_content(conversation)
assert result == "Tool output"