mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-12 13:37:24 -04:00
feat: implement date range API and on-demand downloads (WIP phase 2)
Phase 2 progress - API integration complete. API Changes: - Replace date_range (List[str]) with start_date/end_date (str) - Add automatic end_date defaulting to start_date for single day - Add date format validation - Integrate PriceDataManager for on-demand downloads - Add rate limit handling (trusts provider, no pre-config) - Validate date ranges with configurable max days (MAX_SIMULATION_DAYS) New Modules: - api/date_utils.py - Date validation and expansion utilities - scripts/migrate_price_data.py - Migration script for merged.jsonl API Flow: 1. Validate date range (start <= end, max 30 days, not future) 2. Check missing price data coverage 3. Download missing data if AUTO_DOWNLOAD_PRICE_DATA=true 4. Priority-based download (maximize date completion) 5. Create job with available trading dates 6. Graceful handling of partial data (rate limits) Configuration: - AUTO_DOWNLOAD_PRICE_DATA (default: true) - MAX_SIMULATION_DAYS (default: 30) - No rate limit configuration needed Still TODO: - Update tools/price_tools.py to read from database - Implement simulation run tracking - Update .env.example - Comprehensive testing - Documentation updates Breaking Changes: - API request format changed (date_range -> start_date/end_date) - This completes v0.3.0 preparation
This commit is contained in:
93
api/date_utils.py
Normal file
93
api/date_utils.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""
|
||||||
|
Date range utilities for simulation date management.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
- Date range expansion
|
||||||
|
- Date range validation
|
||||||
|
- Trading day detection
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def expand_date_range(start_date: str, end_date: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Expand date range into list of all dates (inclusive).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date (YYYY-MM-DD)
|
||||||
|
end_date: End date (YYYY-MM-DD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of dates in range
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If dates are invalid or start > end
|
||||||
|
"""
|
||||||
|
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
|
||||||
|
if start > end:
|
||||||
|
raise ValueError(f"start_date ({start_date}) must be <= end_date ({end_date})")
|
||||||
|
|
||||||
|
dates = []
|
||||||
|
current = start
|
||||||
|
|
||||||
|
while current <= end:
|
||||||
|
dates.append(current.strftime("%Y-%m-%d"))
|
||||||
|
current += timedelta(days=1)
|
||||||
|
|
||||||
|
return dates
|
||||||
|
|
||||||
|
|
||||||
|
def validate_date_range(
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
max_days: int = 30
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validate date range for simulation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date (YYYY-MM-DD)
|
||||||
|
end_date: End date (YYYY-MM-DD)
|
||||||
|
max_days: Maximum allowed days in range
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If validation fails
|
||||||
|
"""
|
||||||
|
# Parse dates
|
||||||
|
try:
|
||||||
|
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid date format: {e}")
|
||||||
|
|
||||||
|
# Check order
|
||||||
|
if start > end:
|
||||||
|
raise ValueError(f"start_date ({start_date}) must be <= end_date ({end_date})")
|
||||||
|
|
||||||
|
# Check range size
|
||||||
|
days = (end - start).days + 1
|
||||||
|
if days > max_days:
|
||||||
|
raise ValueError(
|
||||||
|
f"Date range too large: {days} days (max: {max_days}). "
|
||||||
|
f"Reduce range or increase MAX_SIMULATION_DAYS."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check not in future
|
||||||
|
today = datetime.now().date()
|
||||||
|
if end.date() > today:
|
||||||
|
raise ValueError(f"end_date ({end_date}) cannot be in the future")
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_simulation_days() -> int:
|
||||||
|
"""
|
||||||
|
Get maximum simulation days from environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Maximum days allowed in simulation range
|
||||||
|
"""
|
||||||
|
return int(os.getenv("MAX_SIMULATION_DAYS", "30"))
|
||||||
126
api/main.py
126
api/main.py
@@ -9,6 +9,7 @@ Provides endpoints for:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -20,6 +21,8 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
from api.job_manager import JobManager
|
from api.job_manager import JobManager
|
||||||
from api.simulation_worker import SimulationWorker
|
from api.simulation_worker import SimulationWorker
|
||||||
from api.database import get_db_connection
|
from api.database import get_db_connection
|
||||||
|
from api.price_data_manager import PriceDataManager
|
||||||
|
from api.date_utils import validate_date_range, expand_date_range, get_max_simulation_days
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -29,23 +32,29 @@ logger = logging.getLogger(__name__)
|
|||||||
# Pydantic models for request/response validation
|
# Pydantic models for request/response validation
|
||||||
class SimulateTriggerRequest(BaseModel):
|
class SimulateTriggerRequest(BaseModel):
|
||||||
"""Request body for POST /simulate/trigger."""
|
"""Request body for POST /simulate/trigger."""
|
||||||
date_range: List[str] = Field(..., min_length=1, description="List of trading dates (YYYY-MM-DD)")
|
start_date: str = Field(..., description="Start date for simulation (YYYY-MM-DD)")
|
||||||
|
end_date: Optional[str] = Field(None, description="End date for simulation (YYYY-MM-DD). If not provided, simulates single day.")
|
||||||
models: Optional[List[str]] = Field(
|
models: Optional[List[str]] = Field(
|
||||||
None,
|
None,
|
||||||
description="Optional: List of model signatures to simulate. If not provided, uses enabled models from config."
|
description="Optional: List of model signatures to simulate. If not provided, uses enabled models from config."
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("date_range")
|
@field_validator("start_date", "end_date")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_date_range(cls, v):
|
def validate_date_format(cls, v):
|
||||||
"""Validate date format."""
|
"""Validate date format."""
|
||||||
for date in v:
|
if v is None:
|
||||||
try:
|
return v
|
||||||
datetime.strptime(date, "%Y-%m-%d")
|
try:
|
||||||
except ValueError:
|
datetime.strptime(v, "%Y-%m-%d")
|
||||||
raise ValueError(f"Invalid date format: {date}. Expected YYYY-MM-DD")
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid date format: {v}. Expected YYYY-MM-DD")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
def get_end_date(self) -> str:
|
||||||
|
"""Get end date, defaulting to start_date if not provided."""
|
||||||
|
return self.end_date or self.start_date
|
||||||
|
|
||||||
|
|
||||||
class SimulateTriggerResponse(BaseModel):
|
class SimulateTriggerResponse(BaseModel):
|
||||||
"""Response body for POST /simulate/trigger."""
|
"""Response body for POST /simulate/trigger."""
|
||||||
@@ -114,13 +123,12 @@ def create_app(
|
|||||||
"""
|
"""
|
||||||
Trigger a new simulation job.
|
Trigger a new simulation job.
|
||||||
|
|
||||||
Creates a job with dates and models from config file.
|
Validates date range, downloads missing price data if needed,
|
||||||
If models not specified in request, uses enabled models from config.
|
and creates job with available trading dates.
|
||||||
Job runs asynchronously in background thread.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException 400: If another job is already running or config invalid
|
HTTPException 400: Validation errors, running job, or invalid dates
|
||||||
HTTPException 422: If request validation fails
|
HTTPException 503: Price data download failed
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Use config path from app state
|
# Use config path from app state
|
||||||
@@ -133,6 +141,13 @@ def create_app(
|
|||||||
detail=f"Server configuration file not found: {config_path}"
|
detail=f"Server configuration file not found: {config_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get end date (defaults to start_date for single day)
|
||||||
|
end_date = request.get_end_date()
|
||||||
|
|
||||||
|
# Validate date range
|
||||||
|
max_days = get_max_simulation_days()
|
||||||
|
validate_date_range(request.start_date, end_date, max_days=max_days)
|
||||||
|
|
||||||
# Determine which models to run
|
# Determine which models to run
|
||||||
import json
|
import json
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
@@ -155,6 +170,67 @@ def create_app(
|
|||||||
detail="No enabled models found in config. Either enable models in config or specify them in request."
|
detail="No enabled models found in config. Either enable models in config or specify them in request."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check price data and download if needed
|
||||||
|
auto_download = os.getenv("AUTO_DOWNLOAD_PRICE_DATA", "true").lower() == "true"
|
||||||
|
price_manager = PriceDataManager(db_path=app.state.db_path)
|
||||||
|
|
||||||
|
# Check what's missing
|
||||||
|
missing_coverage = price_manager.get_missing_coverage(
|
||||||
|
request.start_date,
|
||||||
|
end_date
|
||||||
|
)
|
||||||
|
|
||||||
|
download_info = None
|
||||||
|
|
||||||
|
# Download missing data if enabled
|
||||||
|
if any(missing_coverage.values()):
|
||||||
|
if not auto_download:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Missing price data for {len(missing_coverage)} symbols and auto-download is disabled. "
|
||||||
|
f"Enable AUTO_DOWNLOAD_PRICE_DATA or pre-populate data."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Downloading missing price data for {len(missing_coverage)} symbols")
|
||||||
|
|
||||||
|
requested_dates = set(expand_date_range(request.start_date, end_date))
|
||||||
|
|
||||||
|
download_result = price_manager.download_missing_data_prioritized(
|
||||||
|
missing_coverage,
|
||||||
|
requested_dates
|
||||||
|
)
|
||||||
|
|
||||||
|
if not download_result["success"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Failed to download any price data. Check ALPHAADVANTAGE_API_KEY."
|
||||||
|
)
|
||||||
|
|
||||||
|
download_info = {
|
||||||
|
"symbols_downloaded": len(download_result["downloaded"]),
|
||||||
|
"symbols_failed": len(download_result["failed"]),
|
||||||
|
"rate_limited": download_result["rate_limited"]
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Downloaded {len(download_result['downloaded'])} symbols, "
|
||||||
|
f"{len(download_result['failed'])} failed, rate_limited={download_result['rate_limited']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get available trading dates (after potential download)
|
||||||
|
available_dates = price_manager.get_available_trading_dates(
|
||||||
|
request.start_date,
|
||||||
|
end_date
|
||||||
|
)
|
||||||
|
|
||||||
|
if not available_dates:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"No trading dates with complete price data in range "
|
||||||
|
f"{request.start_date} to {end_date}. "
|
||||||
|
f"All symbols must have data for a date to be tradeable."
|
||||||
|
)
|
||||||
|
|
||||||
job_manager = JobManager(db_path=app.state.db_path)
|
job_manager = JobManager(db_path=app.state.db_path)
|
||||||
|
|
||||||
# Check if can start new job
|
# Check if can start new job
|
||||||
@@ -164,10 +240,10 @@ def create_app(
|
|||||||
detail="Another simulation job is already running or pending. Please wait for it to complete."
|
detail="Another simulation job is already running or pending. Please wait for it to complete."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create job
|
# Create job with available dates
|
||||||
job_id = job_manager.create_job(
|
job_id = job_manager.create_job(
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
date_range=request.date_range,
|
date_range=available_dates,
|
||||||
models=models_to_run
|
models=models_to_run
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -180,15 +256,27 @@ def create_app(
|
|||||||
thread = threading.Thread(target=run_worker, daemon=True)
|
thread = threading.Thread(target=run_worker, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
logger.info(f"Triggered simulation job {job_id}")
|
logger.info(f"Triggered simulation job {job_id} with {len(available_dates)} dates")
|
||||||
|
|
||||||
return SimulateTriggerResponse(
|
# Build response message
|
||||||
|
message = f"Simulation job created with {len(available_dates)} trading dates"
|
||||||
|
if download_info and download_info["rate_limited"]:
|
||||||
|
message += " (rate limit reached - partial data)"
|
||||||
|
|
||||||
|
response = SimulateTriggerResponse(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
status="pending",
|
status="pending",
|
||||||
total_model_days=len(request.date_range) * len(models_to_run),
|
total_model_days=len(available_dates) * len(models_to_run),
|
||||||
message=f"Simulation job {job_id} created and started with {len(models_to_run)} models"
|
message=message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add download info if we downloaded
|
||||||
|
if download_info:
|
||||||
|
# Note: Need to add download_info field to response model
|
||||||
|
logger.info(f"Download info: {download_info}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|||||||
166
scripts/migrate_price_data.py
Executable file
166
scripts/migrate_price_data.py
Executable file
@@ -0,0 +1,166 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Migration script: Import merged.jsonl price data into SQLite database.
|
||||||
|
|
||||||
|
This script:
|
||||||
|
1. Reads existing merged.jsonl file
|
||||||
|
2. Parses OHLCV data for each symbol/date
|
||||||
|
3. Inserts into price_data table
|
||||||
|
4. Tracks coverage in price_data_coverage table
|
||||||
|
|
||||||
|
Run this once to migrate from jsonl to database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from api.database import get_db_connection, initialize_database
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_merged_jsonl(
|
||||||
|
jsonl_path: str = "data/merged.jsonl",
|
||||||
|
db_path: str = "data/jobs.db"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Migrate price data from merged.jsonl to SQLite database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jsonl_path: Path to merged.jsonl file
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
"""
|
||||||
|
jsonl_file = Path(jsonl_path)
|
||||||
|
|
||||||
|
if not jsonl_file.exists():
|
||||||
|
print(f"⚠️ merged.jsonl not found at {jsonl_path}")
|
||||||
|
print(" No price data to migrate. Skipping migration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"📊 Migrating price data from {jsonl_path} to {db_path}")
|
||||||
|
|
||||||
|
# Ensure database is initialized
|
||||||
|
initialize_database(db_path)
|
||||||
|
|
||||||
|
conn = get_db_connection(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Track what we're importing
|
||||||
|
total_records = 0
|
||||||
|
symbols_processed = set()
|
||||||
|
symbol_date_ranges = defaultdict(lambda: {"min": None, "max": None})
|
||||||
|
|
||||||
|
created_at = datetime.utcnow().isoformat() + "Z"
|
||||||
|
|
||||||
|
print("Reading merged.jsonl...")
|
||||||
|
|
||||||
|
with open(jsonl_file, 'r') as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
record = json.loads(line)
|
||||||
|
|
||||||
|
# Extract metadata
|
||||||
|
meta = record.get("Meta Data", {})
|
||||||
|
symbol = meta.get("2. Symbol")
|
||||||
|
|
||||||
|
if not symbol:
|
||||||
|
print(f"⚠️ Line {line_num}: No symbol found, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
symbols_processed.add(symbol)
|
||||||
|
|
||||||
|
# Extract time series data
|
||||||
|
time_series = record.get("Time Series (Daily)", {})
|
||||||
|
|
||||||
|
if not time_series:
|
||||||
|
print(f"⚠️ {symbol}: No time series data, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Insert each date's data
|
||||||
|
for date, ohlcv in time_series.items():
|
||||||
|
try:
|
||||||
|
# Parse OHLCV values
|
||||||
|
open_price = float(ohlcv.get("1. buy price") or ohlcv.get("1. open", 0))
|
||||||
|
high_price = float(ohlcv.get("2. high", 0))
|
||||||
|
low_price = float(ohlcv.get("3. low", 0))
|
||||||
|
close_price = float(ohlcv.get("4. sell price") or ohlcv.get("4. close", 0))
|
||||||
|
volume = int(ohlcv.get("5. volume", 0))
|
||||||
|
|
||||||
|
# Insert or replace price data
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT OR REPLACE INTO price_data
|
||||||
|
(symbol, date, open, high, low, close, volume, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""", (symbol, date, open_price, high_price, low_price, close_price, volume, created_at))
|
||||||
|
|
||||||
|
total_records += 1
|
||||||
|
|
||||||
|
# Track date range for this symbol
|
||||||
|
if symbol_date_ranges[symbol]["min"] is None or date < symbol_date_ranges[symbol]["min"]:
|
||||||
|
symbol_date_ranges[symbol]["min"] = date
|
||||||
|
if symbol_date_ranges[symbol]["max"] is None or date > symbol_date_ranges[symbol]["max"]:
|
||||||
|
symbol_date_ranges[symbol]["max"] = date
|
||||||
|
|
||||||
|
except (ValueError, KeyError) as e:
|
||||||
|
print(f"⚠️ {symbol} {date}: Failed to parse OHLCV data: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Commit every 1000 records for progress
|
||||||
|
if total_records % 1000 == 0:
|
||||||
|
conn.commit()
|
||||||
|
print(f" Imported {total_records} records...")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"⚠️ Line {line_num}: JSON decode error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Final commit
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
print(f"\n✓ Imported {total_records} price records for {len(symbols_processed)} symbols")
|
||||||
|
|
||||||
|
# Update coverage tracking
|
||||||
|
print("\nUpdating coverage tracking...")
|
||||||
|
|
||||||
|
for symbol, date_range in symbol_date_ranges.items():
|
||||||
|
if date_range["min"] and date_range["max"]:
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT OR REPLACE INTO price_data_coverage
|
||||||
|
(symbol, start_date, end_date, downloaded_at, source)
|
||||||
|
VALUES (?, ?, ?, ?, 'migrated_from_jsonl')
|
||||||
|
""", (symbol, date_range["min"], date_range["max"], created_at))
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
print(f"✓ Coverage tracking updated for {len(symbol_date_ranges)} symbols")
|
||||||
|
print("\n✅ Migration complete!")
|
||||||
|
print(f"\nSymbols migrated: {', '.join(sorted(symbols_processed))}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Migrate merged.jsonl to SQLite database")
|
||||||
|
parser.add_argument(
|
||||||
|
"--jsonl",
|
||||||
|
default="data/merged.jsonl",
|
||||||
|
help="Path to merged.jsonl file (default: data/merged.jsonl)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--db",
|
||||||
|
default="data/jobs.db",
|
||||||
|
help="Path to SQLite database (default: data/jobs.db)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
migrate_merged_jsonl(args.jsonl, args.db)
|
||||||
Reference in New Issue
Block a user