diff --git a/api/simulation_worker.py b/api/simulation_worker.py index 580cbf4..a03a83e 100644 --- a/api/simulation_worker.py +++ b/api/simulation_worker.py @@ -9,7 +9,7 @@ This module provides: """ import logging -from typing import Dict, Any, List +from typing import Dict, Any, List, Set from concurrent.futures import ThreadPoolExecutor, as_completed from api.job_manager import JobManager @@ -200,6 +200,42 @@ class SimulationWorker: "error": str(e) } + def _download_price_data( + self, + price_manager, + missing_coverage: Dict[str, Set[str]], + requested_dates: List[str], + warnings: List[str] + ) -> None: + """Download missing price data with progress logging.""" + logger.info(f"Job {self.job_id}: Starting prioritized download...") + + requested_dates_set = set(requested_dates) + + download_result = price_manager.download_missing_data_prioritized( + missing_coverage, + requested_dates_set + ) + + downloaded = len(download_result["downloaded"]) + failed = len(download_result["failed"]) + total = downloaded + failed + + logger.info( + f"Job {self.job_id}: Download complete - " + f"{downloaded}/{total} symbols succeeded" + ) + + if download_result["rate_limited"]: + msg = f"Rate limit reached - downloaded {downloaded}/{total} symbols" + warnings.append(msg) + logger.warning(f"Job {self.job_id}: {msg}") + + if failed > 0 and not download_result["rate_limited"]: + msg = f"{failed} symbols failed to download" + warnings.append(msg) + logger.warning(f"Job {self.job_id}: {msg}") + def get_job_info(self) -> Dict[str, Any]: """ Get job information. diff --git a/tests/unit/test_simulation_worker.py b/tests/unit/test_simulation_worker.py index 977b343..1deaabd 100644 --- a/tests/unit/test_simulation_worker.py +++ b/tests/unit/test_simulation_worker.py @@ -274,4 +274,65 @@ class TestSimulationWorkerJobRetrieval: assert job_info["models"] == ["gpt-5"] +@pytest.mark.unit +class TestSimulationWorkerHelperMethods: + """Test worker helper methods.""" + + def test_download_price_data_success(self, clean_db): + """Test successful price data download.""" + from api.simulation_worker import SimulationWorker + from api.database import initialize_database + + db_path = clean_db + initialize_database(db_path) + + worker = SimulationWorker(job_id="test-123", db_path=db_path) + + # Mock price manager + mock_price_manager = Mock() + mock_price_manager.download_missing_data_prioritized.return_value = { + "downloaded": ["AAPL", "MSFT"], + "failed": [], + "rate_limited": False + } + + warnings = [] + missing_coverage = {"AAPL": {"2025-10-01"}, "MSFT": {"2025-10-01"}} + + worker._download_price_data(mock_price_manager, missing_coverage, ["2025-10-01"], warnings) + + # Verify download was called + mock_price_manager.download_missing_data_prioritized.assert_called_once() + + # No warnings for successful download + assert len(warnings) == 0 + + def test_download_price_data_rate_limited(self, clean_db): + """Test price download with rate limit.""" + from api.simulation_worker import SimulationWorker + from api.database import initialize_database + + db_path = clean_db + initialize_database(db_path) + + worker = SimulationWorker(job_id="test-456", db_path=db_path) + + # Mock price manager + mock_price_manager = Mock() + mock_price_manager.download_missing_data_prioritized.return_value = { + "downloaded": ["AAPL"], + "failed": ["MSFT"], + "rate_limited": True + } + + warnings = [] + missing_coverage = {"AAPL": {"2025-10-01"}, "MSFT": {"2025-10-01"}} + + worker._download_price_data(mock_price_manager, missing_coverage, ["2025-10-01"], warnings) + + # Should add rate limit warning + assert len(warnings) == 1 + assert "Rate limit" in warnings[0] + + # Coverage target: 90%+ for api/simulation_worker.py