mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-12 21:47:23 -04:00
test: improve test coverage from 61% to 84.81%
Major improvements: - Fixed all 42 broken tests (database connection leaks) - Added db_connection() context manager for proper cleanup - Created comprehensive test suites for undertested modules New test coverage: - tools/general_tools.py: 26 tests (97% coverage) - tools/price_tools.py: 11 tests (validates NASDAQ symbols, date handling) - api/price_data_manager.py: 12 tests (85% coverage) - api/routes/results_v2.py: 3 tests (98% coverage) - agent/reasoning_summarizer.py: 2 tests (87% coverage) - api/routes/period_metrics.py: 2 edge case tests (100% coverage) - agent/mock_provider: 1 test (100% coverage) Database fixes: - Added db_connection() context manager to prevent leaks - Updated 16+ test files to use context managers - Fixed drop_all_tables() to match new schema - Added CHECK constraint for action_type - Added ON DELETE CASCADE to trading_days foreign key Test improvements: - Updated SQL INSERT statements with all required fields - Fixed date parameter handling in API integration tests - Added edge case tests for validation functions - Fixed import errors across test suite Results: - Total coverage: 84.81% (was 61%) - Tests passing: 406 (was 364 with 42 failures) - Total lines covered: 6364 of 7504 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -19,7 +19,7 @@ from api.price_data_manager import (
|
||||
RateLimitError,
|
||||
DownloadError
|
||||
)
|
||||
from api.database import initialize_database, get_db_connection
|
||||
from api.database import initialize_database, get_db_connection, db_connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -168,6 +168,21 @@ class TestPriceDataManagerInit:
|
||||
assert manager.api_key is None
|
||||
|
||||
|
||||
class TestGetAvailableDates:
|
||||
"""Test get_available_dates method."""
|
||||
|
||||
def test_get_available_dates_with_data(self, manager, populated_db):
|
||||
"""Test retrieving all dates from database."""
|
||||
manager.db_path = populated_db
|
||||
dates = manager.get_available_dates()
|
||||
assert dates == {"2025-01-20", "2025-01-21"}
|
||||
|
||||
def test_get_available_dates_empty_database(self, manager):
|
||||
"""Test retrieving dates from empty database."""
|
||||
dates = manager.get_available_dates()
|
||||
assert dates == set()
|
||||
|
||||
|
||||
class TestGetSymbolDates:
|
||||
"""Test get_symbol_dates method."""
|
||||
|
||||
@@ -232,6 +247,35 @@ class TestGetMissingCoverage:
|
||||
assert missing["GOOGL"] == {"2025-01-21"}
|
||||
|
||||
|
||||
class TestExpandDateRange:
|
||||
"""Test _expand_date_range method."""
|
||||
|
||||
def test_expand_single_date(self, manager):
|
||||
"""Test expanding a single date range."""
|
||||
dates = manager._expand_date_range("2025-01-20", "2025-01-20")
|
||||
assert dates == {"2025-01-20"}
|
||||
|
||||
def test_expand_multiple_dates(self, manager):
|
||||
"""Test expanding multiple date range."""
|
||||
dates = manager._expand_date_range("2025-01-20", "2025-01-22")
|
||||
assert dates == {"2025-01-20", "2025-01-21", "2025-01-22"}
|
||||
|
||||
def test_expand_week_range(self, manager):
|
||||
"""Test expanding a week-long range."""
|
||||
dates = manager._expand_date_range("2025-01-20", "2025-01-26")
|
||||
assert len(dates) == 7
|
||||
assert "2025-01-20" in dates
|
||||
assert "2025-01-26" in dates
|
||||
|
||||
def test_expand_month_range(self, manager):
|
||||
"""Test expanding a month-long range."""
|
||||
dates = manager._expand_date_range("2025-01-01", "2025-01-31")
|
||||
assert len(dates) == 31
|
||||
assert "2025-01-01" in dates
|
||||
assert "2025-01-15" in dates
|
||||
assert "2025-01-31" in dates
|
||||
|
||||
|
||||
class TestPrioritizeDownloads:
|
||||
"""Test prioritize_downloads method."""
|
||||
|
||||
@@ -287,6 +331,26 @@ class TestPrioritizeDownloads:
|
||||
# Only AAPL should be included
|
||||
assert prioritized == ["AAPL"]
|
||||
|
||||
def test_prioritize_many_symbols(self, manager):
|
||||
"""Test prioritization with many symbols (exercises debug logging)."""
|
||||
# Create 10 symbols with varying impact
|
||||
missing_coverage = {}
|
||||
for i in range(10):
|
||||
symbol = f"SYM{i}"
|
||||
# Each symbol missing progressively fewer dates
|
||||
missing_coverage[symbol] = {f"2025-01-{20+j}" for j in range(10-i)}
|
||||
|
||||
requested_dates = {f"2025-01-{20+j}" for j in range(10)}
|
||||
|
||||
prioritized = manager.prioritize_downloads(missing_coverage, requested_dates)
|
||||
|
||||
# Should return all 10 symbols, sorted by impact
|
||||
assert len(prioritized) == 10
|
||||
# First symbol should have highest impact (SYM0 with 10 dates)
|
||||
assert prioritized[0] == "SYM0"
|
||||
# Last symbol should have lowest impact (SYM9 with 1 date)
|
||||
assert prioritized[-1] == "SYM9"
|
||||
|
||||
|
||||
class TestGetAvailableTradingDates:
|
||||
"""Test get_available_trading_dates method."""
|
||||
@@ -422,12 +486,11 @@ class TestStoreSymbolData:
|
||||
assert set(stored_dates) == {"2025-01-20", "2025-01-21"}
|
||||
|
||||
# Verify data in database
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM price_data WHERE symbol = 'AAPL'")
|
||||
count = cursor.fetchone()[0]
|
||||
assert count == 2
|
||||
conn.close()
|
||||
with db_connection(manager.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM price_data WHERE symbol = 'AAPL'")
|
||||
count = cursor.fetchone()[0]
|
||||
assert count == 2
|
||||
|
||||
def test_store_filters_by_requested_dates(self, manager):
|
||||
"""Test that only requested dates are stored."""
|
||||
@@ -458,12 +521,11 @@ class TestStoreSymbolData:
|
||||
assert set(stored_dates) == {"2025-01-20"}
|
||||
|
||||
# Verify only one date in database
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM price_data WHERE symbol = 'AAPL'")
|
||||
count = cursor.fetchone()[0]
|
||||
assert count == 1
|
||||
conn.close()
|
||||
with db_connection(manager.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM price_data WHERE symbol = 'AAPL'")
|
||||
count = cursor.fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestUpdateCoverage:
|
||||
@@ -473,15 +535,14 @@ class TestUpdateCoverage:
|
||||
"""Test coverage tracking for new symbol."""
|
||||
manager._update_coverage("AAPL", "2025-01-20", "2025-01-21")
|
||||
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT symbol, start_date, end_date, source
|
||||
FROM price_data_coverage
|
||||
WHERE symbol = 'AAPL'
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
with db_connection(manager.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT symbol, start_date, end_date, source
|
||||
FROM price_data_coverage
|
||||
WHERE symbol = 'AAPL'
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
|
||||
assert row is not None
|
||||
assert row[0] == "AAPL"
|
||||
@@ -496,13 +557,12 @@ class TestUpdateCoverage:
|
||||
# Update with new range
|
||||
manager._update_coverage("AAPL", "2025-01-22", "2025-01-23")
|
||||
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM price_data_coverage WHERE symbol = 'AAPL'
|
||||
""")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
with db_connection(manager.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM price_data_coverage WHERE symbol = 'AAPL'
|
||||
""")
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
# Should have 2 coverage records now
|
||||
assert count == 2
|
||||
@@ -570,3 +630,95 @@ class TestDownloadMissingDataPrioritized:
|
||||
assert result["success"] is False
|
||||
assert len(result["downloaded"]) == 0
|
||||
assert len(result["failed"]) == 1
|
||||
|
||||
def test_download_no_missing_coverage(self, manager):
|
||||
"""Test early return when no downloads needed."""
|
||||
missing_coverage = {} # No missing data
|
||||
requested_dates = {"2025-01-20", "2025-01-21"}
|
||||
|
||||
result = manager.download_missing_data_prioritized(missing_coverage, requested_dates)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["downloaded"] == []
|
||||
assert result["failed"] == []
|
||||
assert result["rate_limited"] is False
|
||||
assert sorted(result["dates_completed"]) == sorted(requested_dates)
|
||||
|
||||
def test_download_missing_api_key(self, temp_db, temp_symbols_config):
|
||||
"""Test error when API key is missing."""
|
||||
manager_no_key = PriceDataManager(
|
||||
db_path=temp_db,
|
||||
symbols_config=temp_symbols_config,
|
||||
api_key=None
|
||||
)
|
||||
|
||||
missing_coverage = {"AAPL": {"2025-01-20"}}
|
||||
requested_dates = {"2025-01-20"}
|
||||
|
||||
with pytest.raises(ValueError, match="ALPHAADVANTAGE_API_KEY not configured"):
|
||||
manager_no_key.download_missing_data_prioritized(missing_coverage, requested_dates)
|
||||
|
||||
@patch.object(PriceDataManager, '_update_coverage')
|
||||
@patch.object(PriceDataManager, '_store_symbol_data')
|
||||
@patch.object(PriceDataManager, '_download_symbol')
|
||||
def test_download_with_progress_callback(self, mock_download, mock_store, mock_update, manager):
|
||||
"""Test download with progress callback."""
|
||||
missing_coverage = {"AAPL": {"2025-01-20"}, "MSFT": {"2025-01-20"}}
|
||||
requested_dates = {"2025-01-20"}
|
||||
|
||||
# Mock successful downloads
|
||||
mock_download.return_value = {"Time Series (Daily)": {}}
|
||||
mock_store.return_value = {"2025-01-20"}
|
||||
|
||||
# Track progress callbacks
|
||||
progress_updates = []
|
||||
|
||||
def progress_callback(info):
|
||||
progress_updates.append(info)
|
||||
|
||||
result = manager.download_missing_data_prioritized(
|
||||
missing_coverage,
|
||||
requested_dates,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
# Verify progress callbacks were made
|
||||
assert len(progress_updates) == 2 # One for each symbol
|
||||
assert progress_updates[0]["current"] == 1
|
||||
assert progress_updates[0]["total"] == 2
|
||||
assert progress_updates[0]["phase"] == "downloading"
|
||||
assert progress_updates[1]["current"] == 2
|
||||
assert progress_updates[1]["total"] == 2
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["downloaded"]) == 2
|
||||
|
||||
@patch.object(PriceDataManager, '_update_coverage')
|
||||
@patch.object(PriceDataManager, '_store_symbol_data')
|
||||
@patch.object(PriceDataManager, '_download_symbol')
|
||||
def test_download_partial_success_with_errors(self, mock_download, mock_store, mock_update, manager):
|
||||
"""Test download with some successes and some failures."""
|
||||
missing_coverage = {
|
||||
"AAPL": {"2025-01-20"},
|
||||
"MSFT": {"2025-01-20"},
|
||||
"GOOGL": {"2025-01-20"}
|
||||
}
|
||||
requested_dates = {"2025-01-20"}
|
||||
|
||||
# First download succeeds, second fails, third succeeds
|
||||
mock_download.side_effect = [
|
||||
{"Time Series (Daily)": {}}, # AAPL success
|
||||
DownloadError("Network error"), # MSFT fails
|
||||
{"Time Series (Daily)": {}} # GOOGL success
|
||||
]
|
||||
mock_store.return_value = {"2025-01-20"}
|
||||
|
||||
result = manager.download_missing_data_prioritized(missing_coverage, requested_dates)
|
||||
|
||||
# Should have partial success
|
||||
assert result["success"] is True # At least one succeeded
|
||||
assert len(result["downloaded"]) == 2 # AAPL and GOOGL
|
||||
assert len(result["failed"]) == 1 # MSFT
|
||||
assert "AAPL" in result["downloaded"]
|
||||
assert "GOOGL" in result["downloaded"]
|
||||
assert "MSFT" in result["failed"]
|
||||
|
||||
Reference in New Issue
Block a user