mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-04 18:07:24 -04:00
feat: transform to REST API service with SQLite persistence (v0.3.0)
Major architecture transformation from batch-only to API service with
database persistence for Windmill integration.
## REST API Implementation
- POST /simulate/trigger - Start simulation jobs
- GET /simulate/status/{job_id} - Monitor job progress
- GET /results - Query results with filters (job_id, date, model)
- GET /health - Service health checks
## Database Layer
- SQLite persistence with 6 tables (jobs, job_details, positions,
holdings, reasoning_logs, tool_usage)
- Foreign key constraints with cascade deletes
- Replaces JSONL file storage
## Backend Components
- JobManager: Job lifecycle management with concurrency control
- RuntimeConfigManager: Thread-safe isolated runtime configs
- ModelDayExecutor: Single model-day execution engine
- SimulationWorker: Date-sequential, model-parallel orchestration
## Testing
- 102 unit and integration tests (85% coverage)
- Database: 98% coverage
- Job manager: 98% coverage
- API endpoints: 81% coverage
- Pydantic models: 100% coverage
- TDD approach throughout
## Docker Deployment
- Dual-mode: API server (persistent) + batch (one-time)
- Health checks with 30s interval
- Volume persistence for database and logs
- Separate entrypoints for each mode
## Validation Tools
- scripts/validate_docker_build.sh - Build validation
- scripts/test_api_endpoints.sh - Complete API testing
- scripts/test_batch_mode.sh - Batch mode validation
- DOCKER_API.md - Deployment guide
- TESTING_GUIDE.md - Testing procedures
## Configuration
- API_PORT environment variable (default: 8080)
- Backwards compatible with existing configs
- FastAPI, uvicorn, pydantic>=2.0 dependencies
Co-Authored-By: AI Assistant <noreply@example.com>
This commit is contained in:
366
api/main.py
Normal file
366
api/main.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
FastAPI REST API for AI-Trader simulation service.
|
||||
|
||||
Provides endpoints for:
|
||||
- Triggering simulation jobs
|
||||
- Checking job status
|
||||
- Querying results
|
||||
- Health checks
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from api.job_manager import JobManager
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.database import get_db_connection
|
||||
import threading
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Pydantic models for request/response validation
|
||||
class SimulateTriggerRequest(BaseModel):
|
||||
"""Request body for POST /simulate/trigger."""
|
||||
config_path: str = Field(..., description="Path to configuration file")
|
||||
date_range: List[str] = Field(..., min_length=1, description="List of trading dates (YYYY-MM-DD)")
|
||||
models: List[str] = Field(..., min_length=1, description="List of model signatures to simulate")
|
||||
|
||||
@field_validator("date_range")
|
||||
@classmethod
|
||||
def validate_date_range(cls, v):
|
||||
"""Validate date format."""
|
||||
for date in v:
|
||||
try:
|
||||
datetime.strptime(date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {date}. Expected YYYY-MM-DD")
|
||||
return v
|
||||
|
||||
|
||||
class SimulateTriggerResponse(BaseModel):
|
||||
"""Response body for POST /simulate/trigger."""
|
||||
job_id: str
|
||||
status: str
|
||||
total_model_days: int
|
||||
message: str
|
||||
|
||||
|
||||
class JobProgress(BaseModel):
|
||||
"""Job progress information."""
|
||||
total_model_days: int
|
||||
completed: int
|
||||
failed: int
|
||||
pending: int
|
||||
|
||||
|
||||
class JobStatusResponse(BaseModel):
|
||||
"""Response body for GET /simulate/status/{job_id}."""
|
||||
job_id: str
|
||||
status: str
|
||||
progress: JobProgress
|
||||
date_range: List[str]
|
||||
models: List[str]
|
||||
created_at: str
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
total_duration_seconds: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
details: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Response body for GET /health."""
|
||||
status: str
|
||||
database: str
|
||||
timestamp: str
|
||||
|
||||
|
||||
def create_app(db_path: str = "data/jobs.db") -> FastAPI:
|
||||
"""
|
||||
Create FastAPI application instance.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
|
||||
Returns:
|
||||
Configured FastAPI app
|
||||
"""
|
||||
app = FastAPI(
|
||||
title="AI-Trader Simulation API",
|
||||
description="REST API for triggering and monitoring AI trading simulations",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Store db_path in app state
|
||||
app.state.db_path = db_path
|
||||
|
||||
@app.post("/simulate/trigger", response_model=SimulateTriggerResponse, status_code=200)
|
||||
async def trigger_simulation(request: SimulateTriggerRequest):
|
||||
"""
|
||||
Trigger a new simulation job.
|
||||
|
||||
Creates a job with specified config, dates, and models.
|
||||
Job runs asynchronously in background thread.
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If another job is already running or config invalid
|
||||
HTTPException 422: If request validation fails
|
||||
"""
|
||||
try:
|
||||
# Validate config path exists
|
||||
if not Path(request.config_path).exists():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Config path does not exist: {request.config_path}"
|
||||
)
|
||||
|
||||
job_manager = JobManager(db_path=app.state.db_path)
|
||||
|
||||
# Check if can start new job
|
||||
if not job_manager.can_start_new_job():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Another simulation job is already running or pending. Please wait for it to complete."
|
||||
)
|
||||
|
||||
# Create job
|
||||
job_id = job_manager.create_job(
|
||||
config_path=request.config_path,
|
||||
date_range=request.date_range,
|
||||
models=request.models
|
||||
)
|
||||
|
||||
# Start worker in background thread (only if not in test mode)
|
||||
if not getattr(app.state, "test_mode", False):
|
||||
def run_worker():
|
||||
worker = SimulationWorker(job_id=job_id, db_path=app.state.db_path)
|
||||
worker.run()
|
||||
|
||||
thread = threading.Thread(target=run_worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Triggered simulation job {job_id}")
|
||||
|
||||
return SimulateTriggerResponse(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
total_model_days=len(request.date_range) * len(request.models),
|
||||
message=f"Simulation job {job_id} created and started"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger simulation: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@app.get("/simulate/status/{job_id}", response_model=JobStatusResponse)
|
||||
async def get_job_status(job_id: str):
|
||||
"""
|
||||
Get status and progress of a simulation job.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
Job status, progress, and model-day details
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If job not found
|
||||
"""
|
||||
try:
|
||||
job_manager = JobManager(db_path=app.state.db_path)
|
||||
|
||||
# Get job info
|
||||
job = job_manager.get_job(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
||||
|
||||
# Get progress
|
||||
progress = job_manager.get_job_progress(job_id)
|
||||
|
||||
# Get model-day details
|
||||
details = job_manager.get_job_details(job_id)
|
||||
|
||||
# Calculate pending (total - completed - failed)
|
||||
pending = progress["total_model_days"] - progress["completed"] - progress["failed"]
|
||||
|
||||
return JobStatusResponse(
|
||||
job_id=job["job_id"],
|
||||
status=job["status"],
|
||||
progress=JobProgress(
|
||||
total_model_days=progress["total_model_days"],
|
||||
completed=progress["completed"],
|
||||
failed=progress["failed"],
|
||||
pending=pending
|
||||
),
|
||||
date_range=job["date_range"],
|
||||
models=job["models"],
|
||||
created_at=job["created_at"],
|
||||
started_at=job.get("started_at"),
|
||||
completed_at=job.get("completed_at"),
|
||||
total_duration_seconds=job.get("total_duration_seconds"),
|
||||
error=job.get("error"),
|
||||
details=details
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get job status: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@app.get("/results")
|
||||
async def get_results(
|
||||
job_id: Optional[str] = Query(None, description="Filter by job ID"),
|
||||
date: Optional[str] = Query(None, description="Filter by date (YYYY-MM-DD)"),
|
||||
model: Optional[str] = Query(None, description="Filter by model signature")
|
||||
):
|
||||
"""
|
||||
Query simulation results.
|
||||
|
||||
Supports filtering by job_id, date, and/or model.
|
||||
Returns position data with holdings.
|
||||
|
||||
Args:
|
||||
job_id: Optional job UUID filter
|
||||
date: Optional date filter (YYYY-MM-DD)
|
||||
model: Optional model signature filter
|
||||
|
||||
Returns:
|
||||
List of position records with holdings
|
||||
"""
|
||||
try:
|
||||
conn = get_db_connection(app.state.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Build query with filters
|
||||
query = """
|
||||
SELECT
|
||||
p.id,
|
||||
p.job_id,
|
||||
p.date,
|
||||
p.model,
|
||||
p.action_id,
|
||||
p.action_type,
|
||||
p.symbol,
|
||||
p.amount,
|
||||
p.price,
|
||||
p.cash,
|
||||
p.portfolio_value,
|
||||
p.daily_profit,
|
||||
p.daily_return_pct,
|
||||
p.created_at
|
||||
FROM positions p
|
||||
WHERE 1=1
|
||||
"""
|
||||
params = []
|
||||
|
||||
if job_id:
|
||||
query += " AND p.job_id = ?"
|
||||
params.append(job_id)
|
||||
|
||||
if date:
|
||||
query += " AND p.date = ?"
|
||||
params.append(date)
|
||||
|
||||
if model:
|
||||
query += " AND p.model = ?"
|
||||
params.append(model)
|
||||
|
||||
query += " ORDER BY p.date, p.model, p.action_id"
|
||||
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
position_id = row[0]
|
||||
|
||||
# Get holdings for this position
|
||||
cursor.execute("""
|
||||
SELECT symbol, quantity
|
||||
FROM holdings
|
||||
WHERE position_id = ?
|
||||
ORDER BY symbol
|
||||
""", (position_id,))
|
||||
|
||||
holdings = [{"symbol": h[0], "quantity": h[1]} for h in cursor.fetchall()]
|
||||
|
||||
results.append({
|
||||
"id": row[0],
|
||||
"job_id": row[1],
|
||||
"date": row[2],
|
||||
"model": row[3],
|
||||
"action_id": row[4],
|
||||
"action_type": row[5],
|
||||
"symbol": row[6],
|
||||
"amount": row[7],
|
||||
"price": row[8],
|
||||
"cash": row[9],
|
||||
"portfolio_value": row[10],
|
||||
"daily_profit": row[11],
|
||||
"daily_return_pct": row[12],
|
||||
"created_at": row[13],
|
||||
"holdings": holdings
|
||||
})
|
||||
|
||||
conn.close()
|
||||
|
||||
return {"results": results, "count": len(results)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query results: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint.
|
||||
|
||||
Verifies database connectivity and service status.
|
||||
|
||||
Returns:
|
||||
Health status and timestamp
|
||||
"""
|
||||
try:
|
||||
# Test database connection
|
||||
conn = get_db_connection(app.state.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
database_status = "connected"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
database_status = "disconnected"
|
||||
|
||||
return HealthResponse(
|
||||
status="healthy" if database_status == "connected" else "unhealthy",
|
||||
database=database_status,
|
||||
timestamp=datetime.utcnow().isoformat() + "Z"
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Create default app instance
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||
Reference in New Issue
Block a user