mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-11 05:07:25 -04:00
feat(worker): add _download_price_data helper method
Handle price data download with rate limit detection and warning generation. Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -9,7 +9,7 @@ This module provides:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, Set
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
from api.job_manager import JobManager
|
from api.job_manager import JobManager
|
||||||
@@ -200,6 +200,42 @@ class SimulationWorker:
|
|||||||
"error": str(e)
|
"error": str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _download_price_data(
|
||||||
|
self,
|
||||||
|
price_manager,
|
||||||
|
missing_coverage: Dict[str, Set[str]],
|
||||||
|
requested_dates: List[str],
|
||||||
|
warnings: List[str]
|
||||||
|
) -> None:
|
||||||
|
"""Download missing price data with progress logging."""
|
||||||
|
logger.info(f"Job {self.job_id}: Starting prioritized download...")
|
||||||
|
|
||||||
|
requested_dates_set = set(requested_dates)
|
||||||
|
|
||||||
|
download_result = price_manager.download_missing_data_prioritized(
|
||||||
|
missing_coverage,
|
||||||
|
requested_dates_set
|
||||||
|
)
|
||||||
|
|
||||||
|
downloaded = len(download_result["downloaded"])
|
||||||
|
failed = len(download_result["failed"])
|
||||||
|
total = downloaded + failed
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Job {self.job_id}: Download complete - "
|
||||||
|
f"{downloaded}/{total} symbols succeeded"
|
||||||
|
)
|
||||||
|
|
||||||
|
if download_result["rate_limited"]:
|
||||||
|
msg = f"Rate limit reached - downloaded {downloaded}/{total} symbols"
|
||||||
|
warnings.append(msg)
|
||||||
|
logger.warning(f"Job {self.job_id}: {msg}")
|
||||||
|
|
||||||
|
if failed > 0 and not download_result["rate_limited"]:
|
||||||
|
msg = f"{failed} symbols failed to download"
|
||||||
|
warnings.append(msg)
|
||||||
|
logger.warning(f"Job {self.job_id}: {msg}")
|
||||||
|
|
||||||
def get_job_info(self) -> Dict[str, Any]:
|
def get_job_info(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get job information.
|
Get job information.
|
||||||
|
|||||||
@@ -274,4 +274,65 @@ class TestSimulationWorkerJobRetrieval:
|
|||||||
assert job_info["models"] == ["gpt-5"]
|
assert job_info["models"] == ["gpt-5"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSimulationWorkerHelperMethods:
|
||||||
|
"""Test worker helper methods."""
|
||||||
|
|
||||||
|
def test_download_price_data_success(self, clean_db):
|
||||||
|
"""Test successful price data download."""
|
||||||
|
from api.simulation_worker import SimulationWorker
|
||||||
|
from api.database import initialize_database
|
||||||
|
|
||||||
|
db_path = clean_db
|
||||||
|
initialize_database(db_path)
|
||||||
|
|
||||||
|
worker = SimulationWorker(job_id="test-123", db_path=db_path)
|
||||||
|
|
||||||
|
# Mock price manager
|
||||||
|
mock_price_manager = Mock()
|
||||||
|
mock_price_manager.download_missing_data_prioritized.return_value = {
|
||||||
|
"downloaded": ["AAPL", "MSFT"],
|
||||||
|
"failed": [],
|
||||||
|
"rate_limited": False
|
||||||
|
}
|
||||||
|
|
||||||
|
warnings = []
|
||||||
|
missing_coverage = {"AAPL": {"2025-10-01"}, "MSFT": {"2025-10-01"}}
|
||||||
|
|
||||||
|
worker._download_price_data(mock_price_manager, missing_coverage, ["2025-10-01"], warnings)
|
||||||
|
|
||||||
|
# Verify download was called
|
||||||
|
mock_price_manager.download_missing_data_prioritized.assert_called_once()
|
||||||
|
|
||||||
|
# No warnings for successful download
|
||||||
|
assert len(warnings) == 0
|
||||||
|
|
||||||
|
def test_download_price_data_rate_limited(self, clean_db):
|
||||||
|
"""Test price download with rate limit."""
|
||||||
|
from api.simulation_worker import SimulationWorker
|
||||||
|
from api.database import initialize_database
|
||||||
|
|
||||||
|
db_path = clean_db
|
||||||
|
initialize_database(db_path)
|
||||||
|
|
||||||
|
worker = SimulationWorker(job_id="test-456", db_path=db_path)
|
||||||
|
|
||||||
|
# Mock price manager
|
||||||
|
mock_price_manager = Mock()
|
||||||
|
mock_price_manager.download_missing_data_prioritized.return_value = {
|
||||||
|
"downloaded": ["AAPL"],
|
||||||
|
"failed": ["MSFT"],
|
||||||
|
"rate_limited": True
|
||||||
|
}
|
||||||
|
|
||||||
|
warnings = []
|
||||||
|
missing_coverage = {"AAPL": {"2025-10-01"}, "MSFT": {"2025-10-01"}}
|
||||||
|
|
||||||
|
worker._download_price_data(mock_price_manager, missing_coverage, ["2025-10-01"], warnings)
|
||||||
|
|
||||||
|
# Should add rate limit warning
|
||||||
|
assert len(warnings) == 1
|
||||||
|
assert "Rate limit" in warnings[0]
|
||||||
|
|
||||||
|
|
||||||
# Coverage target: 90%+ for api/simulation_worker.py
|
# Coverage target: 90%+ for api/simulation_worker.py
|
||||||
|
|||||||
Reference in New Issue
Block a user