Files
AI-Trader/api/main.py
Bill db1341e204 feat: implement replace_existing parameter to allow re-running completed simulations
Add skip_completed parameter to JobManager.create_job() to control duplicate detection:
- When skip_completed=True (default), skips already-completed simulations (existing behavior)
- When skip_completed=False, includes ALL requested simulations regardless of completion status

API endpoint now uses request.replace_existing to control skip_completed parameter:
- replace_existing=false (default): skip_completed=True (skip duplicates)
- replace_existing=true: skip_completed=False (force re-run all simulations)

This allows users to force re-running completed simulations when needed.
2025-11-07 13:39:51 -05:00

490 lines
17 KiB
Python

"""
FastAPI REST API for AI-Trader simulation service.
Provides endpoints for:
- Triggering simulation jobs
- Checking job status
- Querying results
- Health checks
"""
import logging
import os
from typing import Optional, List, Dict, Any
from datetime import datetime
from pathlib import Path
from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator
from contextlib import asynccontextmanager
from api.job_manager import JobManager
from api.simulation_worker import SimulationWorker
from api.database import get_db_connection
from api.date_utils import validate_date_range, expand_date_range, get_max_simulation_days
from tools.deployment_config import get_deployment_mode_dict, log_dev_mode_startup_warning
from api.routes import results_v2
import threading
import time
logger = logging.getLogger(__name__)
# Pydantic models for request/response validation
class SimulateTriggerRequest(BaseModel):
"""Request body for POST /simulate/trigger."""
start_date: Optional[str] = Field(None, description="Start date for simulation (YYYY-MM-DD). If null/omitted, resumes from last completed date per model.")
end_date: str = Field(..., description="End date for simulation (YYYY-MM-DD). Required.")
models: Optional[List[str]] = Field(
None,
description="Optional: List of model signatures to simulate. If not provided, uses enabled models from config."
)
replace_existing: bool = Field(
False,
description="If true, replaces existing simulation data. If false (default), skips dates that already have data (idempotent)."
)
@field_validator("start_date", "end_date")
@classmethod
def validate_date_format(cls, v):
"""Validate date format."""
if v is None or v == "":
return None
try:
datetime.strptime(v, "%Y-%m-%d")
except ValueError:
raise ValueError(f"Invalid date format: {v}. Expected YYYY-MM-DD")
return v
@field_validator("end_date")
@classmethod
def validate_end_date_required(cls, v):
"""Ensure end_date is not null or empty."""
if v is None or v == "":
raise ValueError("end_date is required and cannot be null or empty")
return v
class SimulateTriggerResponse(BaseModel):
"""Response body for POST /simulate/trigger."""
job_id: str
status: str
total_model_days: int
message: str
deployment_mode: str
is_dev_mode: bool
preserve_dev_data: Optional[bool] = None
warnings: Optional[List[str]] = None
class JobProgress(BaseModel):
"""Job progress information."""
total_model_days: int
completed: int
failed: int
pending: int
class JobStatusResponse(BaseModel):
"""Response body for GET /simulate/status/{job_id}."""
job_id: str
status: str
progress: JobProgress
date_range: List[str]
models: List[str]
created_at: str
started_at: Optional[str] = None
completed_at: Optional[str] = None
total_duration_seconds: Optional[float] = None
error: Optional[str] = None
details: List[Dict[str, Any]]
deployment_mode: str
is_dev_mode: bool
preserve_dev_data: Optional[bool] = None
warnings: Optional[List[str]] = None
class HealthResponse(BaseModel):
"""Response body for GET /health."""
status: str
database: str
timestamp: str
deployment_mode: str
is_dev_mode: bool
preserve_dev_data: Optional[bool] = None
def create_app(
db_path: str = "data/jobs.db",
config_path: str = "/tmp/runtime_config.json" if Path("/tmp/runtime_config.json").exists() else "configs/default_config.json"
) -> FastAPI:
"""
Create FastAPI application instance.
Args:
db_path: Path to SQLite database
config_path: Path to default configuration file
Returns:
Configured FastAPI app
"""
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize database on startup, cleanup on shutdown if needed"""
from tools.deployment_config import is_dev_mode, get_db_path, should_preserve_dev_data
from api.database import initialize_dev_database, initialize_database
# Startup - use closure to access db_path from create_app scope
logger.info("🚀 FastAPI application starting...")
logger.info("📊 Initializing database...")
should_cleanup_stale_jobs = False
if is_dev_mode():
# Initialize dev database (reset unless PRESERVE_DEV_DATA=true)
logger.info(" 🔧 DEV mode detected - initializing dev database")
dev_db_path = get_db_path(db_path)
initialize_dev_database(dev_db_path)
log_dev_mode_startup_warning()
# Only cleanup stale jobs if preserving dev data (otherwise DB is fresh)
if should_preserve_dev_data():
should_cleanup_stale_jobs = True
else:
# Ensure production database schema exists
logger.info(" 🏭 PROD mode - ensuring database schema exists")
initialize_database(db_path)
should_cleanup_stale_jobs = True
logger.info("✅ Database initialized")
# Clean up stale jobs from previous container session
if should_cleanup_stale_jobs:
logger.info("🧹 Checking for stale jobs from previous session...")
job_manager = JobManager(get_db_path(db_path) if is_dev_mode() else db_path)
job_manager.cleanup_stale_jobs()
logger.info("🌐 API server ready to accept requests")
yield
# Shutdown (if needed in future)
logger.info("🛑 FastAPI application shutting down...")
app = FastAPI(
title="AI-Trader Simulation API",
description="REST API for triggering and monitoring AI trading simulations",
version="1.0.0",
lifespan=lifespan
)
# Store paths in app state
app.state.db_path = db_path
app.state.config_path = config_path
@app.post("/simulate/trigger", response_model=SimulateTriggerResponse, status_code=200)
async def trigger_simulation(request: SimulateTriggerRequest):
"""
Trigger a new simulation job.
Validates date range and creates job. Price data is downloaded
in background by SimulationWorker.
Supports:
- Single date: start_date == end_date
- Date range: start_date < end_date
- Resume: start_date is null (each model resumes from its last completed date)
Raises:
HTTPException 400: Validation errors, running job, or invalid dates
"""
try:
# Use config path from app state
config_path = app.state.config_path
# Validate config path exists
if not Path(config_path).exists():
raise HTTPException(
status_code=500,
detail=f"Server configuration file not found: {config_path}"
)
end_date = request.end_date
# Determine which models to run
import json
with open(config_path, 'r') as f:
config = json.load(f)
if request.models is not None and len(request.models) > 0:
# Use models from request (explicit override)
models_to_run = request.models
else:
# Use enabled models from config (when models is None or empty list)
models_to_run = [
model["signature"]
for model in config.get("models", [])
if model.get("enabled", False)
]
if not models_to_run:
raise HTTPException(
status_code=400,
detail="No enabled models found in config. Either enable models in config or specify them in request."
)
job_manager = JobManager(db_path=app.state.db_path)
# Handle resume logic (start_date is null)
if request.start_date is None:
# Resume mode: determine start date per model
from datetime import timedelta
model_start_dates = {}
for model in models_to_run:
last_date = job_manager.get_last_completed_date_for_model(model)
if last_date is None:
# Cold start: use end_date as single-day simulation
model_start_dates[model] = end_date
else:
# Resume from next day after last completed
last_dt = datetime.strptime(last_date, "%Y-%m-%d")
next_dt = last_dt + timedelta(days=1)
model_start_dates[model] = next_dt.strftime("%Y-%m-%d")
# For validation purposes, use earliest start date
earliest_start = min(model_start_dates.values())
start_date = earliest_start
else:
# Explicit start date provided
start_date = request.start_date
model_start_dates = {model: start_date for model in models_to_run}
# Validate date range
max_days = get_max_simulation_days()
validate_date_range(start_date, end_date, max_days=max_days)
# Check if can start new job
if not job_manager.can_start_new_job():
raise HTTPException(
status_code=400,
detail="Another simulation job is already running or pending. Please wait for it to complete."
)
# Get all weekdays in range (worker will filter based on data availability)
all_dates = expand_date_range(start_date, end_date)
# Create job immediately with all requested dates
# Worker will handle data download and filtering
result = job_manager.create_job(
config_path=config_path,
date_range=all_dates,
models=models_to_run,
model_day_filter=None, # Worker will filter based on available data
skip_completed=(not request.replace_existing) # Skip if replace_existing=False
)
job_id = result["job_id"]
warnings = result.get("warnings", [])
# Log warnings if any simulations were skipped
if warnings:
logger.warning(f"Job {job_id} created with {len(warnings)} skipped simulations: {warnings}")
# Start worker in background thread (only if not in test mode)
if not getattr(app.state, "test_mode", False):
def run_worker():
worker = SimulationWorker(job_id=job_id, db_path=app.state.db_path)
worker.run()
thread = threading.Thread(target=run_worker, daemon=True)
thread.start()
logger.info(f"Triggered simulation job {job_id} for {len(all_dates)} dates, {len(models_to_run)} models")
# Build response message
message = f"Simulation job created for {len(all_dates)} dates, {len(models_to_run)} models"
if request.start_date is None:
message += " (resume mode)"
# Get deployment mode info
deployment_info = get_deployment_mode_dict()
response = SimulateTriggerResponse(
job_id=job_id,
status="pending",
total_model_days=len(all_dates) * len(models_to_run),
message=message,
warnings=warnings if warnings else None,
**deployment_info
)
return response
except HTTPException:
raise
except ValueError as e:
logger.error(f"Validation error: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to trigger simulation: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.get("/simulate/status/{job_id}", response_model=JobStatusResponse)
async def get_job_status(job_id: str):
"""
Get status and progress of a simulation job.
Args:
job_id: Job UUID
Returns:
Job status, progress, model-day details, and warnings
Raises:
HTTPException 404: If job not found
"""
try:
job_manager = JobManager(db_path=app.state.db_path)
# Get job info
job = job_manager.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
# Get progress
progress = job_manager.get_job_progress(job_id)
# Get model-day details
details = job_manager.get_job_details(job_id)
# Calculate pending (total - completed - failed)
pending = progress["total_model_days"] - progress["completed"] - progress["failed"]
# Parse warnings from JSON if present
import json
warnings = None
if job.get("warnings"):
try:
warnings = json.loads(job["warnings"])
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse warnings for job {job_id}")
# Get deployment mode info
deployment_info = get_deployment_mode_dict()
return JobStatusResponse(
job_id=job["job_id"],
status=job["status"],
progress=JobProgress(
total_model_days=progress["total_model_days"],
completed=progress["completed"],
failed=progress["failed"],
pending=pending
),
date_range=job["date_range"],
models=job["models"],
created_at=job["created_at"],
started_at=job.get("started_at"),
completed_at=job.get("completed_at"),
total_duration_seconds=job.get("total_duration_seconds"),
error=job.get("error"),
details=details,
warnings=warnings,
**deployment_info
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get job status: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
# OLD /results endpoint - REPLACED by results_v2.py
# This endpoint used the old positions table schema and is no longer needed
# The new endpoint is defined in api/routes/results_v2.py
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""
Health check endpoint.
Verifies database connectivity and service status.
Returns:
Health status and timestamp
"""
from tools.deployment_config import is_dev_mode
# Log at DEBUG in dev mode, INFO in prod mode
if is_dev_mode():
logger.debug("Health check")
else:
logger.info("Health check")
try:
# Test database connection
conn = get_db_connection(app.state.db_path)
cursor = conn.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
conn.close()
database_status = "connected"
except Exception as e:
logger.error(f"Database health check failed: {e}")
database_status = "disconnected"
# Get deployment mode info
deployment_info = get_deployment_mode_dict()
return HealthResponse(
status="healthy" if database_status == "connected" else "unhealthy",
database=database_status,
timestamp=datetime.utcnow().isoformat() + "Z",
**deployment_info
)
# Include routers
app.include_router(results_v2.router)
return app
# Create default app instance
app = create_app()
# Ensure database is initialized when module is loaded
# This handles cases where lifespan might not be triggered properly
logger.info("🔧 Module-level database initialization check...")
from tools.deployment_config import is_dev_mode, get_db_path
from api.database import initialize_dev_database, initialize_database
_db_path = app.state.db_path
if is_dev_mode():
logger.info(" 🔧 DEV mode - initializing dev database at module load")
_dev_db_path = get_db_path(_db_path)
initialize_dev_database(_dev_db_path)
else:
logger.info(" 🏭 PROD mode - ensuring database exists at module load")
initialize_database(_db_path)
logger.info("✅ Module-level database initialization complete")
if __name__ == "__main__":
import uvicorn
# Note: Database initialization happens in lifespan AND at module load
# for maximum reliability
uvicorn.run(app, host="0.0.0.0", port=8080)