mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-03 09:47:23 -04:00
feat: transform to REST API service with SQLite persistence (v0.3.0)
Major architecture transformation from batch-only to API service with
database persistence for Windmill integration.
## REST API Implementation
- POST /simulate/trigger - Start simulation jobs
- GET /simulate/status/{job_id} - Monitor job progress
- GET /results - Query results with filters (job_id, date, model)
- GET /health - Service health checks
## Database Layer
- SQLite persistence with 6 tables (jobs, job_details, positions,
holdings, reasoning_logs, tool_usage)
- Foreign key constraints with cascade deletes
- Replaces JSONL file storage
## Backend Components
- JobManager: Job lifecycle management with concurrency control
- RuntimeConfigManager: Thread-safe isolated runtime configs
- ModelDayExecutor: Single model-day execution engine
- SimulationWorker: Date-sequential, model-parallel orchestration
## Testing
- 102 unit and integration tests (85% coverage)
- Database: 98% coverage
- Job manager: 98% coverage
- API endpoints: 81% coverage
- Pydantic models: 100% coverage
- TDD approach throughout
## Docker Deployment
- Dual-mode: API server (persistent) + batch (one-time)
- Health checks with 30s interval
- Volume persistence for database and logs
- Separate entrypoints for each mode
## Validation Tools
- scripts/validate_docker_build.sh - Build validation
- scripts/test_api_endpoints.sh - Complete API testing
- scripts/test_batch_mode.sh - Batch mode validation
- DOCKER_API.md - Deployment guide
- TESTING_GUIDE.md - Testing procedures
## Configuration
- API_PORT environment variable (default: 8080)
- Backwards compatible with existing configs
- FastAPI, uvicorn, pydantic>=2.0 dependencies
Co-Authored-By: AI Assistant <noreply@example.com>
This commit is contained in:
0
api/__init__.py
Normal file
0
api/__init__.py
Normal file
307
api/database.py
Normal file
307
api/database.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Database utilities and schema management for AI-Trader API.
|
||||
|
||||
This module provides:
|
||||
- SQLite connection management
|
||||
- Database schema initialization (6 tables)
|
||||
- ACID-compliant transaction support
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
|
||||
def get_db_connection(db_path: str = "data/jobs.db") -> sqlite3.Connection:
|
||||
"""
|
||||
Get SQLite database connection with proper configuration.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
|
||||
Returns:
|
||||
Configured SQLite connection
|
||||
|
||||
Configuration:
|
||||
- Foreign keys enabled for referential integrity
|
||||
- Row factory for dict-like access
|
||||
- Check same thread disabled for FastAPI async compatibility
|
||||
"""
|
||||
# Ensure data directory exists
|
||||
db_path_obj = Path(db_path)
|
||||
db_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
return conn
|
||||
|
||||
|
||||
def initialize_database(db_path: str = "data/jobs.db") -> None:
|
||||
"""
|
||||
Create all database tables with enhanced schema.
|
||||
|
||||
Tables created:
|
||||
1. jobs - High-level job metadata and status
|
||||
2. job_details - Per model-day execution tracking
|
||||
3. positions - Trading positions and P&L metrics
|
||||
4. holdings - Portfolio holdings per position
|
||||
5. reasoning_logs - AI decision logs (optional, for detail=full)
|
||||
6. tool_usage - Tool usage statistics
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
"""
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Table 1: Jobs - Job metadata and lifecycle
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
job_id TEXT PRIMARY KEY,
|
||||
config_path TEXT NOT NULL,
|
||||
status TEXT NOT NULL CHECK(status IN ('pending', 'running', 'completed', 'partial', 'failed')),
|
||||
date_range TEXT NOT NULL,
|
||||
models TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
updated_at TEXT,
|
||||
completed_at TEXT,
|
||||
total_duration_seconds REAL,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Table 2: Job Details - Per model-day execution
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS job_details (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
status TEXT NOT NULL CHECK(status IN ('pending', 'running', 'completed', 'failed')),
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
duration_seconds REAL,
|
||||
error TEXT,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Table 3: Positions - Trading positions and P&L
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS positions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
action_id INTEGER NOT NULL,
|
||||
action_type TEXT CHECK(action_type IN ('buy', 'sell', 'no_trade')),
|
||||
symbol TEXT,
|
||||
amount INTEGER,
|
||||
price REAL,
|
||||
cash REAL NOT NULL,
|
||||
portfolio_value REAL NOT NULL,
|
||||
daily_profit REAL,
|
||||
daily_return_pct REAL,
|
||||
cumulative_profit REAL,
|
||||
cumulative_return_pct REAL,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Table 4: Holdings - Portfolio holdings
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS holdings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
position_id INTEGER NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
quantity INTEGER NOT NULL,
|
||||
FOREIGN KEY (position_id) REFERENCES positions(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Table 5: Reasoning Logs - AI decision logs (optional)
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS reasoning_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
step_number INTEGER NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
role TEXT CHECK(role IN ('user', 'assistant', 'tool')),
|
||||
content TEXT,
|
||||
tool_name TEXT,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Table 6: Tool Usage - Tool usage statistics
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tool_usage (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
tool_name TEXT NOT NULL,
|
||||
call_count INTEGER NOT NULL DEFAULT 1,
|
||||
total_duration_seconds REAL,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for performance
|
||||
_create_indexes(cursor)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def _create_indexes(cursor: sqlite3.Cursor) -> None:
|
||||
"""Create database indexes for query performance."""
|
||||
|
||||
# Jobs table indexes
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_created_at ON jobs(created_at DESC)
|
||||
""")
|
||||
|
||||
# Job details table indexes
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_job_details_job_id ON job_details(job_id)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_job_details_status ON job_details(status)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_job_details_unique
|
||||
ON job_details(job_id, date, model)
|
||||
""")
|
||||
|
||||
# Positions table indexes
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_positions_job_id ON positions(job_id)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_positions_date ON positions(date)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_positions_model ON positions(model)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_positions_date_model ON positions(date, model)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_positions_unique
|
||||
ON positions(job_id, date, model, action_id)
|
||||
""")
|
||||
|
||||
# Holdings table indexes
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_holdings_position_id ON holdings(position_id)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_holdings_symbol ON holdings(symbol)
|
||||
""")
|
||||
|
||||
# Reasoning logs table indexes
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_reasoning_logs_job_date_model
|
||||
ON reasoning_logs(job_id, date, model)
|
||||
""")
|
||||
|
||||
# Tool usage table indexes
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_tool_usage_job_date_model
|
||||
ON tool_usage(job_id, date, model)
|
||||
""")
|
||||
|
||||
|
||||
def drop_all_tables(db_path: str = "data/jobs.db") -> None:
|
||||
"""
|
||||
Drop all database tables. USE WITH CAUTION.
|
||||
|
||||
This is primarily for testing and development.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
"""
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
tables = [
|
||||
'tool_usage',
|
||||
'reasoning_logs',
|
||||
'holdings',
|
||||
'positions',
|
||||
'job_details',
|
||||
'jobs'
|
||||
]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def vacuum_database(db_path: str = "data/jobs.db") -> None:
|
||||
"""
|
||||
Reclaim disk space after deletions.
|
||||
|
||||
Should be run periodically after cleanup operations.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
"""
|
||||
conn = get_db_connection(db_path)
|
||||
conn.execute("VACUUM")
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_database_stats(db_path: str = "data/jobs.db") -> dict:
|
||||
"""
|
||||
Get database statistics for monitoring.
|
||||
|
||||
Returns:
|
||||
Dictionary with table row counts and database size
|
||||
|
||||
Example:
|
||||
{
|
||||
"database_size_mb": 12.5,
|
||||
"jobs": 150,
|
||||
"job_details": 3000,
|
||||
"positions": 15000,
|
||||
"holdings": 45000,
|
||||
"reasoning_logs": 300000,
|
||||
"tool_usage": 12000
|
||||
}
|
||||
"""
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
stats = {}
|
||||
|
||||
# Get database file size
|
||||
if os.path.exists(db_path):
|
||||
size_bytes = os.path.getsize(db_path)
|
||||
stats["database_size_mb"] = round(size_bytes / (1024 * 1024), 2)
|
||||
else:
|
||||
stats["database_size_mb"] = 0
|
||||
|
||||
# Get row counts for each table
|
||||
tables = ['jobs', 'job_details', 'positions', 'holdings', 'reasoning_logs', 'tool_usage']
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
stats[table] = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
return stats
|
||||
625
api/job_manager.py
Normal file
625
api/job_manager.py
Normal file
@@ -0,0 +1,625 @@
|
||||
"""
|
||||
Job lifecycle manager for simulation orchestration.
|
||||
|
||||
This module provides:
|
||||
- Job creation and validation
|
||||
- Status transitions (state machine)
|
||||
- Progress tracking across model-days
|
||||
- Concurrency control (single job at a time)
|
||||
- Job retrieval and queries
|
||||
- Cleanup operations
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from api.database import get_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobManager:
|
||||
"""
|
||||
Manages simulation job lifecycle and orchestration.
|
||||
|
||||
Responsibilities:
|
||||
- Create jobs with date ranges and model lists
|
||||
- Track job status (pending → running → completed/partial/failed)
|
||||
- Monitor progress across model-days
|
||||
- Enforce single-job concurrency
|
||||
- Provide job queries and retrieval
|
||||
- Cleanup old jobs
|
||||
|
||||
State Machine:
|
||||
pending → running → completed (all succeeded)
|
||||
→ partial (some failed)
|
||||
→ failed (job-level error)
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "data/jobs.db"):
|
||||
"""
|
||||
Initialize JobManager.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
"""
|
||||
self.db_path = db_path
|
||||
|
||||
def create_job(
|
||||
self,
|
||||
config_path: str,
|
||||
date_range: List[str],
|
||||
models: List[str]
|
||||
) -> str:
|
||||
"""
|
||||
Create new simulation job.
|
||||
|
||||
Args:
|
||||
config_path: Path to configuration file
|
||||
date_range: List of dates to simulate (YYYY-MM-DD)
|
||||
models: List of model signatures to execute
|
||||
|
||||
Returns:
|
||||
job_id: UUID of created job
|
||||
|
||||
Raises:
|
||||
ValueError: If another job is already running/pending
|
||||
"""
|
||||
if not self.can_start_new_job():
|
||||
raise ValueError("Another simulation job is already running or pending")
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Insert job
|
||||
cursor.execute("""
|
||||
INSERT INTO jobs (
|
||||
job_id, config_path, status, date_range, models, created_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
job_id,
|
||||
config_path,
|
||||
"pending",
|
||||
json.dumps(date_range),
|
||||
json.dumps(models),
|
||||
created_at
|
||||
))
|
||||
|
||||
# Create job_details for each model-day combination
|
||||
for date in date_range:
|
||||
for model in models:
|
||||
cursor.execute("""
|
||||
INSERT INTO job_details (
|
||||
job_id, date, model, status
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (job_id, date, model, "pending"))
|
||||
|
||||
conn.commit()
|
||||
logger.info(f"Created job {job_id} with {len(date_range)} dates and {len(models)} models")
|
||||
|
||||
return job_id
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get job by ID.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
Job data dict or None if not found
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
job_id, config_path, status, date_range, models,
|
||||
created_at, started_at, updated_at, completed_at,
|
||||
total_duration_seconds, error
|
||||
FROM jobs
|
||||
WHERE job_id = ?
|
||||
""", (job_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return {
|
||||
"job_id": row[0],
|
||||
"config_path": row[1],
|
||||
"status": row[2],
|
||||
"date_range": json.loads(row[3]),
|
||||
"models": json.loads(row[4]),
|
||||
"created_at": row[5],
|
||||
"started_at": row[6],
|
||||
"updated_at": row[7],
|
||||
"completed_at": row[8],
|
||||
"total_duration_seconds": row[9],
|
||||
"error": row[10]
|
||||
}
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_current_job(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get most recent job.
|
||||
|
||||
Returns:
|
||||
Most recent job data or None if no jobs exist
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
job_id, config_path, status, date_range, models,
|
||||
created_at, started_at, updated_at, completed_at,
|
||||
total_duration_seconds, error
|
||||
FROM jobs
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return {
|
||||
"job_id": row[0],
|
||||
"config_path": row[1],
|
||||
"status": row[2],
|
||||
"date_range": json.loads(row[3]),
|
||||
"models": json.loads(row[4]),
|
||||
"created_at": row[5],
|
||||
"started_at": row[6],
|
||||
"updated_at": row[7],
|
||||
"completed_at": row[8],
|
||||
"total_duration_seconds": row[9],
|
||||
"error": row[10]
|
||||
}
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def find_job_by_date_range(self, date_range: List[str]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Find job with matching date range.
|
||||
|
||||
Args:
|
||||
date_range: List of dates to match
|
||||
|
||||
Returns:
|
||||
Job data or None if not found
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
date_range_json = json.dumps(date_range)
|
||||
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
job_id, config_path, status, date_range, models,
|
||||
created_at, started_at, updated_at, completed_at,
|
||||
total_duration_seconds, error
|
||||
FROM jobs
|
||||
WHERE date_range = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (date_range_json,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return {
|
||||
"job_id": row[0],
|
||||
"config_path": row[1],
|
||||
"status": row[2],
|
||||
"date_range": json.loads(row[3]),
|
||||
"models": json.loads(row[4]),
|
||||
"created_at": row[5],
|
||||
"started_at": row[6],
|
||||
"updated_at": row[7],
|
||||
"completed_at": row[8],
|
||||
"total_duration_seconds": row[9],
|
||||
"error": row[10]
|
||||
}
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_job_status(
|
||||
self,
|
||||
job_id: str,
|
||||
status: str,
|
||||
error: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Update job status.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
status: New status (pending/running/completed/partial/failed)
|
||||
error: Optional error message
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
updated_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
# Set timestamps based on status
|
||||
if status == "running":
|
||||
cursor.execute("""
|
||||
UPDATE jobs
|
||||
SET status = ?, started_at = ?, updated_at = ?
|
||||
WHERE job_id = ?
|
||||
""", (status, updated_at, updated_at, job_id))
|
||||
|
||||
elif status in ("completed", "partial", "failed"):
|
||||
# Calculate duration
|
||||
cursor.execute("""
|
||||
SELECT started_at FROM jobs WHERE job_id = ?
|
||||
""", (job_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
duration_seconds = None
|
||||
|
||||
if row and row[0]:
|
||||
started_at = datetime.fromisoformat(row[0].replace("Z", ""))
|
||||
completed_at = datetime.fromisoformat(updated_at.replace("Z", ""))
|
||||
duration_seconds = (completed_at - started_at).total_seconds()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE jobs
|
||||
SET status = ?, completed_at = ?, updated_at = ?,
|
||||
total_duration_seconds = ?, error = ?
|
||||
WHERE job_id = ?
|
||||
""", (status, updated_at, updated_at, duration_seconds, error, job_id))
|
||||
|
||||
else:
|
||||
# Just update status
|
||||
cursor.execute("""
|
||||
UPDATE jobs
|
||||
SET status = ?, updated_at = ?, error = ?
|
||||
WHERE job_id = ?
|
||||
""", (status, updated_at, error, job_id))
|
||||
|
||||
conn.commit()
|
||||
logger.debug(f"Updated job {job_id} status to {status}")
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_job_detail_status(
|
||||
self,
|
||||
job_id: str,
|
||||
date: str,
|
||||
model: str,
|
||||
status: str,
|
||||
error: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Update model-day status and auto-update job status.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
model: Model signature
|
||||
status: New status (pending/running/completed/failed)
|
||||
error: Optional error message
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
updated_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
if status == "running":
|
||||
cursor.execute("""
|
||||
UPDATE job_details
|
||||
SET status = ?, started_at = ?
|
||||
WHERE job_id = ? AND date = ? AND model = ?
|
||||
""", (status, updated_at, job_id, date, model))
|
||||
|
||||
# Update job to running if not already
|
||||
cursor.execute("""
|
||||
UPDATE jobs
|
||||
SET status = 'running', started_at = COALESCE(started_at, ?), updated_at = ?
|
||||
WHERE job_id = ? AND status = 'pending'
|
||||
""", (updated_at, updated_at, job_id))
|
||||
|
||||
elif status in ("completed", "failed"):
|
||||
# Calculate duration for detail
|
||||
cursor.execute("""
|
||||
SELECT started_at FROM job_details
|
||||
WHERE job_id = ? AND date = ? AND model = ?
|
||||
""", (job_id, date, model))
|
||||
|
||||
row = cursor.fetchone()
|
||||
duration_seconds = None
|
||||
|
||||
if row and row[0]:
|
||||
started_at = datetime.fromisoformat(row[0].replace("Z", ""))
|
||||
completed_at = datetime.fromisoformat(updated_at.replace("Z", ""))
|
||||
duration_seconds = (completed_at - started_at).total_seconds()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE job_details
|
||||
SET status = ?, completed_at = ?, duration_seconds = ?, error = ?
|
||||
WHERE job_id = ? AND date = ? AND model = ?
|
||||
""", (status, updated_at, duration_seconds, error, job_id, date, model))
|
||||
|
||||
# Check if all details are done
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed
|
||||
FROM job_details
|
||||
WHERE job_id = ?
|
||||
""", (job_id,))
|
||||
|
||||
total, completed, failed = cursor.fetchone()
|
||||
|
||||
if completed + failed == total:
|
||||
# All done - determine final status
|
||||
if failed == 0:
|
||||
final_status = "completed"
|
||||
elif completed > 0:
|
||||
final_status = "partial"
|
||||
else:
|
||||
final_status = "failed"
|
||||
|
||||
# Calculate job duration
|
||||
cursor.execute("""
|
||||
SELECT started_at FROM jobs WHERE job_id = ?
|
||||
""", (job_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
job_duration = None
|
||||
|
||||
if row and row[0]:
|
||||
started_at = datetime.fromisoformat(row[0].replace("Z", ""))
|
||||
completed_at = datetime.fromisoformat(updated_at.replace("Z", ""))
|
||||
job_duration = (completed_at - started_at).total_seconds()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE jobs
|
||||
SET status = ?, completed_at = ?, updated_at = ?, total_duration_seconds = ?
|
||||
WHERE job_id = ?
|
||||
""", (final_status, updated_at, updated_at, job_duration, job_id))
|
||||
|
||||
conn.commit()
|
||||
logger.debug(f"Updated job_detail {job_id}/{date}/{model} to {status}")
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_job_details(self, job_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all model-day execution details for a job.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
List of job_detail records with date, model, status, error
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT date, model, status, error, started_at, completed_at, duration_seconds
|
||||
FROM job_details
|
||||
WHERE job_id = ?
|
||||
ORDER BY date, model
|
||||
""", (job_id,))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
|
||||
details = []
|
||||
for row in rows:
|
||||
details.append({
|
||||
"date": row[0],
|
||||
"model": row[1],
|
||||
"status": row[2],
|
||||
"error": row[3],
|
||||
"started_at": row[4],
|
||||
"completed_at": row[5],
|
||||
"duration_seconds": row[6]
|
||||
})
|
||||
|
||||
return details
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_job_progress(self, job_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get job progress summary.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
Progress dict with total_model_days, completed, failed, current, details
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed
|
||||
FROM job_details
|
||||
WHERE job_id = ?
|
||||
""", (job_id,))
|
||||
|
||||
total, completed, failed = cursor.fetchone()
|
||||
|
||||
# Get currently running model-day
|
||||
cursor.execute("""
|
||||
SELECT date, model
|
||||
FROM job_details
|
||||
WHERE job_id = ? AND status = 'running'
|
||||
LIMIT 1
|
||||
""", (job_id,))
|
||||
|
||||
current_row = cursor.fetchone()
|
||||
current = {"date": current_row[0], "model": current_row[1]} if current_row else None
|
||||
|
||||
# Get all details
|
||||
cursor.execute("""
|
||||
SELECT date, model, status, duration_seconds, error
|
||||
FROM job_details
|
||||
WHERE job_id = ?
|
||||
ORDER BY date, model
|
||||
""", (job_id,))
|
||||
|
||||
details = []
|
||||
for row in cursor.fetchall():
|
||||
details.append({
|
||||
"date": row[0],
|
||||
"model": row[1],
|
||||
"status": row[2],
|
||||
"duration_seconds": row[3],
|
||||
"error": row[4]
|
||||
})
|
||||
|
||||
return {
|
||||
"total_model_days": total,
|
||||
"completed": completed or 0,
|
||||
"failed": failed or 0,
|
||||
"current": current,
|
||||
"details": details
|
||||
}
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def can_start_new_job(self) -> bool:
|
||||
"""
|
||||
Check if new job can be started.
|
||||
|
||||
Returns:
|
||||
True if no jobs are pending/running, False otherwise
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM jobs
|
||||
WHERE status IN ('pending', 'running')
|
||||
""")
|
||||
|
||||
count = cursor.fetchone()[0]
|
||||
return count == 0
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_running_jobs(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all running/pending jobs.
|
||||
|
||||
Returns:
|
||||
List of job dicts
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
job_id, config_path, status, date_range, models,
|
||||
created_at, started_at, updated_at, completed_at,
|
||||
total_duration_seconds, error
|
||||
FROM jobs
|
||||
WHERE status IN ('pending', 'running')
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
jobs = []
|
||||
for row in cursor.fetchall():
|
||||
jobs.append({
|
||||
"job_id": row[0],
|
||||
"config_path": row[1],
|
||||
"status": row[2],
|
||||
"date_range": json.loads(row[3]),
|
||||
"models": json.loads(row[4]),
|
||||
"created_at": row[5],
|
||||
"started_at": row[6],
|
||||
"updated_at": row[7],
|
||||
"completed_at": row[8],
|
||||
"total_duration_seconds": row[9],
|
||||
"error": row[10]
|
||||
})
|
||||
|
||||
return jobs
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def cleanup_old_jobs(self, days: int = 30) -> Dict[str, int]:
|
||||
"""
|
||||
Delete jobs older than threshold.
|
||||
|
||||
Args:
|
||||
days: Delete jobs older than this many days
|
||||
|
||||
Returns:
|
||||
Dict with jobs_deleted count
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cutoff_date = (datetime.utcnow() - timedelta(days=days)).isoformat() + "Z"
|
||||
|
||||
# Get count before deletion
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM jobs
|
||||
WHERE created_at < ? AND status IN ('completed', 'partial', 'failed')
|
||||
""", (cutoff_date,))
|
||||
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
# Delete old jobs (foreign key cascade will delete related records)
|
||||
cursor.execute("""
|
||||
DELETE FROM jobs
|
||||
WHERE created_at < ? AND status IN ('completed', 'partial', 'failed')
|
||||
""", (cutoff_date,))
|
||||
|
||||
conn.commit()
|
||||
logger.info(f"Cleaned up {count} jobs older than {days} days")
|
||||
|
||||
return {"jobs_deleted": count}
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
366
api/main.py
Normal file
366
api/main.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
FastAPI REST API for AI-Trader simulation service.
|
||||
|
||||
Provides endpoints for:
|
||||
- Triggering simulation jobs
|
||||
- Checking job status
|
||||
- Querying results
|
||||
- Health checks
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from api.job_manager import JobManager
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.database import get_db_connection
|
||||
import threading
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Pydantic models for request/response validation
|
||||
class SimulateTriggerRequest(BaseModel):
|
||||
"""Request body for POST /simulate/trigger."""
|
||||
config_path: str = Field(..., description="Path to configuration file")
|
||||
date_range: List[str] = Field(..., min_length=1, description="List of trading dates (YYYY-MM-DD)")
|
||||
models: List[str] = Field(..., min_length=1, description="List of model signatures to simulate")
|
||||
|
||||
@field_validator("date_range")
|
||||
@classmethod
|
||||
def validate_date_range(cls, v):
|
||||
"""Validate date format."""
|
||||
for date in v:
|
||||
try:
|
||||
datetime.strptime(date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {date}. Expected YYYY-MM-DD")
|
||||
return v
|
||||
|
||||
|
||||
class SimulateTriggerResponse(BaseModel):
|
||||
"""Response body for POST /simulate/trigger."""
|
||||
job_id: str
|
||||
status: str
|
||||
total_model_days: int
|
||||
message: str
|
||||
|
||||
|
||||
class JobProgress(BaseModel):
|
||||
"""Job progress information."""
|
||||
total_model_days: int
|
||||
completed: int
|
||||
failed: int
|
||||
pending: int
|
||||
|
||||
|
||||
class JobStatusResponse(BaseModel):
|
||||
"""Response body for GET /simulate/status/{job_id}."""
|
||||
job_id: str
|
||||
status: str
|
||||
progress: JobProgress
|
||||
date_range: List[str]
|
||||
models: List[str]
|
||||
created_at: str
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
total_duration_seconds: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
details: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Response body for GET /health."""
|
||||
status: str
|
||||
database: str
|
||||
timestamp: str
|
||||
|
||||
|
||||
def create_app(db_path: str = "data/jobs.db") -> FastAPI:
|
||||
"""
|
||||
Create FastAPI application instance.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
|
||||
Returns:
|
||||
Configured FastAPI app
|
||||
"""
|
||||
app = FastAPI(
|
||||
title="AI-Trader Simulation API",
|
||||
description="REST API for triggering and monitoring AI trading simulations",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Store db_path in app state
|
||||
app.state.db_path = db_path
|
||||
|
||||
@app.post("/simulate/trigger", response_model=SimulateTriggerResponse, status_code=200)
|
||||
async def trigger_simulation(request: SimulateTriggerRequest):
|
||||
"""
|
||||
Trigger a new simulation job.
|
||||
|
||||
Creates a job with specified config, dates, and models.
|
||||
Job runs asynchronously in background thread.
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If another job is already running or config invalid
|
||||
HTTPException 422: If request validation fails
|
||||
"""
|
||||
try:
|
||||
# Validate config path exists
|
||||
if not Path(request.config_path).exists():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Config path does not exist: {request.config_path}"
|
||||
)
|
||||
|
||||
job_manager = JobManager(db_path=app.state.db_path)
|
||||
|
||||
# Check if can start new job
|
||||
if not job_manager.can_start_new_job():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Another simulation job is already running or pending. Please wait for it to complete."
|
||||
)
|
||||
|
||||
# Create job
|
||||
job_id = job_manager.create_job(
|
||||
config_path=request.config_path,
|
||||
date_range=request.date_range,
|
||||
models=request.models
|
||||
)
|
||||
|
||||
# Start worker in background thread (only if not in test mode)
|
||||
if not getattr(app.state, "test_mode", False):
|
||||
def run_worker():
|
||||
worker = SimulationWorker(job_id=job_id, db_path=app.state.db_path)
|
||||
worker.run()
|
||||
|
||||
thread = threading.Thread(target=run_worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Triggered simulation job {job_id}")
|
||||
|
||||
return SimulateTriggerResponse(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
total_model_days=len(request.date_range) * len(request.models),
|
||||
message=f"Simulation job {job_id} created and started"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger simulation: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@app.get("/simulate/status/{job_id}", response_model=JobStatusResponse)
|
||||
async def get_job_status(job_id: str):
|
||||
"""
|
||||
Get status and progress of a simulation job.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
Job status, progress, and model-day details
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If job not found
|
||||
"""
|
||||
try:
|
||||
job_manager = JobManager(db_path=app.state.db_path)
|
||||
|
||||
# Get job info
|
||||
job = job_manager.get_job(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
||||
|
||||
# Get progress
|
||||
progress = job_manager.get_job_progress(job_id)
|
||||
|
||||
# Get model-day details
|
||||
details = job_manager.get_job_details(job_id)
|
||||
|
||||
# Calculate pending (total - completed - failed)
|
||||
pending = progress["total_model_days"] - progress["completed"] - progress["failed"]
|
||||
|
||||
return JobStatusResponse(
|
||||
job_id=job["job_id"],
|
||||
status=job["status"],
|
||||
progress=JobProgress(
|
||||
total_model_days=progress["total_model_days"],
|
||||
completed=progress["completed"],
|
||||
failed=progress["failed"],
|
||||
pending=pending
|
||||
),
|
||||
date_range=job["date_range"],
|
||||
models=job["models"],
|
||||
created_at=job["created_at"],
|
||||
started_at=job.get("started_at"),
|
||||
completed_at=job.get("completed_at"),
|
||||
total_duration_seconds=job.get("total_duration_seconds"),
|
||||
error=job.get("error"),
|
||||
details=details
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get job status: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@app.get("/results")
|
||||
async def get_results(
|
||||
job_id: Optional[str] = Query(None, description="Filter by job ID"),
|
||||
date: Optional[str] = Query(None, description="Filter by date (YYYY-MM-DD)"),
|
||||
model: Optional[str] = Query(None, description="Filter by model signature")
|
||||
):
|
||||
"""
|
||||
Query simulation results.
|
||||
|
||||
Supports filtering by job_id, date, and/or model.
|
||||
Returns position data with holdings.
|
||||
|
||||
Args:
|
||||
job_id: Optional job UUID filter
|
||||
date: Optional date filter (YYYY-MM-DD)
|
||||
model: Optional model signature filter
|
||||
|
||||
Returns:
|
||||
List of position records with holdings
|
||||
"""
|
||||
try:
|
||||
conn = get_db_connection(app.state.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Build query with filters
|
||||
query = """
|
||||
SELECT
|
||||
p.id,
|
||||
p.job_id,
|
||||
p.date,
|
||||
p.model,
|
||||
p.action_id,
|
||||
p.action_type,
|
||||
p.symbol,
|
||||
p.amount,
|
||||
p.price,
|
||||
p.cash,
|
||||
p.portfolio_value,
|
||||
p.daily_profit,
|
||||
p.daily_return_pct,
|
||||
p.created_at
|
||||
FROM positions p
|
||||
WHERE 1=1
|
||||
"""
|
||||
params = []
|
||||
|
||||
if job_id:
|
||||
query += " AND p.job_id = ?"
|
||||
params.append(job_id)
|
||||
|
||||
if date:
|
||||
query += " AND p.date = ?"
|
||||
params.append(date)
|
||||
|
||||
if model:
|
||||
query += " AND p.model = ?"
|
||||
params.append(model)
|
||||
|
||||
query += " ORDER BY p.date, p.model, p.action_id"
|
||||
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
position_id = row[0]
|
||||
|
||||
# Get holdings for this position
|
||||
cursor.execute("""
|
||||
SELECT symbol, quantity
|
||||
FROM holdings
|
||||
WHERE position_id = ?
|
||||
ORDER BY symbol
|
||||
""", (position_id,))
|
||||
|
||||
holdings = [{"symbol": h[0], "quantity": h[1]} for h in cursor.fetchall()]
|
||||
|
||||
results.append({
|
||||
"id": row[0],
|
||||
"job_id": row[1],
|
||||
"date": row[2],
|
||||
"model": row[3],
|
||||
"action_id": row[4],
|
||||
"action_type": row[5],
|
||||
"symbol": row[6],
|
||||
"amount": row[7],
|
||||
"price": row[8],
|
||||
"cash": row[9],
|
||||
"portfolio_value": row[10],
|
||||
"daily_profit": row[11],
|
||||
"daily_return_pct": row[12],
|
||||
"created_at": row[13],
|
||||
"holdings": holdings
|
||||
})
|
||||
|
||||
conn.close()
|
||||
|
||||
return {"results": results, "count": len(results)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query results: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint.
|
||||
|
||||
Verifies database connectivity and service status.
|
||||
|
||||
Returns:
|
||||
Health status and timestamp
|
||||
"""
|
||||
try:
|
||||
# Test database connection
|
||||
conn = get_db_connection(app.state.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
database_status = "connected"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
database_status = "disconnected"
|
||||
|
||||
return HealthResponse(
|
||||
status="healthy" if database_status == "connected" else "unhealthy",
|
||||
database=database_status,
|
||||
timestamp=datetime.utcnow().isoformat() + "Z"
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Create default app instance
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||
342
api/model_day_executor.py
Normal file
342
api/model_day_executor.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Single model-day execution engine.
|
||||
|
||||
This module provides:
|
||||
- Isolated execution of one model for one trading day
|
||||
- Runtime config management per execution
|
||||
- Result persistence to SQLite (positions, holdings, reasoning)
|
||||
- Automatic status updates via JobManager
|
||||
- Cleanup of temporary resources
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
|
||||
from api.runtime_manager import RuntimeConfigManager
|
||||
from api.job_manager import JobManager
|
||||
from api.database import get_db_connection
|
||||
|
||||
# Lazy import to avoid loading heavy dependencies during testing
|
||||
if TYPE_CHECKING:
|
||||
from agent.base_agent.base_agent import BaseAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelDayExecutor:
|
||||
"""
|
||||
Executes a single model for a single trading day.
|
||||
|
||||
Responsibilities:
|
||||
- Create isolated runtime config
|
||||
- Initialize and run trading agent
|
||||
- Persist results to SQLite
|
||||
- Update job status
|
||||
- Cleanup resources
|
||||
|
||||
Lifecycle:
|
||||
1. __init__() → Create runtime config
|
||||
2. execute() → Run agent, write results, update status
|
||||
3. cleanup → Delete runtime config
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job_id: str,
|
||||
date: str,
|
||||
model_sig: str,
|
||||
config_path: str,
|
||||
db_path: str = "data/jobs.db",
|
||||
data_dir: str = "data"
|
||||
):
|
||||
"""
|
||||
Initialize ModelDayExecutor.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
model_sig: Model signature
|
||||
config_path: Path to configuration file
|
||||
db_path: Path to SQLite database
|
||||
data_dir: Data directory for runtime configs
|
||||
"""
|
||||
self.job_id = job_id
|
||||
self.date = date
|
||||
self.model_sig = model_sig
|
||||
self.config_path = config_path
|
||||
self.db_path = db_path
|
||||
self.data_dir = data_dir
|
||||
|
||||
# Create isolated runtime config
|
||||
self.runtime_manager = RuntimeConfigManager(data_dir=data_dir)
|
||||
self.runtime_config_path = self.runtime_manager.create_runtime_config(
|
||||
job_id=job_id,
|
||||
model_sig=model_sig,
|
||||
date=date
|
||||
)
|
||||
|
||||
self.job_manager = JobManager(db_path=db_path)
|
||||
|
||||
logger.info(f"Initialized executor for {model_sig} on {date} (job: {job_id})")
|
||||
|
||||
def execute(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute trading session and persist results.
|
||||
|
||||
Returns:
|
||||
Result dict with success status and metadata
|
||||
|
||||
Process:
|
||||
1. Update job_detail status to 'running'
|
||||
2. Initialize and run trading agent
|
||||
3. Write results to SQLite
|
||||
4. Update job_detail status to 'completed' or 'failed'
|
||||
5. Cleanup runtime config
|
||||
|
||||
SQLite writes:
|
||||
- positions: Trading position record
|
||||
- holdings: Portfolio holdings breakdown
|
||||
- reasoning_logs: AI reasoning steps (if available)
|
||||
- tool_usage: Tool usage statistics (if available)
|
||||
"""
|
||||
try:
|
||||
# Update status to running
|
||||
self.job_manager.update_job_detail_status(
|
||||
self.job_id,
|
||||
self.date,
|
||||
self.model_sig,
|
||||
"running"
|
||||
)
|
||||
|
||||
# Set environment variable for agent to use isolated config
|
||||
os.environ["RUNTIME_ENV_PATH"] = self.runtime_config_path
|
||||
|
||||
# Initialize agent
|
||||
agent = self._initialize_agent()
|
||||
|
||||
# Run trading session
|
||||
logger.info(f"Running trading session for {self.model_sig} on {self.date}")
|
||||
session_result = agent.run_trading_session(self.date)
|
||||
|
||||
# Persist results to SQLite
|
||||
self._write_results_to_db(agent, session_result)
|
||||
|
||||
# Update status to completed
|
||||
self.job_manager.update_job_detail_status(
|
||||
self.job_id,
|
||||
self.date,
|
||||
self.model_sig,
|
||||
"completed"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully completed {self.model_sig} on {self.date}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"job_id": self.job_id,
|
||||
"date": self.date,
|
||||
"model": self.model_sig,
|
||||
"session_result": session_result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Execution failed: {str(e)}"
|
||||
logger.error(f"{self.model_sig} on {self.date}: {error_msg}", exc_info=True)
|
||||
|
||||
# Update status to failed
|
||||
self.job_manager.update_job_detail_status(
|
||||
self.job_id,
|
||||
self.date,
|
||||
self.model_sig,
|
||||
"failed",
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"job_id": self.job_id,
|
||||
"date": self.date,
|
||||
"model": self.model_sig,
|
||||
"error": error_msg
|
||||
}
|
||||
|
||||
finally:
|
||||
# Always cleanup runtime config
|
||||
self.runtime_manager.cleanup_runtime_config(self.runtime_config_path)
|
||||
|
||||
def _initialize_agent(self):
|
||||
"""
|
||||
Initialize trading agent with config.
|
||||
|
||||
Returns:
|
||||
Configured BaseAgent instance
|
||||
"""
|
||||
# Lazy import to avoid loading heavy dependencies during testing
|
||||
from agent.base_agent.base_agent import BaseAgent
|
||||
|
||||
# Load config
|
||||
import json
|
||||
with open(self.config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Find model config
|
||||
model_config = None
|
||||
for model in config.get("models", []):
|
||||
if model.get("signature") == self.model_sig:
|
||||
model_config = model
|
||||
break
|
||||
|
||||
if not model_config:
|
||||
raise ValueError(f"Model {self.model_sig} not found in config")
|
||||
|
||||
# Initialize agent
|
||||
agent = BaseAgent(
|
||||
model_name=model_config.get("basemodel"),
|
||||
signature=self.model_sig,
|
||||
config=config
|
||||
)
|
||||
|
||||
# Register agent (creates initial position if needed)
|
||||
agent.register_agent()
|
||||
|
||||
return agent
|
||||
|
||||
def _write_results_to_db(self, agent, session_result: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Write execution results to SQLite.
|
||||
|
||||
Args:
|
||||
agent: Trading agent instance
|
||||
session_result: Result from run_trading_session()
|
||||
|
||||
Writes to:
|
||||
- positions: Position record with action and P&L
|
||||
- holdings: Current portfolio holdings
|
||||
- reasoning_logs: AI reasoning steps (if available)
|
||||
- tool_usage: Tool usage stats (if available)
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Get current positions and trade info
|
||||
positions = agent.get_positions() if hasattr(agent, 'get_positions') else {}
|
||||
last_trade = agent.get_last_trade() if hasattr(agent, 'get_last_trade') else None
|
||||
|
||||
# Calculate portfolio value
|
||||
current_prices = agent.get_current_prices() if hasattr(agent, 'get_current_prices') else {}
|
||||
total_value = self._calculate_portfolio_value(positions, current_prices)
|
||||
|
||||
# Get previous value for P&L calculation
|
||||
cursor.execute("""
|
||||
SELECT portfolio_value
|
||||
FROM positions
|
||||
WHERE job_id = ? AND model = ? AND date < ?
|
||||
ORDER BY date DESC
|
||||
LIMIT 1
|
||||
""", (self.job_id, self.model_sig, self.date))
|
||||
|
||||
row = cursor.fetchone()
|
||||
previous_value = row[0] if row else 10000.0 # Initial portfolio value
|
||||
|
||||
daily_profit = total_value - previous_value
|
||||
daily_return_pct = (daily_profit / previous_value * 100) if previous_value > 0 else 0
|
||||
|
||||
# Determine action_id (sequence number for this model)
|
||||
cursor.execute("""
|
||||
SELECT COALESCE(MAX(action_id), 0) + 1
|
||||
FROM positions
|
||||
WHERE job_id = ? AND model = ?
|
||||
""", (self.job_id, self.model_sig))
|
||||
|
||||
action_id = cursor.fetchone()[0]
|
||||
|
||||
# Insert position record
|
||||
action_type = last_trade.get("action") if last_trade else "no_trade"
|
||||
symbol = last_trade.get("symbol") if last_trade else None
|
||||
amount = last_trade.get("amount") if last_trade else None
|
||||
price = last_trade.get("price") if last_trade else None
|
||||
cash = positions.get("CASH", 0.0)
|
||||
|
||||
from datetime import datetime
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO positions (
|
||||
job_id, date, model, action_id, action_type, symbol,
|
||||
amount, price, cash, portfolio_value, daily_profit, daily_return_pct, created_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
self.job_id, self.date, self.model_sig, action_id, action_type,
|
||||
symbol, amount, price, cash, total_value,
|
||||
daily_profit, daily_return_pct, created_at
|
||||
))
|
||||
|
||||
position_id = cursor.lastrowid
|
||||
|
||||
# Insert holdings
|
||||
for symbol, quantity in positions.items():
|
||||
cursor.execute("""
|
||||
INSERT INTO holdings (position_id, symbol, quantity)
|
||||
VALUES (?, ?, ?)
|
||||
""", (position_id, symbol, float(quantity)))
|
||||
|
||||
# Insert reasoning logs (if available)
|
||||
if hasattr(agent, 'get_reasoning_steps'):
|
||||
reasoning_steps = agent.get_reasoning_steps()
|
||||
for step in reasoning_steps:
|
||||
cursor.execute("""
|
||||
INSERT INTO reasoning_logs (
|
||||
job_id, date, model, step_number, timestamp, content
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
self.job_id, self.date, self.model_sig,
|
||||
step.get("step"), created_at, step.get("reasoning")
|
||||
))
|
||||
|
||||
# Insert tool usage (if available)
|
||||
if hasattr(agent, 'get_tool_usage') and hasattr(agent, 'get_tool_usage'):
|
||||
tool_usage = agent.get_tool_usage()
|
||||
for tool_name, count in tool_usage.items():
|
||||
cursor.execute("""
|
||||
INSERT INTO tool_usage (
|
||||
job_id, date, model, tool_name, call_count
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (self.job_id, self.date, self.model_sig, tool_name, count))
|
||||
|
||||
conn.commit()
|
||||
logger.debug(f"Wrote results to DB for {self.model_sig} on {self.date}")
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _calculate_portfolio_value(
|
||||
self,
|
||||
positions: Dict[str, float],
|
||||
current_prices: Dict[str, float]
|
||||
) -> float:
|
||||
"""
|
||||
Calculate total portfolio value.
|
||||
|
||||
Args:
|
||||
positions: Current holdings (symbol: quantity)
|
||||
current_prices: Current market prices (symbol: price)
|
||||
|
||||
Returns:
|
||||
Total portfolio value in dollars
|
||||
"""
|
||||
total = 0.0
|
||||
|
||||
for symbol, quantity in positions.items():
|
||||
if symbol == "CASH":
|
||||
total += quantity
|
||||
else:
|
||||
price = current_prices.get(symbol, 0.0)
|
||||
total += quantity * price
|
||||
|
||||
return total
|
||||
459
api/models.py
Normal file
459
api/models.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
Pydantic data models for AI-Trader API.
|
||||
|
||||
This module defines:
|
||||
- Request models (input validation)
|
||||
- Response models (output serialization)
|
||||
- Nested models for complex data structures
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Literal, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# ==================== Request Models ====================
|
||||
|
||||
class TriggerSimulationRequest(BaseModel):
|
||||
"""Request model for POST /simulate/trigger endpoint."""
|
||||
|
||||
config_path: str = Field(
|
||||
default="configs/default_config.json",
|
||||
description="Path to configuration file"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"config_path": "configs/default_config.json"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ResultsQueryParams(BaseModel):
|
||||
"""Query parameters for GET /results endpoint."""
|
||||
|
||||
date: str = Field(
|
||||
...,
|
||||
pattern=r"^\d{4}-\d{2}-\d{2}$",
|
||||
description="Date in YYYY-MM-DD format"
|
||||
)
|
||||
model: Optional[str] = Field(
|
||||
None,
|
||||
description="Model signature filter (optional)"
|
||||
)
|
||||
detail: Literal["minimal", "full"] = Field(
|
||||
default="minimal",
|
||||
description="Response detail level"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"date": "2025-01-16",
|
||||
"model": "gpt-5",
|
||||
"detail": "minimal"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== Nested Response Models ====================
|
||||
|
||||
class JobProgress(BaseModel):
|
||||
"""Progress tracking for simulation jobs."""
|
||||
|
||||
total_model_days: int = Field(
|
||||
...,
|
||||
description="Total number of model-days to execute"
|
||||
)
|
||||
completed: int = Field(
|
||||
...,
|
||||
description="Number of model-days completed"
|
||||
)
|
||||
failed: int = Field(
|
||||
...,
|
||||
description="Number of model-days that failed"
|
||||
)
|
||||
current: Optional[Dict[str, str]] = Field(
|
||||
None,
|
||||
description="Currently executing model-day (if any)"
|
||||
)
|
||||
details: Optional[List[Dict]] = Field(
|
||||
None,
|
||||
description="Detailed progress for each model-day"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"total_model_days": 4,
|
||||
"completed": 2,
|
||||
"failed": 0,
|
||||
"current": {"date": "2025-01-16", "model": "gpt-5"},
|
||||
"details": [
|
||||
{
|
||||
"date": "2025-01-16",
|
||||
"model": "gpt-5",
|
||||
"status": "completed",
|
||||
"duration_seconds": 45.2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DailyPnL(BaseModel):
|
||||
"""Daily profit and loss metrics."""
|
||||
|
||||
profit: float = Field(
|
||||
...,
|
||||
description="Daily profit in dollars"
|
||||
)
|
||||
return_pct: float = Field(
|
||||
...,
|
||||
description="Daily return percentage"
|
||||
)
|
||||
portfolio_value: float = Field(
|
||||
...,
|
||||
description="Total portfolio value"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"profit": 150.50,
|
||||
"return_pct": 1.51,
|
||||
"portfolio_value": 10150.50
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class Trade(BaseModel):
|
||||
"""Individual trade record."""
|
||||
|
||||
id: int = Field(
|
||||
...,
|
||||
description="Trade sequence ID"
|
||||
)
|
||||
action: str = Field(
|
||||
...,
|
||||
description="Trade action (buy/sell)"
|
||||
)
|
||||
symbol: str = Field(
|
||||
...,
|
||||
description="Stock symbol"
|
||||
)
|
||||
amount: int = Field(
|
||||
...,
|
||||
description="Number of shares"
|
||||
)
|
||||
price: Optional[float] = Field(
|
||||
None,
|
||||
description="Trade price per share"
|
||||
)
|
||||
total: Optional[float] = Field(
|
||||
None,
|
||||
description="Total trade value"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"id": 1,
|
||||
"action": "buy",
|
||||
"symbol": "AAPL",
|
||||
"amount": 10,
|
||||
"price": 255.88,
|
||||
"total": 2558.80
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AIReasoning(BaseModel):
|
||||
"""AI reasoning and decision-making summary."""
|
||||
|
||||
total_steps: int = Field(
|
||||
...,
|
||||
description="Total reasoning steps taken"
|
||||
)
|
||||
stop_signal_received: bool = Field(
|
||||
...,
|
||||
description="Whether AI sent stop signal"
|
||||
)
|
||||
reasoning_summary: str = Field(
|
||||
...,
|
||||
description="Summary of AI reasoning"
|
||||
)
|
||||
tool_usage: Dict[str, int] = Field(
|
||||
...,
|
||||
description="Tool usage counts"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"total_steps": 15,
|
||||
"stop_signal_received": True,
|
||||
"reasoning_summary": "Market analysis indicates...",
|
||||
"tool_usage": {
|
||||
"search": 3,
|
||||
"get_price": 5,
|
||||
"math": 2,
|
||||
"trade": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelResult(BaseModel):
|
||||
"""Simulation results for a single model on a single date."""
|
||||
|
||||
model: str = Field(
|
||||
...,
|
||||
description="Model signature"
|
||||
)
|
||||
positions: Dict[str, float] = Field(
|
||||
...,
|
||||
description="Current positions (symbol: quantity)"
|
||||
)
|
||||
daily_pnl: DailyPnL = Field(
|
||||
...,
|
||||
description="Daily P&L metrics"
|
||||
)
|
||||
trades: Optional[List[Trade]] = Field(
|
||||
None,
|
||||
description="Trades executed (detail=full only)"
|
||||
)
|
||||
ai_reasoning: Optional[AIReasoning] = Field(
|
||||
None,
|
||||
description="AI reasoning summary (detail=full only)"
|
||||
)
|
||||
log_file_path: Optional[str] = Field(
|
||||
None,
|
||||
description="Path to detailed log file (detail=full only)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"model": "gpt-5",
|
||||
"positions": {
|
||||
"AAPL": 10,
|
||||
"MSFT": 5,
|
||||
"CASH": 7500.0
|
||||
},
|
||||
"daily_pnl": {
|
||||
"profit": 150.50,
|
||||
"return_pct": 1.51,
|
||||
"portfolio_value": 10150.50
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== Response Models ====================
|
||||
|
||||
class TriggerSimulationResponse(BaseModel):
|
||||
"""Response model for POST /simulate/trigger endpoint."""
|
||||
|
||||
job_id: str = Field(
|
||||
...,
|
||||
description="Unique job identifier"
|
||||
)
|
||||
status: str = Field(
|
||||
...,
|
||||
description="Job status (accepted/running/current)"
|
||||
)
|
||||
date_range: List[str] = Field(
|
||||
...,
|
||||
description="Dates to be simulated"
|
||||
)
|
||||
models: List[str] = Field(
|
||||
...,
|
||||
description="Models to execute"
|
||||
)
|
||||
created_at: str = Field(
|
||||
...,
|
||||
description="Job creation timestamp (ISO 8601)"
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="Human-readable status message"
|
||||
)
|
||||
progress: Optional[JobProgress] = Field(
|
||||
None,
|
||||
description="Progress (if job already running)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"job_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"status": "accepted",
|
||||
"date_range": ["2025-01-16", "2025-01-17"],
|
||||
"models": ["gpt-5", "claude-3.7-sonnet"],
|
||||
"created_at": "2025-01-20T14:30:00Z",
|
||||
"message": "Simulation job queued successfully"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class JobStatusResponse(BaseModel):
|
||||
"""Response model for GET /simulate/status/{job_id} endpoint."""
|
||||
|
||||
job_id: str = Field(
|
||||
...,
|
||||
description="Job identifier"
|
||||
)
|
||||
status: str = Field(
|
||||
...,
|
||||
description="Job status (pending/running/completed/partial/failed)"
|
||||
)
|
||||
date_range: List[str] = Field(
|
||||
...,
|
||||
description="Dates being simulated"
|
||||
)
|
||||
models: List[str] = Field(
|
||||
...,
|
||||
description="Models being executed"
|
||||
)
|
||||
progress: JobProgress = Field(
|
||||
...,
|
||||
description="Execution progress"
|
||||
)
|
||||
created_at: str = Field(
|
||||
...,
|
||||
description="Job creation timestamp"
|
||||
)
|
||||
updated_at: Optional[str] = Field(
|
||||
None,
|
||||
description="Last update timestamp"
|
||||
)
|
||||
completed_at: Optional[str] = Field(
|
||||
None,
|
||||
description="Job completion timestamp"
|
||||
)
|
||||
total_duration_seconds: Optional[float] = Field(
|
||||
None,
|
||||
description="Total execution duration"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"job_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"status": "running",
|
||||
"date_range": ["2025-01-16", "2025-01-17"],
|
||||
"models": ["gpt-5"],
|
||||
"progress": {
|
||||
"total_model_days": 2,
|
||||
"completed": 1,
|
||||
"failed": 0,
|
||||
"current": {"date": "2025-01-17", "model": "gpt-5"}
|
||||
},
|
||||
"created_at": "2025-01-20T14:30:00Z"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ResultsResponse(BaseModel):
|
||||
"""Response model for GET /results endpoint."""
|
||||
|
||||
date: str = Field(
|
||||
...,
|
||||
description="Trading date"
|
||||
)
|
||||
results: List[ModelResult] = Field(
|
||||
...,
|
||||
description="Results for each model"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"date": "2025-01-16",
|
||||
"results": [
|
||||
{
|
||||
"model": "gpt-5",
|
||||
"positions": {"AAPL": 10, "CASH": 7500.0},
|
||||
"daily_pnl": {
|
||||
"profit": 150.50,
|
||||
"return_pct": 1.51,
|
||||
"portfolio_value": 10150.50
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""Response model for GET /health endpoint."""
|
||||
|
||||
status: str = Field(
|
||||
...,
|
||||
description="Overall health status (healthy/unhealthy)"
|
||||
)
|
||||
timestamp: str = Field(
|
||||
...,
|
||||
description="Health check timestamp"
|
||||
)
|
||||
services: Dict[str, Dict] = Field(
|
||||
...,
|
||||
description="Status of each service"
|
||||
)
|
||||
storage: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Storage status"
|
||||
)
|
||||
database: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Database status"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": "healthy",
|
||||
"timestamp": "2025-01-20T14:30:00Z",
|
||||
"services": {
|
||||
"mcp_math": {"status": "up", "url": "http://localhost:8000/mcp"},
|
||||
"mcp_search": {"status": "up", "url": "http://localhost:8001/mcp"}
|
||||
},
|
||||
"storage": {
|
||||
"data_directory": "/app/data",
|
||||
"writable": True,
|
||||
"free_space_mb": 15234
|
||||
},
|
||||
"database": {
|
||||
"status": "connected",
|
||||
"path": "/app/data/jobs.db"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Standard error response model."""
|
||||
|
||||
error: str = Field(
|
||||
...,
|
||||
description="Error code/type"
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="Human-readable error message"
|
||||
)
|
||||
details: Optional[Dict] = Field(
|
||||
None,
|
||||
description="Additional error details"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"error": "invalid_date",
|
||||
"message": "Date must be in YYYY-MM-DD format",
|
||||
"details": {"provided": "2025/01/16"}
|
||||
}
|
||||
}
|
||||
131
api/runtime_manager.py
Normal file
131
api/runtime_manager.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Runtime configuration manager for isolated model-day execution.
|
||||
|
||||
This module provides:
|
||||
- Isolated runtime config file creation per model-day
|
||||
- Prevention of state collisions between concurrent executions
|
||||
- Automatic cleanup of temporary config files
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RuntimeConfigManager:
|
||||
"""
|
||||
Manages isolated runtime configuration files for concurrent model execution.
|
||||
|
||||
Problem:
|
||||
Multiple models running concurrently need separate runtime_env.json files
|
||||
to avoid race conditions on TODAY_DATE, SIGNATURE, IF_TRADE values.
|
||||
|
||||
Solution:
|
||||
Create temporary runtime config file per model-day execution:
|
||||
- /app/data/runtime_env_{job_id}_{model}_{date}.json
|
||||
|
||||
Lifecycle:
|
||||
1. create_runtime_config() → Creates temp file
|
||||
2. Executor sets RUNTIME_ENV_PATH env var
|
||||
3. Agent uses isolated config via get_config_value/write_config_value
|
||||
4. cleanup_runtime_config() → Deletes temp file
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
"""
|
||||
Initialize RuntimeConfigManager.
|
||||
|
||||
Args:
|
||||
data_dir: Directory for runtime config files (default: "data")
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def create_runtime_config(
|
||||
self,
|
||||
job_id: str,
|
||||
model_sig: str,
|
||||
date: str
|
||||
) -> str:
|
||||
"""
|
||||
Create isolated runtime config file for this execution.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
model_sig: Model signature
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Path to created runtime config file
|
||||
|
||||
Example:
|
||||
config_path = manager.create_runtime_config(
|
||||
"abc123...",
|
||||
"gpt-5",
|
||||
"2025-01-16"
|
||||
)
|
||||
# Returns: "data/runtime_env_abc123_gpt-5_2025-01-16.json"
|
||||
"""
|
||||
# Generate unique filename (use first 8 chars of job_id for brevity)
|
||||
job_id_short = job_id[:8] if len(job_id) > 8 else job_id
|
||||
filename = f"runtime_env_{job_id_short}_{model_sig}_{date}.json"
|
||||
config_path = self.data_dir / filename
|
||||
|
||||
# Initialize with default values
|
||||
initial_config = {
|
||||
"TODAY_DATE": date,
|
||||
"SIGNATURE": model_sig,
|
||||
"IF_TRADE": False,
|
||||
"JOB_ID": job_id
|
||||
}
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(initial_config, f, indent=4)
|
||||
|
||||
logger.debug(f"Created runtime config: {config_path}")
|
||||
return str(config_path)
|
||||
|
||||
def cleanup_runtime_config(self, config_path: str) -> None:
|
||||
"""
|
||||
Delete runtime config file after execution.
|
||||
|
||||
Args:
|
||||
config_path: Path to runtime config file
|
||||
|
||||
Note:
|
||||
Silently ignores if file doesn't exist (already cleaned up)
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(config_path):
|
||||
os.unlink(config_path)
|
||||
logger.debug(f"Cleaned up runtime config: {config_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup runtime config {config_path}: {e}")
|
||||
|
||||
def cleanup_all_runtime_configs(self) -> int:
|
||||
"""
|
||||
Cleanup all runtime config files (for maintenance/startup).
|
||||
|
||||
Returns:
|
||||
Number of files deleted
|
||||
|
||||
Use case:
|
||||
- On API startup to clean stale configs from previous runs
|
||||
- Periodic maintenance
|
||||
"""
|
||||
count = 0
|
||||
for config_file in self.data_dir.glob("runtime_env_*.json"):
|
||||
try:
|
||||
config_file.unlink()
|
||||
count += 1
|
||||
logger.debug(f"Deleted stale runtime config: {config_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {config_file}: {e}")
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} stale runtime config files")
|
||||
|
||||
return count
|
||||
210
api/simulation_worker.py
Normal file
210
api/simulation_worker.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Simulation job orchestration worker.
|
||||
|
||||
This module provides:
|
||||
- Job execution orchestration
|
||||
- Date-sequential, model-parallel execution
|
||||
- Progress tracking and status updates
|
||||
- Error handling and recovery
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from api.job_manager import JobManager
|
||||
from api.model_day_executor import ModelDayExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimulationWorker:
|
||||
"""
|
||||
Orchestrates execution of a simulation job.
|
||||
|
||||
Responsibilities:
|
||||
- Execute all model-day combinations for a job
|
||||
- Date-sequential execution (one date at a time)
|
||||
- Model-parallel execution (all models for a date run concurrently)
|
||||
- Update job status throughout execution
|
||||
- Handle failures gracefully
|
||||
|
||||
Execution Strategy:
|
||||
For each date in job.date_range:
|
||||
Execute all models in parallel using ThreadPoolExecutor
|
||||
Wait for all models to complete before moving to next date
|
||||
|
||||
Status Transitions:
|
||||
pending → running → completed (all succeeded)
|
||||
→ partial (some failed)
|
||||
→ failed (job-level error)
|
||||
"""
|
||||
|
||||
def __init__(self, job_id: str, db_path: str = "data/jobs.db", max_workers: int = 4):
|
||||
"""
|
||||
Initialize SimulationWorker.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID to execute
|
||||
db_path: Path to SQLite database
|
||||
max_workers: Maximum concurrent model executions per date
|
||||
"""
|
||||
self.job_id = job_id
|
||||
self.db_path = db_path
|
||||
self.max_workers = max_workers
|
||||
self.job_manager = JobManager(db_path=db_path)
|
||||
|
||||
logger.info(f"Initialized worker for job {job_id}")
|
||||
|
||||
def run(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the simulation job.
|
||||
|
||||
Returns:
|
||||
Result dict with success status and summary
|
||||
|
||||
Process:
|
||||
1. Get job details (dates, models, config)
|
||||
2. For each date sequentially:
|
||||
a. Execute all models in parallel
|
||||
b. Wait for all to complete
|
||||
c. Update progress
|
||||
3. Determine final job status
|
||||
4. Update job with final status
|
||||
|
||||
Error Handling:
|
||||
- Individual model failures: Mark detail as failed, continue with others
|
||||
- Job-level errors: Mark entire job as failed
|
||||
"""
|
||||
try:
|
||||
# Get job info
|
||||
job = self.job_manager.get_job(self.job_id)
|
||||
if not job:
|
||||
raise ValueError(f"Job {self.job_id} not found")
|
||||
|
||||
date_range = job["date_range"]
|
||||
models = job["models"]
|
||||
config_path = job["config_path"]
|
||||
|
||||
logger.info(f"Starting job {self.job_id}: {len(date_range)} dates, {len(models)} models")
|
||||
|
||||
# Execute date-by-date (sequential)
|
||||
for date in date_range:
|
||||
logger.info(f"Processing date {date} with {len(models)} models")
|
||||
self._execute_date(date, models, config_path)
|
||||
|
||||
# Job completed - determine final status
|
||||
progress = self.job_manager.get_job_progress(self.job_id)
|
||||
|
||||
if progress["failed"] == 0:
|
||||
final_status = "completed"
|
||||
elif progress["completed"] > 0:
|
||||
final_status = "partial"
|
||||
else:
|
||||
final_status = "failed"
|
||||
|
||||
# Note: Job status is already updated by model_day_executor's detail status updates
|
||||
# We don't need to explicitly call update_job_status here as it's handled automatically
|
||||
# by the status transition logic in JobManager.update_job_detail_status
|
||||
|
||||
logger.info(f"Job {self.job_id} finished with status: {final_status}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"job_id": self.job_id,
|
||||
"status": final_status,
|
||||
"total_model_days": progress["total_model_days"],
|
||||
"completed": progress["completed"],
|
||||
"failed": progress["failed"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Job execution failed: {str(e)}"
|
||||
logger.error(f"Job {self.job_id}: {error_msg}", exc_info=True)
|
||||
|
||||
# Update job to failed
|
||||
self.job_manager.update_job_status(self.job_id, "failed", error=error_msg)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"job_id": self.job_id,
|
||||
"error": error_msg
|
||||
}
|
||||
|
||||
def _execute_date(self, date: str, models: List[str], config_path: str) -> None:
|
||||
"""
|
||||
Execute all models for a single date in parallel.
|
||||
|
||||
Args:
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
models: List of model signatures to execute
|
||||
config_path: Path to configuration file
|
||||
|
||||
Uses ThreadPoolExecutor to run all models concurrently for this date.
|
||||
Waits for all models to complete before returning.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# Submit all model executions for this date
|
||||
futures = []
|
||||
for model in models:
|
||||
future = executor.submit(
|
||||
self._execute_model_day,
|
||||
date,
|
||||
model,
|
||||
config_path
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all to complete
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
result = future.result()
|
||||
if result["success"]:
|
||||
logger.debug(f"Completed {result['model']} on {result['date']}")
|
||||
else:
|
||||
logger.warning(f"Failed {result['model']} on {result['date']}: {result.get('error')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in model execution: {e}", exc_info=True)
|
||||
|
||||
def _execute_model_day(self, date: str, model: str, config_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a single model for a single date.
|
||||
|
||||
Args:
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
model: Model signature
|
||||
config_path: Path to configuration file
|
||||
|
||||
Returns:
|
||||
Execution result dict
|
||||
"""
|
||||
try:
|
||||
executor = ModelDayExecutor(
|
||||
job_id=self.job_id,
|
||||
date=date,
|
||||
model_sig=model,
|
||||
config_path=config_path,
|
||||
db_path=self.db_path
|
||||
)
|
||||
|
||||
result = executor.execute()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute {model} on {date}: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"job_id": self.job_id,
|
||||
"date": date,
|
||||
"model": model,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def get_job_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get job information.
|
||||
|
||||
Returns:
|
||||
Job data dict
|
||||
"""
|
||||
return self.job_manager.get_job(self.job_id)
|
||||
Reference in New Issue
Block a user