mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 01:27:24 -04:00
Compare commits
28 Commits
v0.3.0-alp
...
v0.3.0-alp
| Author | SHA1 | Date | |
|---|---|---|---|
| 9813a3c9fd | |||
| 3535746eb7 | |||
| a414ce3597 | |||
| a9dd346b35 | |||
| bdc0cff067 | |||
| a8d2b82149 | |||
| a42487794f | |||
| 139a016a4d | |||
| d355b82268 | |||
| 91ffb7c71e | |||
| 5e5354e2af | |||
| 8c3e08a29b | |||
| 445183d5bf | |||
| 2ab78c8552 | |||
| 88a3c78e07 | |||
| a478165f35 | |||
| 05c2480ac4 | |||
| baa44c208a | |||
| 711ae5df73 | |||
| 15525d05c7 | |||
| 80b22232ad | |||
| 2d47bd7a3a | |||
| 28fbd6d621 | |||
| 7d66f90810 | |||
| c220211c3a | |||
| 7e95ce356b | |||
| 03f81b3b5c | |||
| ebc66481df |
@@ -269,12 +269,14 @@ Get status and progress of a simulation job.
|
||||
| `total_duration_seconds` | float | Total execution time in seconds |
|
||||
| `error` | string | Error message if job failed |
|
||||
| `details` | array[object] | Per model-day execution details |
|
||||
| `warnings` | array[string] | Optional array of non-fatal warning messages |
|
||||
|
||||
**Job Status Values:**
|
||||
|
||||
| Status | Description |
|
||||
|--------|-------------|
|
||||
| `pending` | Job created, waiting to start |
|
||||
| `downloading_data` | Preparing price data (downloading if needed) |
|
||||
| `running` | Job currently executing |
|
||||
| `completed` | All model-days completed successfully |
|
||||
| `partial` | Some model-days completed, some failed |
|
||||
@@ -289,6 +291,35 @@ Get status and progress of a simulation job.
|
||||
| `completed` | Finished successfully |
|
||||
| `failed` | Execution failed (see `error` field) |
|
||||
|
||||
**Warnings Field:**
|
||||
|
||||
The optional `warnings` array contains non-fatal warning messages about the job execution:
|
||||
|
||||
- **Rate limit warnings**: Price data download hit API rate limits
|
||||
- **Skipped dates**: Some dates couldn't be processed due to incomplete price data
|
||||
- **Other issues**: Non-fatal problems that don't prevent job completion
|
||||
|
||||
**Example response with warnings:**
|
||||
|
||||
```json
|
||||
{
|
||||
"job_id": "019a426b-1234-5678-90ab-cdef12345678",
|
||||
"status": "completed",
|
||||
"progress": {
|
||||
"total_model_days": 10,
|
||||
"completed": 8,
|
||||
"failed": 0,
|
||||
"pending": 0
|
||||
},
|
||||
"warnings": [
|
||||
"Rate limit reached - downloaded 12/15 symbols",
|
||||
"Skipped 2 dates due to incomplete price data: ['2025-10-02', '2025-10-05']"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If no warnings occurred, the field will be `null` or omitted.
|
||||
|
||||
**Error Response:**
|
||||
|
||||
**404 Not Found** - Job doesn't exist
|
||||
@@ -729,6 +760,29 @@ Server loads model definitions from configuration file (default: `configs/defaul
|
||||
- `openai_base_url` - Optional custom API endpoint
|
||||
- `openai_api_key` - Optional model-specific API key
|
||||
|
||||
### Configuration Override System
|
||||
|
||||
**Default config:** `/app/configs/default_config.json` (baked into image)
|
||||
|
||||
**Custom config:** `/app/user-configs/config.json` (optional, via volume mount)
|
||||
|
||||
**Merge behavior:**
|
||||
- Custom config sections completely replace default sections (root-level merge)
|
||||
- If no custom config exists, defaults are used
|
||||
- Validation occurs at container startup (before API starts)
|
||||
- Invalid config causes immediate exit with detailed error message
|
||||
|
||||
**Example custom config** (overrides models only):
|
||||
```json
|
||||
{
|
||||
"models": [
|
||||
{"name": "gpt-5", "basemodel": "openai/gpt-5", "signature": "gpt-5", "enabled": true}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
All other sections (`agent_config`, `log_config`, etc.) inherited from default.
|
||||
|
||||
---
|
||||
|
||||
## OpenAPI / Swagger Documentation
|
||||
|
||||
371
DOCKER.md
Normal file
371
DOCKER.md
Normal file
@@ -0,0 +1,371 @@
|
||||
# Docker Deployment Guide
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
- Docker Engine 20.10+
|
||||
- Docker Compose 2.0+
|
||||
- API keys for OpenAI, Alpha Vantage, and Jina AI
|
||||
|
||||
### First-Time Setup
|
||||
|
||||
1. **Clone repository:**
|
||||
```bash
|
||||
git clone https://github.com/Xe138/AI-Trader-Server.git
|
||||
cd AI-Trader-Server
|
||||
```
|
||||
|
||||
2. **Configure environment:**
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your API keys
|
||||
```
|
||||
|
||||
3. **Run with Docker Compose:**
|
||||
```bash
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
That's it! The container will:
|
||||
- Fetch latest price data from Alpha Vantage
|
||||
- Start all MCP services
|
||||
- Run the trading agent with default configuration
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Edit `.env` file with your credentials:
|
||||
|
||||
```bash
|
||||
# Required
|
||||
OPENAI_API_KEY=sk-...
|
||||
ALPHAADVANTAGE_API_KEY=...
|
||||
JINA_API_KEY=...
|
||||
|
||||
# Optional (defaults shown)
|
||||
MATH_HTTP_PORT=8000
|
||||
SEARCH_HTTP_PORT=8001
|
||||
TRADE_HTTP_PORT=8002
|
||||
GETPRICE_HTTP_PORT=8003
|
||||
AGENT_MAX_STEP=30
|
||||
```
|
||||
|
||||
### Custom Trading Configuration
|
||||
|
||||
**Simple Method (Recommended):**
|
||||
|
||||
Create a `configs/custom_config.json` file - it will be automatically used:
|
||||
|
||||
```bash
|
||||
# Copy default config as starting point
|
||||
cp configs/default_config.json configs/custom_config.json
|
||||
|
||||
# Edit your custom config
|
||||
nano configs/custom_config.json
|
||||
|
||||
# Run normally - custom_config.json is automatically detected!
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
**Priority order:**
|
||||
1. `configs/custom_config.json` (if exists) - **Highest priority**
|
||||
2. Command-line argument: `docker-compose run ai-trader-server configs/other.json`
|
||||
3. `configs/default_config.json` (fallback)
|
||||
|
||||
**Advanced: Use a different config file name:**
|
||||
|
||||
```bash
|
||||
docker-compose run ai-trader-server configs/my_special_config.json
|
||||
```
|
||||
|
||||
### Custom Configuration via Volume Mount
|
||||
|
||||
The Docker image includes a default configuration at `configs/default_config.json`. You can override sections of this config by mounting a custom config file.
|
||||
|
||||
**Volume mount:**
|
||||
```yaml
|
||||
volumes:
|
||||
- ./my-configs:/app/user-configs # Contains config.json
|
||||
```
|
||||
|
||||
**Custom config example** (`./my-configs/config.json`):
|
||||
```json
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "gpt-5",
|
||||
"basemodel": "openai/gpt-5",
|
||||
"signature": "gpt-5",
|
||||
"enabled": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
This overrides only the `models` section. All other settings (`agent_config`, `log_config`, etc.) are inherited from the default config.
|
||||
|
||||
**Validation:** Config is validated at container startup. Invalid configs cause immediate exit with detailed error messages.
|
||||
|
||||
**Complete config:** You can also provide a complete config that replaces all default values:
|
||||
```json
|
||||
{
|
||||
"agent_type": "BaseAgent",
|
||||
"date_range": {
|
||||
"init_date": "2025-10-01",
|
||||
"end_date": "2025-10-31"
|
||||
},
|
||||
"models": [...],
|
||||
"agent_config": {...},
|
||||
"log_config": {...}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Run in foreground with logs
|
||||
```bash
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
### Run in background (detached)
|
||||
```bash
|
||||
docker-compose up -d
|
||||
docker-compose logs -f # Follow logs
|
||||
```
|
||||
|
||||
### Run with custom config
|
||||
```bash
|
||||
docker-compose run ai-trader-server configs/custom_config.json
|
||||
```
|
||||
|
||||
### Stop containers
|
||||
```bash
|
||||
docker-compose down
|
||||
```
|
||||
|
||||
### Rebuild after code changes
|
||||
```bash
|
||||
docker-compose build
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
## Data Persistence
|
||||
|
||||
### Volume Mounts
|
||||
|
||||
Docker Compose mounts three volumes for persistent data. By default, these are stored in the project directory:
|
||||
|
||||
- `./data:/app/data` - Price data and trading records
|
||||
- `./logs:/app/logs` - MCP service logs
|
||||
- `./configs:/app/configs` - Configuration files (allows editing configs without rebuilding)
|
||||
|
||||
### Custom Volume Location
|
||||
|
||||
You can change where data is stored by setting `VOLUME_PATH` in your `.env` file:
|
||||
|
||||
```bash
|
||||
# Store data in a different location
|
||||
VOLUME_PATH=/home/user/trading-data
|
||||
|
||||
# Or use a relative path
|
||||
VOLUME_PATH=./volumes
|
||||
```
|
||||
|
||||
This will store data in:
|
||||
- `/home/user/trading-data/data/`
|
||||
- `/home/user/trading-data/logs/`
|
||||
- `/home/user/trading-data/configs/`
|
||||
|
||||
**Note:** The directory structure is automatically created. You'll need to copy your existing configs:
|
||||
```bash
|
||||
# After changing VOLUME_PATH
|
||||
mkdir -p /home/user/trading-data/configs
|
||||
cp configs/custom_config.json /home/user/trading-data/configs/
|
||||
```
|
||||
|
||||
### Reset Data
|
||||
|
||||
To reset all trading data:
|
||||
|
||||
```bash
|
||||
docker-compose down
|
||||
rm -rf ${VOLUME_PATH:-.}/data/agent_data/* ${VOLUME_PATH:-.}/logs/*
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
### Backup Trading Data
|
||||
|
||||
```bash
|
||||
# Backup
|
||||
tar -czf ai-trader-server-backup-$(date +%Y%m%d).tar.gz data/agent_data/
|
||||
|
||||
# Restore
|
||||
tar -xzf ai-trader-server-backup-YYYYMMDD.tar.gz
|
||||
```
|
||||
|
||||
## Using Pre-built Images
|
||||
|
||||
### Pull from GitHub Container Registry
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/xe138/ai-trader-server:latest
|
||||
```
|
||||
|
||||
### Run without Docker Compose
|
||||
|
||||
```bash
|
||||
docker run --env-file .env \
|
||||
-v $(pwd)/data:/app/data \
|
||||
-v $(pwd)/logs:/app/logs \
|
||||
-p 8000-8003:8000-8003 \
|
||||
ghcr.io/xe138/ai-trader-server:latest
|
||||
```
|
||||
|
||||
### Specific version
|
||||
```bash
|
||||
docker pull ghcr.io/xe138/ai-trader-server:v1.0.0
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### MCP Services Not Starting
|
||||
|
||||
**Symptom:** Container exits immediately or errors about ports
|
||||
|
||||
**Solutions:**
|
||||
- Check ports 8000-8003 not already in use: `lsof -i :8000-8003`
|
||||
- View container logs: `docker-compose logs`
|
||||
- Check MCP service logs: `cat logs/math.log`
|
||||
|
||||
### Missing API Keys
|
||||
|
||||
**Symptom:** Errors about missing environment variables
|
||||
|
||||
**Solutions:**
|
||||
- Verify `.env` file exists: `ls -la .env`
|
||||
- Check required variables set: `grep OPENAI_API_KEY .env`
|
||||
- Ensure `.env` in same directory as docker-compose.yml
|
||||
|
||||
### Data Fetch Failures
|
||||
|
||||
**Symptom:** Container exits during data preparation step
|
||||
|
||||
**Solutions:**
|
||||
- Verify Alpha Vantage API key valid
|
||||
- Check API rate limits (5 requests/minute for free tier)
|
||||
- View logs: `docker-compose logs | grep "Fetching and merging"`
|
||||
|
||||
### Permission Issues
|
||||
|
||||
**Symptom:** Cannot write to data or logs directories
|
||||
|
||||
**Solutions:**
|
||||
- Ensure directories writable: `chmod -R 755 data logs`
|
||||
- Check volume mount permissions
|
||||
- May need to create directories first: `mkdir -p data logs`
|
||||
|
||||
### Container Keeps Restarting
|
||||
|
||||
**Symptom:** Container restarts repeatedly
|
||||
|
||||
**Solutions:**
|
||||
- View logs to identify error: `docker-compose logs --tail=50`
|
||||
- Disable auto-restart: Comment out `restart: unless-stopped` in docker-compose.yml
|
||||
- Check if main.py exits with error
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Override Entrypoint
|
||||
|
||||
Run bash inside container for debugging:
|
||||
|
||||
```bash
|
||||
docker-compose run --entrypoint /bin/bash ai-trader-server
|
||||
```
|
||||
|
||||
### Build Multi-platform Images
|
||||
|
||||
For ARM64 (Apple Silicon) and AMD64:
|
||||
|
||||
```bash
|
||||
docker buildx build --platform linux/amd64,linux/arm64 -t ai-trader-server .
|
||||
```
|
||||
|
||||
### View Container Resource Usage
|
||||
|
||||
```bash
|
||||
docker stats ai-trader-server
|
||||
```
|
||||
|
||||
### Access MCP Services Directly
|
||||
|
||||
Services exposed on host:
|
||||
- Math: http://localhost:8000
|
||||
- Search: http://localhost:8001
|
||||
- Trade: http://localhost:8002
|
||||
- Price: http://localhost:8003
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Local Code Changes
|
||||
|
||||
1. Edit code in project root
|
||||
2. Rebuild image: `docker-compose build`
|
||||
3. Run updated container: `docker-compose up`
|
||||
|
||||
### Test Different Configurations
|
||||
|
||||
**Method 1: Use the standard custom_config.json**
|
||||
|
||||
```bash
|
||||
# Create and edit your config
|
||||
cp configs/default_config.json configs/custom_config.json
|
||||
nano configs/custom_config.json
|
||||
|
||||
# Run - automatically uses custom_config.json
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
**Method 2: Test multiple configs with different names**
|
||||
|
||||
```bash
|
||||
# Create multiple test configs
|
||||
cp configs/default_config.json configs/conservative.json
|
||||
cp configs/default_config.json configs/aggressive.json
|
||||
|
||||
# Edit each config...
|
||||
|
||||
# Test conservative strategy
|
||||
docker-compose run ai-trader-server configs/conservative.json
|
||||
|
||||
# Test aggressive strategy
|
||||
docker-compose run ai-trader-server configs/aggressive.json
|
||||
```
|
||||
|
||||
**Method 3: Temporarily switch configs**
|
||||
|
||||
```bash
|
||||
# Temporarily rename your custom config
|
||||
mv configs/custom_config.json configs/custom_config.json.backup
|
||||
cp configs/test_strategy.json configs/custom_config.json
|
||||
|
||||
# Run with test strategy
|
||||
docker-compose up
|
||||
|
||||
# Restore original
|
||||
mv configs/custom_config.json.backup configs/custom_config.json
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
For production use, consider:
|
||||
|
||||
1. **Use specific version tags** instead of `latest`
|
||||
2. **External secrets management** (AWS Secrets Manager, etc.)
|
||||
3. **Health checks** in docker-compose.yml
|
||||
4. **Resource limits** (CPU/memory)
|
||||
5. **Log aggregation** (ELK stack, CloudWatch)
|
||||
6. **Orchestration** (Kubernetes, Docker Swarm)
|
||||
|
||||
See design document in `docs/plans/2025-10-30-docker-deployment-design.md` for architecture details.
|
||||
@@ -54,7 +54,36 @@ JINA_API_KEY=your-jina-key-here
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Start the API Server
|
||||
## Step 3: (Optional) Custom Model Configuration
|
||||
|
||||
To use different AI models than the defaults, create a custom config:
|
||||
|
||||
1. Create config directory:
|
||||
```bash
|
||||
mkdir -p configs
|
||||
```
|
||||
|
||||
2. Create `configs/config.json`:
|
||||
```json
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "my-gpt-4",
|
||||
"basemodel": "openai/gpt-4",
|
||||
"signature": "my-gpt-4",
|
||||
"enabled": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
3. The Docker container will automatically merge this with default settings.
|
||||
|
||||
Your custom config only needs to include sections you want to override.
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Start the API Server
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
@@ -79,7 +108,7 @@ docker logs -f ai-trader-server
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Verify Service is Running
|
||||
## Step 5: Verify Service is Running
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/health
|
||||
@@ -99,7 +128,7 @@ If you see `"status": "healthy"`, you're ready!
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Run Your First Simulation
|
||||
## Step 6: Run Your First Simulation
|
||||
|
||||
Trigger a simulation for a single day with GPT-4:
|
||||
|
||||
@@ -130,7 +159,7 @@ curl -X POST http://localhost:8080/simulate/trigger \
|
||||
|
||||
---
|
||||
|
||||
## Step 6: Monitor Progress
|
||||
## Step 7: Monitor Progress
|
||||
|
||||
```bash
|
||||
# Replace with your job_id from Step 5
|
||||
@@ -175,7 +204,7 @@ curl http://localhost:8080/simulate/status/$JOB_ID
|
||||
|
||||
---
|
||||
|
||||
## Step 7: View Results
|
||||
## Step 8: View Results
|
||||
|
||||
```bash
|
||||
curl "http://localhost:8080/results?job_id=$JOB_ID" | jq '.'
|
||||
|
||||
10
ROADMAP.md
10
ROADMAP.md
@@ -150,7 +150,15 @@ curl -X POST http://localhost:5000/simulate/to-date \
|
||||
- Integration with monitoring systems (Prometheus, Grafana)
|
||||
- Alerting recommendations
|
||||
- Backup and disaster recovery guidance
|
||||
- Database migration strategy
|
||||
- Database migration strategy:
|
||||
- Automated schema migration system for production databases
|
||||
- Support for ALTER TABLE and table recreation when needed
|
||||
- Migration version tracking and rollback capabilities
|
||||
- Zero-downtime migration procedures for production
|
||||
- Data integrity validation before and after migrations
|
||||
- Migration script testing framework
|
||||
- Note: Currently migrations are minimal (pre-production state)
|
||||
- Pre-production recommendation: Delete and recreate databases for schema updates
|
||||
- Upgrade path documentation (v0.x to v1.0)
|
||||
- Version compatibility guarantees going forward
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ def initialize_database(db_path: str = "data/jobs.db") -> None:
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
job_id TEXT PRIMARY KEY,
|
||||
config_path TEXT NOT NULL,
|
||||
status TEXT NOT NULL CHECK(status IN ('pending', 'running', 'completed', 'partial', 'failed')),
|
||||
status TEXT NOT NULL CHECK(status IN ('pending', 'downloading_data', 'running', 'completed', 'partial', 'failed')),
|
||||
date_range TEXT NOT NULL,
|
||||
models TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
@@ -93,7 +93,8 @@ def initialize_database(db_path: str = "data/jobs.db") -> None:
|
||||
updated_at TEXT,
|
||||
completed_at TEXT,
|
||||
total_duration_seconds REAL,
|
||||
error TEXT
|
||||
error TEXT,
|
||||
warnings TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
@@ -285,7 +286,12 @@ def cleanup_dev_database(db_path: str = "data/trading_dev.db", data_path: str =
|
||||
|
||||
|
||||
def _migrate_schema(cursor: sqlite3.Cursor) -> None:
|
||||
"""Migrate existing database schema to latest version."""
|
||||
"""
|
||||
Migrate existing database schema to latest version.
|
||||
|
||||
Note: For pre-production databases, simply delete and recreate.
|
||||
This migration is only for preserving data during development.
|
||||
"""
|
||||
# Check if positions table exists and has simulation_run_id column
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='positions'")
|
||||
if cursor.fetchone():
|
||||
@@ -293,7 +299,6 @@ def _migrate_schema(cursor: sqlite3.Cursor) -> None:
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if 'simulation_run_id' not in columns:
|
||||
# Add simulation_run_id column to existing positions table
|
||||
cursor.execute("""
|
||||
ALTER TABLE positions ADD COLUMN simulation_run_id TEXT
|
||||
""")
|
||||
|
||||
@@ -148,7 +148,7 @@ class JobManager:
|
||||
SELECT
|
||||
job_id, config_path, status, date_range, models,
|
||||
created_at, started_at, updated_at, completed_at,
|
||||
total_duration_seconds, error
|
||||
total_duration_seconds, error, warnings
|
||||
FROM jobs
|
||||
WHERE job_id = ?
|
||||
""", (job_id,))
|
||||
@@ -168,7 +168,8 @@ class JobManager:
|
||||
"updated_at": row[7],
|
||||
"completed_at": row[8],
|
||||
"total_duration_seconds": row[9],
|
||||
"error": row[10]
|
||||
"error": row[10],
|
||||
"warnings": row[11]
|
||||
}
|
||||
|
||||
finally:
|
||||
@@ -189,7 +190,7 @@ class JobManager:
|
||||
SELECT
|
||||
job_id, config_path, status, date_range, models,
|
||||
created_at, started_at, updated_at, completed_at,
|
||||
total_duration_seconds, error
|
||||
total_duration_seconds, error, warnings
|
||||
FROM jobs
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
@@ -210,7 +211,8 @@ class JobManager:
|
||||
"updated_at": row[7],
|
||||
"completed_at": row[8],
|
||||
"total_duration_seconds": row[9],
|
||||
"error": row[10]
|
||||
"error": row[10],
|
||||
"warnings": row[11]
|
||||
}
|
||||
|
||||
finally:
|
||||
@@ -236,7 +238,7 @@ class JobManager:
|
||||
SELECT
|
||||
job_id, config_path, status, date_range, models,
|
||||
created_at, started_at, updated_at, completed_at,
|
||||
total_duration_seconds, error
|
||||
total_duration_seconds, error, warnings
|
||||
FROM jobs
|
||||
WHERE date_range = ?
|
||||
ORDER BY created_at DESC
|
||||
@@ -258,7 +260,8 @@ class JobManager:
|
||||
"updated_at": row[7],
|
||||
"completed_at": row[8],
|
||||
"total_duration_seconds": row[9],
|
||||
"error": row[10]
|
||||
"error": row[10],
|
||||
"warnings": row[11]
|
||||
}
|
||||
|
||||
finally:
|
||||
@@ -327,6 +330,32 @@ class JobManager:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def add_job_warnings(self, job_id: str, warnings: List[str]) -> None:
|
||||
"""
|
||||
Store warnings for a job.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
warnings: List of warning messages
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
warnings_json = json.dumps(warnings)
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE jobs
|
||||
SET warnings = ?
|
||||
WHERE job_id = ?
|
||||
""", (warnings_json, job_id))
|
||||
|
||||
conn.commit()
|
||||
logger.info(f"Added {len(warnings)} warnings to job {job_id}")
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_job_detail_status(
|
||||
self,
|
||||
job_id: str,
|
||||
@@ -575,7 +604,7 @@ class JobManager:
|
||||
SELECT
|
||||
job_id, config_path, status, date_range, models,
|
||||
created_at, started_at, updated_at, completed_at,
|
||||
total_duration_seconds, error
|
||||
total_duration_seconds, error, warnings
|
||||
FROM jobs
|
||||
WHERE status IN ('pending', 'running')
|
||||
ORDER BY created_at DESC
|
||||
@@ -594,7 +623,8 @@ class JobManager:
|
||||
"updated_at": row[7],
|
||||
"completed_at": row[8],
|
||||
"total_duration_seconds": row[9],
|
||||
"error": row[10]
|
||||
"error": row[10],
|
||||
"warnings": row[11]
|
||||
})
|
||||
|
||||
return jobs
|
||||
|
||||
165
api/main.py
165
api/main.py
@@ -21,7 +21,6 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from api.job_manager import JobManager
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.database import get_db_connection
|
||||
from api.price_data_manager import PriceDataManager
|
||||
from api.date_utils import validate_date_range, expand_date_range, get_max_simulation_days
|
||||
from tools.deployment_config import get_deployment_mode_dict, log_dev_mode_startup_warning
|
||||
import threading
|
||||
@@ -74,6 +73,7 @@ class SimulateTriggerResponse(BaseModel):
|
||||
deployment_mode: str
|
||||
is_dev_mode: bool
|
||||
preserve_dev_data: Optional[bool] = None
|
||||
warnings: Optional[List[str]] = None
|
||||
|
||||
|
||||
class JobProgress(BaseModel):
|
||||
@@ -100,6 +100,7 @@ class JobStatusResponse(BaseModel):
|
||||
deployment_mode: str
|
||||
is_dev_mode: bool
|
||||
preserve_dev_data: Optional[bool] = None
|
||||
warnings: Optional[List[str]] = None
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
@@ -146,18 +147,16 @@ def create_app(
|
||||
"""
|
||||
Trigger a new simulation job.
|
||||
|
||||
Validates date range, downloads missing price data if needed,
|
||||
and creates job with available trading dates.
|
||||
Validates date range and creates job. Price data is downloaded
|
||||
in background by SimulationWorker.
|
||||
|
||||
Supports:
|
||||
- Single date: start_date == end_date
|
||||
- Date range: start_date < end_date
|
||||
- Resume: start_date is null (each model resumes from its last completed date)
|
||||
- Idempotent: replace_existing=false skips already completed model-days
|
||||
|
||||
Raises:
|
||||
HTTPException 400: Validation errors, running job, or invalid dates
|
||||
HTTPException 503: Price data download failed
|
||||
"""
|
||||
try:
|
||||
# Use config path from app state
|
||||
@@ -199,6 +198,7 @@ def create_app(
|
||||
# Handle resume logic (start_date is null)
|
||||
if request.start_date is None:
|
||||
# Resume mode: determine start date per model
|
||||
from datetime import timedelta
|
||||
model_start_dates = {}
|
||||
|
||||
for model in models_to_run:
|
||||
@@ -225,112 +225,6 @@ def create_app(
|
||||
max_days = get_max_simulation_days()
|
||||
validate_date_range(start_date, end_date, max_days=max_days)
|
||||
|
||||
# Check price data and download if needed
|
||||
auto_download = os.getenv("AUTO_DOWNLOAD_PRICE_DATA", "true").lower() == "true"
|
||||
price_manager = PriceDataManager(db_path=app.state.db_path)
|
||||
|
||||
# Check what's missing (use computed start_date, not request.start_date which may be None)
|
||||
missing_coverage = price_manager.get_missing_coverage(
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
download_info = None
|
||||
|
||||
# Download missing data if enabled
|
||||
if any(missing_coverage.values()):
|
||||
if not auto_download:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Missing price data for {len(missing_coverage)} symbols and auto-download is disabled. "
|
||||
f"Enable AUTO_DOWNLOAD_PRICE_DATA or pre-populate data."
|
||||
)
|
||||
|
||||
logger.info(f"Downloading missing price data for {len(missing_coverage)} symbols")
|
||||
|
||||
requested_dates = set(expand_date_range(start_date, end_date))
|
||||
|
||||
download_result = price_manager.download_missing_data_prioritized(
|
||||
missing_coverage,
|
||||
requested_dates
|
||||
)
|
||||
|
||||
if not download_result["success"]:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Failed to download any price data. Check ALPHAADVANTAGE_API_KEY."
|
||||
)
|
||||
|
||||
download_info = {
|
||||
"symbols_downloaded": len(download_result["downloaded"]),
|
||||
"symbols_failed": len(download_result["failed"]),
|
||||
"rate_limited": download_result["rate_limited"]
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Downloaded {len(download_result['downloaded'])} symbols, "
|
||||
f"{len(download_result['failed'])} failed, rate_limited={download_result['rate_limited']}"
|
||||
)
|
||||
|
||||
# Get available trading dates (after potential download)
|
||||
available_dates = price_manager.get_available_trading_dates(
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
if not available_dates:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No trading dates with complete price data in range "
|
||||
f"{start_date} to {end_date}. "
|
||||
f"All symbols must have data for a date to be tradeable."
|
||||
)
|
||||
|
||||
# Handle idempotent behavior (skip already completed model-days)
|
||||
if not request.replace_existing:
|
||||
# Get existing completed dates per model
|
||||
completed_dates = job_manager.get_completed_model_dates(
|
||||
models_to_run,
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
# Build list of model-day tuples to simulate
|
||||
model_day_tasks = []
|
||||
for model in models_to_run:
|
||||
# Filter dates for this model
|
||||
model_start = model_start_dates[model]
|
||||
|
||||
for date in available_dates:
|
||||
# Skip if before model's start date
|
||||
if date < model_start:
|
||||
continue
|
||||
|
||||
# Skip if already completed (idempotent)
|
||||
if date in completed_dates.get(model, []):
|
||||
continue
|
||||
|
||||
model_day_tasks.append((model, date))
|
||||
|
||||
if not model_day_tasks:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No new model-days to simulate. All requested dates are already completed. "
|
||||
"Use replace_existing=true to re-run."
|
||||
)
|
||||
|
||||
# Extract unique dates that will actually be run
|
||||
dates_to_run = sorted(list(set([date for _, date in model_day_tasks])))
|
||||
else:
|
||||
# Replace mode: run all model-date combinations
|
||||
dates_to_run = available_dates
|
||||
model_day_tasks = [
|
||||
(model, date)
|
||||
for model in models_to_run
|
||||
for date in available_dates
|
||||
if date >= model_start_dates[model]
|
||||
]
|
||||
|
||||
# Check if can start new job
|
||||
if not job_manager.can_start_new_job():
|
||||
raise HTTPException(
|
||||
@@ -338,13 +232,16 @@ def create_app(
|
||||
detail="Another simulation job is already running or pending. Please wait for it to complete."
|
||||
)
|
||||
|
||||
# Create job with dates that will be run
|
||||
# Pass model_day_tasks to only create job_details for tasks that will actually run
|
||||
# Get all weekdays in range (worker will filter based on data availability)
|
||||
all_dates = expand_date_range(start_date, end_date)
|
||||
|
||||
# Create job immediately with all requested dates
|
||||
# Worker will handle data download and filtering
|
||||
job_id = job_manager.create_job(
|
||||
config_path=config_path,
|
||||
date_range=dates_to_run,
|
||||
date_range=all_dates,
|
||||
models=models_to_run,
|
||||
model_day_filter=model_day_tasks
|
||||
model_day_filter=None # Worker will filter based on available data
|
||||
)
|
||||
|
||||
# Start worker in background thread (only if not in test mode)
|
||||
@@ -356,26 +253,13 @@ def create_app(
|
||||
thread = threading.Thread(target=run_worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Triggered simulation job {job_id} with {len(model_day_tasks)} model-day tasks")
|
||||
logger.info(f"Triggered simulation job {job_id} for {len(all_dates)} dates, {len(models_to_run)} models")
|
||||
|
||||
# Build response message
|
||||
total_model_days = len(model_day_tasks)
|
||||
message_parts = [f"Simulation job created with {total_model_days} model-day tasks"]
|
||||
message = f"Simulation job created for {len(all_dates)} dates, {len(models_to_run)} models"
|
||||
|
||||
if request.start_date is None:
|
||||
message_parts.append("(resume mode)")
|
||||
|
||||
if not request.replace_existing:
|
||||
# Calculate how many were skipped
|
||||
total_possible = len(models_to_run) * len(available_dates)
|
||||
skipped = total_possible - total_model_days
|
||||
if skipped > 0:
|
||||
message_parts.append(f"({skipped} already completed, skipped)")
|
||||
|
||||
if download_info and download_info["rate_limited"]:
|
||||
message_parts.append("(rate limit reached - partial data)")
|
||||
|
||||
message = " ".join(message_parts)
|
||||
message += " (resume mode)"
|
||||
|
||||
# Get deployment mode info
|
||||
deployment_info = get_deployment_mode_dict()
|
||||
@@ -383,16 +267,11 @@ def create_app(
|
||||
response = SimulateTriggerResponse(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
total_model_days=total_model_days,
|
||||
total_model_days=len(all_dates) * len(models_to_run),
|
||||
message=message,
|
||||
**deployment_info
|
||||
)
|
||||
|
||||
# Add download info if we downloaded
|
||||
if download_info:
|
||||
# Note: Need to add download_info field to response model
|
||||
logger.info(f"Download info: {download_info}")
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
@@ -413,7 +292,7 @@ def create_app(
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
Job status, progress, and model-day details
|
||||
Job status, progress, model-day details, and warnings
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If job not found
|
||||
@@ -435,6 +314,15 @@ def create_app(
|
||||
# Calculate pending (total - completed - failed)
|
||||
pending = progress["total_model_days"] - progress["completed"] - progress["failed"]
|
||||
|
||||
# Parse warnings from JSON if present
|
||||
import json
|
||||
warnings = None
|
||||
if job.get("warnings"):
|
||||
try:
|
||||
warnings = json.loads(job["warnings"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"Failed to parse warnings for job {job_id}")
|
||||
|
||||
# Get deployment mode info
|
||||
deployment_info = get_deployment_mode_dict()
|
||||
|
||||
@@ -455,6 +343,7 @@ def create_app(
|
||||
total_duration_seconds=job.get("total_duration_seconds"),
|
||||
error=job.get("error"),
|
||||
details=details,
|
||||
warnings=warnings,
|
||||
**deployment_info
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ This module provides:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Set
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from api.job_manager import JobManager
|
||||
@@ -65,12 +65,13 @@ class SimulationWorker:
|
||||
|
||||
Process:
|
||||
1. Get job details (dates, models, config)
|
||||
2. For each date sequentially:
|
||||
2. Prepare data (download if needed)
|
||||
3. For each date sequentially:
|
||||
a. Execute all models in parallel
|
||||
b. Wait for all to complete
|
||||
c. Update progress
|
||||
3. Determine final job status
|
||||
4. Update job with final status
|
||||
4. Determine final job status
|
||||
5. Store warnings if any
|
||||
|
||||
Error Handling:
|
||||
- Individual model failures: Mark detail as failed, continue with others
|
||||
@@ -88,8 +89,16 @@ class SimulationWorker:
|
||||
|
||||
logger.info(f"Starting job {self.job_id}: {len(date_range)} dates, {len(models)} models")
|
||||
|
||||
# Execute date-by-date (sequential)
|
||||
for date in date_range:
|
||||
# NEW: Prepare price data (download if needed)
|
||||
available_dates, warnings = self._prepare_data(date_range, models, config_path)
|
||||
|
||||
if not available_dates:
|
||||
error_msg = "No trading dates available after price data preparation"
|
||||
self.job_manager.update_job_status(self.job_id, "failed", error=error_msg)
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
# Execute available dates only
|
||||
for date in available_dates:
|
||||
logger.info(f"Processing date {date} with {len(models)} models")
|
||||
self._execute_date(date, models, config_path)
|
||||
|
||||
@@ -103,6 +112,10 @@ class SimulationWorker:
|
||||
else:
|
||||
final_status = "failed"
|
||||
|
||||
# Add warnings if any dates were skipped
|
||||
if warnings:
|
||||
self._add_job_warnings(warnings)
|
||||
|
||||
# Note: Job status is already updated by model_day_executor's detail status updates
|
||||
# We don't need to explicitly call update_job_status here as it's handled automatically
|
||||
# by the status transition logic in JobManager.update_job_detail_status
|
||||
@@ -115,7 +128,8 @@ class SimulationWorker:
|
||||
"status": final_status,
|
||||
"total_model_days": progress["total_model_days"],
|
||||
"completed": progress["completed"],
|
||||
"failed": progress["failed"]
|
||||
"failed": progress["failed"],
|
||||
"warnings": warnings
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -200,6 +214,158 @@ class SimulationWorker:
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _download_price_data(
|
||||
self,
|
||||
price_manager,
|
||||
missing_coverage: Dict[str, Set[str]],
|
||||
requested_dates: List[str],
|
||||
warnings: List[str]
|
||||
) -> None:
|
||||
"""Download missing price data with progress logging."""
|
||||
logger.info(f"Job {self.job_id}: Starting prioritized download...")
|
||||
|
||||
requested_dates_set = set(requested_dates)
|
||||
|
||||
download_result = price_manager.download_missing_data_prioritized(
|
||||
missing_coverage,
|
||||
requested_dates_set
|
||||
)
|
||||
|
||||
downloaded = len(download_result["downloaded"])
|
||||
failed = len(download_result["failed"])
|
||||
total = downloaded + failed
|
||||
|
||||
logger.info(
|
||||
f"Job {self.job_id}: Download complete - "
|
||||
f"{downloaded}/{total} symbols succeeded"
|
||||
)
|
||||
|
||||
if download_result["rate_limited"]:
|
||||
msg = f"Rate limit reached - downloaded {downloaded}/{total} symbols"
|
||||
warnings.append(msg)
|
||||
logger.warning(f"Job {self.job_id}: {msg}")
|
||||
|
||||
if failed > 0 and not download_result["rate_limited"]:
|
||||
msg = f"{failed} symbols failed to download"
|
||||
warnings.append(msg)
|
||||
logger.warning(f"Job {self.job_id}: {msg}")
|
||||
|
||||
def _filter_completed_dates(
|
||||
self,
|
||||
available_dates: List[str],
|
||||
models: List[str]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Filter out dates that are already completed for all models.
|
||||
|
||||
Implements idempotent job behavior - skip model-days that already
|
||||
have completed data.
|
||||
|
||||
Args:
|
||||
available_dates: List of dates with complete price data
|
||||
models: List of model signatures
|
||||
|
||||
Returns:
|
||||
List of dates that need processing
|
||||
"""
|
||||
if not available_dates:
|
||||
return []
|
||||
|
||||
# Get completed dates from job_manager
|
||||
start_date = available_dates[0]
|
||||
end_date = available_dates[-1]
|
||||
|
||||
completed_dates = self.job_manager.get_completed_model_dates(
|
||||
models,
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
# Build list of dates that need processing
|
||||
dates_to_process = []
|
||||
for date in available_dates:
|
||||
# Check if any model needs this date
|
||||
needs_processing = False
|
||||
for model in models:
|
||||
if date not in completed_dates.get(model, []):
|
||||
needs_processing = True
|
||||
break
|
||||
|
||||
if needs_processing:
|
||||
dates_to_process.append(date)
|
||||
|
||||
return dates_to_process
|
||||
|
||||
def _add_job_warnings(self, warnings: List[str]) -> None:
|
||||
"""Store warnings in job metadata."""
|
||||
self.job_manager.add_job_warnings(self.job_id, warnings)
|
||||
|
||||
def _prepare_data(
|
||||
self,
|
||||
requested_dates: List[str],
|
||||
models: List[str],
|
||||
config_path: str
|
||||
) -> tuple:
|
||||
"""
|
||||
Prepare price data for simulation.
|
||||
|
||||
Steps:
|
||||
1. Update job status to "downloading_data"
|
||||
2. Check what data is missing
|
||||
3. Download missing data (with rate limit handling)
|
||||
4. Determine available trading dates
|
||||
5. Filter out already-completed model-days (idempotent)
|
||||
6. Update job status to "running"
|
||||
|
||||
Args:
|
||||
requested_dates: All dates requested for simulation
|
||||
models: Model signatures to simulate
|
||||
config_path: Path to configuration file
|
||||
|
||||
Returns:
|
||||
Tuple of (available_dates, warnings)
|
||||
"""
|
||||
from api.price_data_manager import PriceDataManager
|
||||
|
||||
warnings = []
|
||||
|
||||
# Update status
|
||||
self.job_manager.update_job_status(self.job_id, "downloading_data")
|
||||
logger.info(f"Job {self.job_id}: Checking price data availability...")
|
||||
|
||||
# Initialize price manager
|
||||
price_manager = PriceDataManager(db_path=self.db_path)
|
||||
|
||||
# Check missing coverage
|
||||
start_date = requested_dates[0]
|
||||
end_date = requested_dates[-1]
|
||||
missing_coverage = price_manager.get_missing_coverage(start_date, end_date)
|
||||
|
||||
# Download if needed
|
||||
if missing_coverage:
|
||||
logger.info(f"Job {self.job_id}: Missing data for {len(missing_coverage)} symbols")
|
||||
self._download_price_data(price_manager, missing_coverage, requested_dates, warnings)
|
||||
else:
|
||||
logger.info(f"Job {self.job_id}: All price data available")
|
||||
|
||||
# Get available dates after download
|
||||
available_dates = price_manager.get_available_trading_dates(start_date, end_date)
|
||||
|
||||
# Warn about skipped dates
|
||||
skipped = set(requested_dates) - set(available_dates)
|
||||
if skipped:
|
||||
warnings.append(f"Skipped {len(skipped)} dates due to incomplete price data: {sorted(list(skipped))}")
|
||||
logger.warning(f"Job {self.job_id}: {warnings[-1]}")
|
||||
|
||||
# Filter already-completed model-days (idempotent behavior)
|
||||
available_dates = self._filter_completed_dates(available_dates, models)
|
||||
|
||||
# Update to running
|
||||
self.job_manager.update_job_status(self.job_id, "running")
|
||||
logger.info(f"Job {self.job_id}: Starting execution - {len(available_dates)} dates, {len(models)} models")
|
||||
|
||||
return available_dates, warnings
|
||||
|
||||
def get_job_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get job information.
|
||||
|
||||
@@ -8,7 +8,8 @@ services:
|
||||
volumes:
|
||||
- ${VOLUME_PATH:-.}/data:/app/data
|
||||
- ${VOLUME_PATH:-.}/logs:/app/logs
|
||||
- ${VOLUME_PATH:-.}/configs:/app/configs
|
||||
# User configs mounted to /app/user-configs (default config baked into image)
|
||||
- ${VOLUME_PATH:-.}/configs:/app/user-configs
|
||||
environment:
|
||||
# Deployment Configuration
|
||||
- DEPLOYMENT_MODE=${DEPLOYMENT_MODE:-PROD}
|
||||
|
||||
532
docs/plans/2025-11-01-async-price-download-design.md
Normal file
532
docs/plans/2025-11-01-async-price-download-design.md
Normal file
@@ -0,0 +1,532 @@
|
||||
# Async Price Data Download Design
|
||||
|
||||
**Date:** 2025-11-01
|
||||
**Status:** Approved
|
||||
**Problem:** `/simulate/trigger` endpoint times out (30s+) when downloading missing price data
|
||||
|
||||
## Problem Statement
|
||||
|
||||
The `/simulate/trigger` API endpoint currently downloads missing price data synchronously within the HTTP request handler. This causes:
|
||||
- HTTP timeouts when downloads take >30 seconds
|
||||
- Poor user experience (long wait for job_id)
|
||||
- Blocking behavior that doesn't match async job pattern
|
||||
|
||||
## Solution Overview
|
||||
|
||||
Move price data download from the HTTP endpoint to the background worker thread, enabling:
|
||||
- Fast API response (<1 second)
|
||||
- Background data preparation with progress visibility
|
||||
- Graceful handling of rate limits and partial downloads
|
||||
|
||||
## Architecture Changes
|
||||
|
||||
### Current Flow
|
||||
```
|
||||
POST /simulate/trigger → Download price data (30s+) → Create job → Return job_id
|
||||
```
|
||||
|
||||
### New Flow
|
||||
```
|
||||
POST /simulate/trigger → Quick validation → Create job → Return job_id (<1s)
|
||||
↓
|
||||
Background worker → Download missing data → Execute trading → Complete
|
||||
```
|
||||
|
||||
### Status Progression
|
||||
```
|
||||
pending → downloading_data → running → completed (with optional warnings)
|
||||
↓
|
||||
failed (if download fails completely)
|
||||
```
|
||||
|
||||
## Component Changes
|
||||
|
||||
### 1. API Endpoint (`api/main.py`)
|
||||
|
||||
**Remove:**
|
||||
- Price data availability checks (lines 228-287)
|
||||
- `PriceDataManager.get_missing_coverage()`
|
||||
- `PriceDataManager.download_missing_data_prioritized()`
|
||||
- `PriceDataManager.get_available_trading_dates()`
|
||||
- Idempotent filtering logic (move to worker)
|
||||
|
||||
**Keep:**
|
||||
- Date format validation
|
||||
- Job creation
|
||||
- Worker thread startup
|
||||
|
||||
**New Logic:**
|
||||
```python
|
||||
# Quick validation only
|
||||
validate_date_range(start_date, end_date, max_days=max_days)
|
||||
|
||||
# Check if can start new job
|
||||
if not job_manager.can_start_new_job():
|
||||
raise HTTPException(status_code=400, detail="...")
|
||||
|
||||
# Create job immediately with all requested dates
|
||||
job_id = job_manager.create_job(
|
||||
config_path=config_path,
|
||||
date_range=expand_date_range(start_date, end_date), # All weekdays
|
||||
models=models_to_run,
|
||||
model_day_filter=None # Worker will filter
|
||||
)
|
||||
|
||||
# Start worker thread (existing code)
|
||||
```
|
||||
|
||||
### 2. Simulation Worker (`api/simulation_worker.py`)
|
||||
|
||||
**New Method: `_prepare_data()`**
|
||||
|
||||
Encapsulates data preparation phase:
|
||||
|
||||
```python
|
||||
def _prepare_data(
|
||||
self,
|
||||
requested_dates: List[str],
|
||||
models: List[str],
|
||||
config_path: str
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Prepare price data for simulation.
|
||||
|
||||
Steps:
|
||||
1. Update job status to "downloading_data"
|
||||
2. Check what data is missing
|
||||
3. Download missing data (with rate limit handling)
|
||||
4. Determine available trading dates
|
||||
5. Filter out already-completed model-days (idempotent)
|
||||
6. Update job status to "running"
|
||||
|
||||
Returns:
|
||||
(available_dates, warnings)
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# Update status
|
||||
self.job_manager.update_job_status(self.job_id, "downloading_data")
|
||||
logger.info(f"Job {self.job_id}: Checking price data availability...")
|
||||
|
||||
# Initialize price manager
|
||||
price_manager = PriceDataManager(db_path=self.db_path)
|
||||
|
||||
# Check missing coverage
|
||||
start_date = requested_dates[0]
|
||||
end_date = requested_dates[-1]
|
||||
missing_coverage = price_manager.get_missing_coverage(start_date, end_date)
|
||||
|
||||
# Download if needed
|
||||
if missing_coverage:
|
||||
logger.info(f"Job {self.job_id}: Missing data for {len(missing_coverage)} symbols")
|
||||
self._download_price_data(price_manager, missing_coverage, requested_dates, warnings)
|
||||
else:
|
||||
logger.info(f"Job {self.job_id}: All price data available")
|
||||
|
||||
# Get available dates after download
|
||||
available_dates = price_manager.get_available_trading_dates(start_date, end_date)
|
||||
|
||||
# Warn about skipped dates
|
||||
skipped = set(requested_dates) - set(available_dates)
|
||||
if skipped:
|
||||
warnings.append(f"Skipped {len(skipped)} dates due to incomplete price data: {sorted(skipped)}")
|
||||
logger.warning(f"Job {self.job_id}: {warnings[-1]}")
|
||||
|
||||
# Filter already-completed model-days (idempotent behavior)
|
||||
available_dates = self._filter_completed_dates(available_dates, models)
|
||||
|
||||
# Update to running
|
||||
self.job_manager.update_job_status(self.job_id, "running")
|
||||
logger.info(f"Job {self.job_id}: Starting execution - {len(available_dates)} dates, {len(models)} models")
|
||||
|
||||
return available_dates, warnings
|
||||
```
|
||||
|
||||
**New Method: `_download_price_data()`**
|
||||
|
||||
Handles download with progress logging:
|
||||
|
||||
```python
|
||||
def _download_price_data(
|
||||
self,
|
||||
price_manager: PriceDataManager,
|
||||
missing_coverage: Dict[str, Set[str]],
|
||||
requested_dates: List[str],
|
||||
warnings: List[str]
|
||||
) -> None:
|
||||
"""Download missing price data with progress logging."""
|
||||
|
||||
logger.info(f"Job {self.job_id}: Starting prioritized download...")
|
||||
|
||||
requested_dates_set = set(requested_dates)
|
||||
|
||||
download_result = price_manager.download_missing_data_prioritized(
|
||||
missing_coverage,
|
||||
requested_dates_set
|
||||
)
|
||||
|
||||
downloaded = len(download_result["downloaded"])
|
||||
failed = len(download_result["failed"])
|
||||
total = downloaded + failed
|
||||
|
||||
logger.info(
|
||||
f"Job {self.job_id}: Download complete - "
|
||||
f"{downloaded}/{total} symbols succeeded"
|
||||
)
|
||||
|
||||
if download_result["rate_limited"]:
|
||||
msg = f"Rate limit reached - downloaded {downloaded}/{total} symbols"
|
||||
warnings.append(msg)
|
||||
logger.warning(f"Job {self.job_id}: {msg}")
|
||||
|
||||
if failed > 0 and not download_result["rate_limited"]:
|
||||
msg = f"{failed} symbols failed to download"
|
||||
warnings.append(msg)
|
||||
logger.warning(f"Job {self.job_id}: {msg}")
|
||||
```
|
||||
|
||||
**New Method: `_filter_completed_dates()`**
|
||||
|
||||
Implements idempotent behavior:
|
||||
|
||||
```python
|
||||
def _filter_completed_dates(
|
||||
self,
|
||||
available_dates: List[str],
|
||||
models: List[str]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Filter out dates that are already completed for all models.
|
||||
|
||||
Implements idempotent job behavior - skip model-days that already
|
||||
have completed data.
|
||||
"""
|
||||
# Get completed dates from job_manager
|
||||
start_date = available_dates[0]
|
||||
end_date = available_dates[-1]
|
||||
|
||||
completed_dates = self.job_manager.get_completed_model_dates(
|
||||
models,
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
# Build list of dates that need processing
|
||||
dates_to_process = []
|
||||
for date in available_dates:
|
||||
# Check if any model needs this date
|
||||
needs_processing = False
|
||||
for model in models:
|
||||
if date not in completed_dates.get(model, []):
|
||||
needs_processing = True
|
||||
break
|
||||
|
||||
if needs_processing:
|
||||
dates_to_process.append(date)
|
||||
|
||||
return dates_to_process
|
||||
```
|
||||
|
||||
**New Method: `_add_job_warnings()`**
|
||||
|
||||
Store warnings in job metadata:
|
||||
|
||||
```python
|
||||
def _add_job_warnings(self, warnings: List[str]) -> None:
|
||||
"""Store warnings in job metadata."""
|
||||
self.job_manager.add_job_warnings(self.job_id, warnings)
|
||||
```
|
||||
|
||||
**Modified: `run()` method**
|
||||
|
||||
```python
|
||||
def run(self) -> Dict[str, Any]:
|
||||
try:
|
||||
job = self.job_manager.get_job(self.job_id)
|
||||
if not job:
|
||||
raise ValueError(f"Job {self.job_id} not found")
|
||||
|
||||
date_range = job["date_range"]
|
||||
models = job["models"]
|
||||
config_path = job["config_path"]
|
||||
|
||||
logger.info(f"Starting job {self.job_id}: {len(date_range)} dates, {len(models)} models")
|
||||
|
||||
# NEW: Prepare price data (download if needed)
|
||||
available_dates, warnings = self._prepare_data(date_range, models, config_path)
|
||||
|
||||
if not available_dates:
|
||||
error_msg = "No trading dates available after price data preparation"
|
||||
self.job_manager.update_job_status(self.job_id, "failed", error=error_msg)
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
# Execute available dates only
|
||||
for date in available_dates:
|
||||
logger.info(f"Processing date {date} with {len(models)} models")
|
||||
self._execute_date(date, models, config_path)
|
||||
|
||||
# Determine final status
|
||||
progress = self.job_manager.get_job_progress(self.job_id)
|
||||
|
||||
if progress["failed"] == 0:
|
||||
final_status = "completed"
|
||||
elif progress["completed"] > 0:
|
||||
final_status = "partial"
|
||||
else:
|
||||
final_status = "failed"
|
||||
|
||||
# Add warnings if any dates were skipped
|
||||
if warnings:
|
||||
self._add_job_warnings(warnings)
|
||||
|
||||
logger.info(f"Job {self.job_id} finished with status: {final_status}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"job_id": self.job_id,
|
||||
"status": final_status,
|
||||
"total_model_days": progress["total_model_days"],
|
||||
"completed": progress["completed"],
|
||||
"failed": progress["failed"],
|
||||
"warnings": warnings
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Job execution failed: {str(e)}"
|
||||
logger.error(f"Job {self.job_id}: {error_msg}", exc_info=True)
|
||||
self.job_manager.update_job_status(self.job_id, "failed", error=error_msg)
|
||||
return {"success": False, "job_id": self.job_id, "error": error_msg}
|
||||
```
|
||||
|
||||
### 3. Job Manager (`api/job_manager.py`)
|
||||
|
||||
**Verify Status Support:**
|
||||
- Ensure "downloading_data" status is allowed in database schema
|
||||
- Verify status transition logic supports: `pending → downloading_data → running`
|
||||
|
||||
**New Method: `add_job_warnings()`**
|
||||
|
||||
```python
|
||||
def add_job_warnings(self, job_id: str, warnings: List[str]) -> None:
|
||||
"""
|
||||
Store warnings for a job.
|
||||
|
||||
Implementation options:
|
||||
1. Add 'warnings' JSON column to jobs table
|
||||
2. Store in existing metadata field
|
||||
3. Create separate warnings table
|
||||
"""
|
||||
# To be implemented based on schema preference
|
||||
pass
|
||||
```
|
||||
|
||||
### 4. Response Models (`api/main.py`)
|
||||
|
||||
**Add warnings field:**
|
||||
|
||||
```python
|
||||
class SimulateTriggerResponse(BaseModel):
|
||||
job_id: str
|
||||
status: str
|
||||
total_model_days: int
|
||||
message: str
|
||||
deployment_mode: str
|
||||
is_dev_mode: bool
|
||||
preserve_dev_data: Optional[bool] = None
|
||||
warnings: Optional[List[str]] = None # NEW
|
||||
|
||||
class JobStatusResponse(BaseModel):
|
||||
job_id: str
|
||||
status: str
|
||||
progress: JobProgress
|
||||
date_range: List[str]
|
||||
models: List[str]
|
||||
created_at: str
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
total_duration_seconds: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
details: List[Dict[str, Any]]
|
||||
deployment_mode: str
|
||||
is_dev_mode: bool
|
||||
preserve_dev_data: Optional[bool] = None
|
||||
warnings: Optional[List[str]] = None # NEW
|
||||
```
|
||||
|
||||
## Logging Strategy
|
||||
|
||||
### Progress Visibility
|
||||
|
||||
Enhanced logging for monitoring via `docker logs -f`:
|
||||
|
||||
```python
|
||||
# At download start
|
||||
logger.info(f"Job {job_id}: Checking price data availability...")
|
||||
logger.info(f"Job {job_id}: Missing data for {len(missing_symbols)} symbols")
|
||||
logger.info(f"Job {job_id}: Starting prioritized download...")
|
||||
|
||||
# Download completion
|
||||
logger.info(f"Job {job_id}: Download complete - {downloaded}/{total} symbols succeeded")
|
||||
logger.warning(f"Job {job_id}: Rate limited - proceeding with available dates")
|
||||
|
||||
# Execution start
|
||||
logger.info(f"Job {job_id}: Starting execution - {len(dates)} dates, {len(models)} models")
|
||||
logger.info(f"Job {job_id}: Processing date {date} with {len(models)} models")
|
||||
```
|
||||
|
||||
### DEV Mode Enhancement
|
||||
|
||||
```python
|
||||
if DEPLOYMENT_MODE == "DEV":
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.info("🔧 DEV MODE: Enhanced logging enabled")
|
||||
```
|
||||
|
||||
### Example Console Output
|
||||
|
||||
```
|
||||
Job 019a426b: Checking price data availability...
|
||||
Job 019a426b: Missing data for 15 symbols
|
||||
Job 019a426b: Starting prioritized download...
|
||||
Job 019a426b: Download complete - 12/15 symbols succeeded
|
||||
Job 019a426b: Rate limit reached - downloaded 12/15 symbols
|
||||
Job 019a426b: Skipped 2 dates due to incomplete price data: ['2025-10-02', '2025-10-05']
|
||||
Job 019a426b: Starting execution - 8 dates, 1 models
|
||||
Job 019a426b: Processing date 2025-10-01 with 1 models
|
||||
Job 019a426b: Processing date 2025-10-03 with 1 models
|
||||
...
|
||||
Job 019a426b: Job finished with status: completed
|
||||
```
|
||||
|
||||
## Behavior Specifications
|
||||
|
||||
### Rate Limit Handling
|
||||
|
||||
**Option B (Approved):** Run with available data
|
||||
- Download symbols in priority order (most date-completing first)
|
||||
- When rate limited, proceed with dates that have complete data
|
||||
- Add warning to job response
|
||||
- Mark job as "completed" (not "failed") if any dates processed
|
||||
- Log skipped dates for visibility
|
||||
|
||||
### Job Status Communication
|
||||
|
||||
**Option B (Approved):** Status "completed" with warnings
|
||||
- Status = "completed" means "successfully processed all processable dates"
|
||||
- Warnings field communicates skipped dates
|
||||
- Consistent with existing skip-incomplete-data behavior
|
||||
- Doesn't penalize users for rate limits
|
||||
|
||||
### Progress Visibility
|
||||
|
||||
**Option A (Approved):** Job status field
|
||||
- New status: "downloading_data"
|
||||
- Appears in `/simulate/status/{job_id}` responses
|
||||
- Clear distinction between phases:
|
||||
- `pending`: Job queued, not started
|
||||
- `downloading_data`: Preparing price data
|
||||
- `running`: Executing trades
|
||||
- `completed`: Finished successfully
|
||||
- `partial`: Some model-days failed
|
||||
- `failed`: Job-level failure
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Test Cases
|
||||
|
||||
1. **Fast path** - All data present
|
||||
- Request simulation with existing data
|
||||
- Expect <1s response with job_id
|
||||
- Verify status goes: pending → running → completed
|
||||
|
||||
2. **Download path** - Missing data
|
||||
- Request simulation with missing price data
|
||||
- Expect <1s response with job_id
|
||||
- Verify status goes: pending → downloading_data → running → completed
|
||||
- Check `docker logs -f` shows download progress
|
||||
|
||||
3. **Rate limit handling**
|
||||
- Trigger rate limit during download
|
||||
- Verify job completes with warnings
|
||||
- Verify partial dates processed
|
||||
- Verify status = "completed" (not "failed")
|
||||
|
||||
4. **Complete failure**
|
||||
- Simulate download failure (invalid API key)
|
||||
- Verify job status = "failed"
|
||||
- Verify error message in response
|
||||
|
||||
5. **Idempotent behavior**
|
||||
- Request same date range twice
|
||||
- Verify second request skips completed model-days
|
||||
- Verify no duplicate executions
|
||||
|
||||
### Integration Test Example
|
||||
|
||||
```python
|
||||
def test_async_download_with_missing_data():
|
||||
"""Test that missing data is downloaded in background."""
|
||||
# Trigger simulation
|
||||
response = requests.post("http://localhost:8080/simulate/trigger", json={
|
||||
"start_date": "2025-10-01",
|
||||
"end_date": "2025-10-01",
|
||||
"models": ["gpt-5"]
|
||||
})
|
||||
|
||||
# Should return immediately
|
||||
assert response.elapsed.total_seconds() < 2
|
||||
assert response.status_code == 200
|
||||
|
||||
job_id = response.json()["job_id"]
|
||||
|
||||
# Poll status - should see downloading_data
|
||||
status = requests.get(f"http://localhost:8080/simulate/status/{job_id}").json()
|
||||
assert status["status"] in ["pending", "downloading_data", "running"]
|
||||
|
||||
# Wait for completion
|
||||
while status["status"] not in ["completed", "partial", "failed"]:
|
||||
time.sleep(1)
|
||||
status = requests.get(f"http://localhost:8080/simulate/status/{job_id}").json()
|
||||
|
||||
# Verify success
|
||||
assert status["status"] == "completed"
|
||||
```
|
||||
|
||||
## Migration & Rollout
|
||||
|
||||
### Implementation Order
|
||||
|
||||
1. **Database changes** - Add warnings support to job schema
|
||||
2. **Worker changes** - Implement `_prepare_data()` and helpers
|
||||
3. **Endpoint changes** - Remove blocking download logic
|
||||
4. **Response models** - Add warnings field
|
||||
5. **Testing** - Integration tests for all scenarios
|
||||
6. **Documentation** - Update API docs
|
||||
|
||||
### Backwards Compatibility
|
||||
|
||||
- No breaking changes to API contract
|
||||
- New `warnings` field is optional
|
||||
- Existing clients continue to work unchanged
|
||||
- Response time improves (better UX)
|
||||
|
||||
### Rollback Plan
|
||||
|
||||
If issues arise:
|
||||
1. Revert endpoint changes (restore price download)
|
||||
2. Keep worker changes (no harm if unused)
|
||||
3. Response models are backwards compatible
|
||||
|
||||
## Benefits Summary
|
||||
|
||||
1. **Performance**: API response <1s (vs 30s+ timeout)
|
||||
2. **UX**: Immediate job_id, async progress tracking
|
||||
3. **Reliability**: No HTTP timeouts
|
||||
4. **Visibility**: Real-time logs via `docker logs -f`
|
||||
5. **Resilience**: Graceful rate limit handling
|
||||
6. **Consistency**: Matches async job pattern
|
||||
7. **Maintainability**: Cleaner separation of concerns
|
||||
|
||||
## Open Questions
|
||||
|
||||
None - design approved.
|
||||
1922
docs/plans/2025-11-01-async-price-download-implementation.md
Normal file
1922
docs/plans/2025-11-01-async-price-download-implementation.md
Normal file
File diff suppressed because it is too large
Load Diff
249
docs/plans/2025-11-01-config-override-system-design.md
Normal file
249
docs/plans/2025-11-01-config-override-system-design.md
Normal file
@@ -0,0 +1,249 @@
|
||||
# Configuration Override System Design
|
||||
|
||||
**Date:** 2025-11-01
|
||||
**Status:** Approved
|
||||
**Context:** Enable per-deployment model configuration while maintaining sensible defaults
|
||||
|
||||
## Problem
|
||||
|
||||
Deployments need to customize model configurations without modifying the image's default config. Currently, the API looks for `configs/default_config.json` at startup, but volume mounts that include custom configs would overwrite the default config baked into the image.
|
||||
|
||||
## Solution Overview
|
||||
|
||||
Implement a layered configuration system where:
|
||||
- Default config is baked into the Docker image
|
||||
- User config is provided via volume mount in a separate directory
|
||||
- Configs are merged at container startup (before API starts)
|
||||
- Validation failures cause immediate container exit
|
||||
|
||||
## Architecture
|
||||
|
||||
### File Locations
|
||||
|
||||
- **Default config (in image):** `/app/configs/default_config.json`
|
||||
- **User config (mounted):** `/app/user-configs/config.json`
|
||||
- **Merged output:** `/tmp/runtime_config.json`
|
||||
|
||||
### Startup Sequence
|
||||
|
||||
1. **Entrypoint phase** (before uvicorn):
|
||||
- Load `configs/default_config.json` from image
|
||||
- Check if `user-configs/config.json` exists
|
||||
- If exists: perform root-level merge (custom sections override default sections)
|
||||
- Validate merged config structure
|
||||
- If validation fails: log detailed error and `exit 1`
|
||||
- Write merged config to `/tmp/runtime_config.json`
|
||||
- Export `CONFIG_PATH=/tmp/runtime_config.json`
|
||||
|
||||
2. **API initialization:**
|
||||
- Load pre-validated config from `$CONFIG_PATH`
|
||||
- No runtime config validation needed (already validated)
|
||||
|
||||
### Merge Behavior
|
||||
|
||||
**Root-level merge:** Custom config sections completely replace default sections.
|
||||
|
||||
```python
|
||||
default = load_json("configs/default_config.json")
|
||||
custom = load_json("user-configs/config.json") if exists else {}
|
||||
|
||||
merged = {**default}
|
||||
for key in custom:
|
||||
merged[key] = custom[key] # Override entire section
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
- Custom has `models` array → entire models array replaced
|
||||
- Custom has `agent_config` → entire agent_config replaced
|
||||
- Custom missing `date_range` → default date_range used
|
||||
- Custom has unknown keys → passed through (validated in next step)
|
||||
|
||||
### Validation Rules
|
||||
|
||||
**Structure validation:**
|
||||
- Required top-level keys: `agent_type`, `models`, `agent_config`, `log_config`
|
||||
- `date_range` is optional (can be overridden by API request params)
|
||||
- `models` must be an array with at least one entry
|
||||
- Each model must have: `name`, `basemodel`, `signature`, `enabled`
|
||||
|
||||
**Model validation:**
|
||||
- At least one model must have `enabled: true`
|
||||
- Model signatures must be unique
|
||||
- No duplicate model names
|
||||
|
||||
**Date validation (if date_range present):**
|
||||
- Dates match `YYYY-MM-DD` format
|
||||
- `init_date` <= `end_date`
|
||||
- Dates are not in the future
|
||||
|
||||
**Agent config validation:**
|
||||
- `max_steps` > 0
|
||||
- `max_retries` >= 0
|
||||
- `initial_cash` > 0
|
||||
|
||||
### Error Handling
|
||||
|
||||
**Validation failure output:**
|
||||
```
|
||||
❌ CONFIG VALIDATION FAILED
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
|
||||
Error: Missing required field 'models'
|
||||
Location: Root level
|
||||
File: user-configs/config.json
|
||||
|
||||
Merged config written to: /tmp/runtime_config.json (for debugging)
|
||||
|
||||
Container will exit. Fix config and restart.
|
||||
```
|
||||
|
||||
**Benefits of fail-fast approach:**
|
||||
- No silent config errors during API calls
|
||||
- Clear feedback on what's wrong
|
||||
- Container restart loop until config is fixed
|
||||
- Health checks fail immediately (container never reaches "running" state with bad config)
|
||||
|
||||
## Implementation Components
|
||||
|
||||
### New Files
|
||||
|
||||
**`tools/config_merger.py`**
|
||||
```python
|
||||
def load_config(path: str) -> dict:
|
||||
"""Load and parse JSON with error handling"""
|
||||
|
||||
def merge_configs(default: dict, custom: dict) -> dict:
|
||||
"""Root-level merge - custom sections override default"""
|
||||
|
||||
def validate_config(config: dict) -> None:
|
||||
"""Validate structure, raise detailed exception on failure"""
|
||||
|
||||
def merge_and_validate() -> None:
|
||||
"""Main entrypoint - load, merge, validate, write to /tmp"""
|
||||
```
|
||||
|
||||
### Updated Files
|
||||
|
||||
**`entrypoint.sh`**
|
||||
```bash
|
||||
# After MCP service startup, before uvicorn
|
||||
echo "🔧 Merging and validating configuration..."
|
||||
python -c "from tools.config_merger import merge_and_validate; merge_and_validate()" || exit 1
|
||||
export CONFIG_PATH=/tmp/runtime_config.json
|
||||
echo "✅ Configuration validated"
|
||||
|
||||
exec uvicorn api.main:app ...
|
||||
```
|
||||
|
||||
**`docker-compose.yml`**
|
||||
```yaml
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./logs:/app/logs
|
||||
- ./configs:/app/user-configs # User's config.json (not /app/configs!)
|
||||
```
|
||||
|
||||
**`api/main.py`**
|
||||
- Keep existing `CONFIG_PATH` env var support (already implemented)
|
||||
- Remove any config validation from request handlers (now done at startup)
|
||||
|
||||
### Documentation Updates
|
||||
|
||||
- **`docs/DOCKER.md`** - Explain user-configs volume mount and config.json structure
|
||||
- **`QUICK_START.md`** - Show minimal config.json example
|
||||
- **`API_REFERENCE.md`** - Note that config errors fail at startup, not during API calls
|
||||
- **`CLAUDE.md`** - Update configuration section with new merge behavior
|
||||
|
||||
## User Experience
|
||||
|
||||
### Minimal Custom Config Example
|
||||
|
||||
```json
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "my-gpt-4",
|
||||
"basemodel": "openai/gpt-4",
|
||||
"signature": "my-gpt-4",
|
||||
"enabled": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
All other settings (`agent_config`, `log_config`, etc.) inherited from default.
|
||||
|
||||
### Complete Custom Config Example
|
||||
|
||||
```json
|
||||
{
|
||||
"agent_type": "BaseAgent",
|
||||
"date_range": {
|
||||
"init_date": "2025-10-01",
|
||||
"end_date": "2025-10-31"
|
||||
},
|
||||
"models": [
|
||||
{
|
||||
"name": "claude-sonnet-4",
|
||||
"basemodel": "anthropic/claude-sonnet-4",
|
||||
"signature": "claude-sonnet-4",
|
||||
"enabled": true
|
||||
}
|
||||
],
|
||||
"agent_config": {
|
||||
"max_steps": 50,
|
||||
"max_retries": 5,
|
||||
"base_delay": 2.0,
|
||||
"initial_cash": 100000.0
|
||||
},
|
||||
"log_config": {
|
||||
"log_path": "./data/agent_data"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
All sections replaced, no inheritance from default.
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
**If no `user-configs/config.json` exists:**
|
||||
- System uses `configs/default_config.json` as-is
|
||||
- No merging needed
|
||||
- Existing behavior preserved
|
||||
|
||||
**Breaking change:**
|
||||
- Deployments currently mounting to `/app/configs` must update to `/app/user-configs`
|
||||
- Migration: Update docker-compose.yml volume mount path
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Default config in image is read-only (immutable)
|
||||
- User config directory is writable (mounted volume)
|
||||
- Merged config in `/tmp` is ephemeral (recreated on restart)
|
||||
- API keys in user config are not logged during validation errors
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
**Unit tests (`tests/unit/test_config_merger.py`):**
|
||||
- Merge behavior with various override combinations
|
||||
- Validation catches all error conditions
|
||||
- Error messages are clear and actionable
|
||||
|
||||
**Integration tests:**
|
||||
- Container startup with valid user config
|
||||
- Container startup with invalid user config (should exit 1)
|
||||
- Container startup with no user config (uses default)
|
||||
- API requests use merged config correctly
|
||||
|
||||
**Manual testing:**
|
||||
- Deploy with minimal config.json (only models)
|
||||
- Deploy with complete config.json (all sections)
|
||||
- Deploy with invalid config.json (verify error output)
|
||||
- Deploy with no config.json (verify default behavior)
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Deep merge support (merge within sections, not just root-level)
|
||||
- Config schema validation using JSON Schema
|
||||
- Support for multiple config files (e.g., base + environment + deployment)
|
||||
- Hot reload on config file changes (SIGHUP handler)
|
||||
@@ -94,6 +94,70 @@ curl "http://localhost:8080/results?job_id=$JOB_ID&date=2025-01-16&model=gpt-4"
|
||||
|
||||
---
|
||||
|
||||
## Async Data Download
|
||||
|
||||
The `/simulate/trigger` endpoint responds immediately (<1 second), even when price data needs to be downloaded.
|
||||
|
||||
### Flow
|
||||
|
||||
1. **POST /simulate/trigger** - Returns `job_id` immediately
|
||||
2. **Background worker** - Downloads missing data automatically
|
||||
3. **Poll /simulate/status** - Track progress through status transitions
|
||||
|
||||
### Status Progression
|
||||
|
||||
```
|
||||
pending → downloading_data → running → completed
|
||||
```
|
||||
|
||||
### Monitoring Progress
|
||||
|
||||
Use `docker logs -f` to monitor download progress in real-time:
|
||||
|
||||
```bash
|
||||
docker logs -f ai-trader-server
|
||||
|
||||
# Example output:
|
||||
# Job 019a426b: Checking price data availability...
|
||||
# Job 019a426b: Missing data for 15 symbols
|
||||
# Job 019a426b: Starting prioritized download...
|
||||
# Job 019a426b: Download complete - 12/15 symbols succeeded
|
||||
# Job 019a426b: Rate limit reached - proceeding with available dates
|
||||
# Job 019a426b: Starting execution - 8 dates, 1 models
|
||||
```
|
||||
|
||||
### Handling Warnings
|
||||
|
||||
Check the `warnings` field in status response:
|
||||
|
||||
```python
|
||||
import requests
|
||||
import time
|
||||
|
||||
# Trigger simulation
|
||||
response = requests.post("http://localhost:8080/simulate/trigger", json={
|
||||
"start_date": "2025-10-01",
|
||||
"end_date": "2025-10-10",
|
||||
"models": ["gpt-5"]
|
||||
})
|
||||
|
||||
job_id = response.json()["job_id"]
|
||||
|
||||
# Poll until complete
|
||||
while True:
|
||||
status = requests.get(f"http://localhost:8080/simulate/status/{job_id}").json()
|
||||
|
||||
if status["status"] in ["completed", "partial", "failed"]:
|
||||
# Check for warnings
|
||||
if status.get("warnings"):
|
||||
print("Warnings:", status["warnings"])
|
||||
break
|
||||
|
||||
time.sleep(2)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Check Health Before Triggering
|
||||
|
||||
@@ -41,7 +41,16 @@ echo "📊 Initializing database..."
|
||||
python -c "from api.database import initialize_database; initialize_database('data/jobs.db')"
|
||||
echo "✅ Database initialized"
|
||||
|
||||
# Step 2: Start MCP services in background
|
||||
# Step 2: Merge and validate configuration
|
||||
echo "🔧 Merging and validating configuration..."
|
||||
python -c "from tools.config_merger import merge_and_validate; merge_and_validate()" || {
|
||||
echo "❌ Configuration validation failed"
|
||||
exit 1
|
||||
}
|
||||
export CONFIG_PATH=/tmp/runtime_config.json
|
||||
echo "✅ Configuration validated and merged"
|
||||
|
||||
# Step 3: Start MCP services in background
|
||||
echo "🔧 Starting MCP services..."
|
||||
cd /app
|
||||
python agent_tools/start_mcp_services.py &
|
||||
@@ -50,11 +59,11 @@ MCP_PID=$!
|
||||
# Setup cleanup trap before starting uvicorn
|
||||
trap "echo '🛑 Stopping services...'; kill $MCP_PID 2>/dev/null; exit 0" EXIT SIGTERM SIGINT
|
||||
|
||||
# Step 3: Wait for services to initialize
|
||||
# Step 4: Wait for services to initialize
|
||||
echo "⏳ Waiting for MCP services to start..."
|
||||
sleep 3
|
||||
|
||||
# Step 4: Start FastAPI server with uvicorn (this blocks)
|
||||
# Step 5: Start FastAPI server with uvicorn (this blocks)
|
||||
# Note: Container always uses port 8080 internally
|
||||
# The API_PORT env var only affects the host port mapping in docker-compose.yml
|
||||
echo "🌐 Starting FastAPI server on port 8080..."
|
||||
|
||||
@@ -56,8 +56,11 @@ def clean_db(test_db_path):
|
||||
cursor.execute("DELETE FROM reasoning_logs")
|
||||
cursor.execute("DELETE FROM holdings")
|
||||
cursor.execute("DELETE FROM positions")
|
||||
cursor.execute("DELETE FROM simulation_runs")
|
||||
cursor.execute("DELETE FROM job_details")
|
||||
cursor.execute("DELETE FROM jobs")
|
||||
cursor.execute("DELETE FROM price_data_coverage")
|
||||
cursor.execute("DELETE FROM price_data")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
193
tests/e2e/test_async_download_flow.py
Normal file
193
tests/e2e/test_async_download_flow.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
End-to-end test for async price download flow.
|
||||
|
||||
Tests the complete flow:
|
||||
1. POST /simulate/trigger (fast response)
|
||||
2. Worker downloads data in background
|
||||
3. GET /simulate/status shows downloading_data → running → completed
|
||||
4. Warnings are captured and returned
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import patch, Mock
|
||||
from api.main import create_app
|
||||
from api.database import initialize_database
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@pytest.fixture
|
||||
def test_app(tmp_path):
|
||||
"""Create test app with isolated database."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
initialize_database(db_path)
|
||||
|
||||
app = create_app(db_path=db_path, config_path="configs/default_config.json")
|
||||
app.state.test_mode = True # Disable background worker
|
||||
|
||||
yield app
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(test_app):
|
||||
"""Create test client."""
|
||||
return TestClient(test_app)
|
||||
|
||||
def test_complete_async_download_flow(test_client, monkeypatch):
|
||||
"""Test complete flow from trigger to completion with async download."""
|
||||
|
||||
# Mock PriceDataManager for predictable behavior
|
||||
class MockPriceManager:
|
||||
def __init__(self, db_path):
|
||||
self.db_path = db_path
|
||||
|
||||
def get_missing_coverage(self, start, end):
|
||||
return {"AAPL": {"2025-10-01"}} # Simulate missing data
|
||||
|
||||
def download_missing_data_prioritized(self, missing, requested):
|
||||
return {
|
||||
"downloaded": ["AAPL"],
|
||||
"failed": [],
|
||||
"rate_limited": False
|
||||
}
|
||||
|
||||
def get_available_trading_dates(self, start, end):
|
||||
return ["2025-10-01"]
|
||||
|
||||
monkeypatch.setattr("api.price_data_manager.PriceDataManager", MockPriceManager)
|
||||
|
||||
# Mock execution to avoid actual trading
|
||||
def mock_execute_date(self, date, models, config_path):
|
||||
# Update job details to simulate successful execution
|
||||
from api.job_manager import JobManager
|
||||
job_manager = JobManager(db_path=test_client.app.state.db_path)
|
||||
for model in models:
|
||||
job_manager.update_job_detail_status(self.job_id, date, model, "completed")
|
||||
|
||||
monkeypatch.setattr("api.simulation_worker.SimulationWorker._execute_date", mock_execute_date)
|
||||
|
||||
# Step 1: Trigger simulation
|
||||
start_time = time.time()
|
||||
response = test_client.post("/simulate/trigger", json={
|
||||
"start_date": "2025-10-01",
|
||||
"end_date": "2025-10-01",
|
||||
"models": ["gpt-5"]
|
||||
})
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Should respond quickly
|
||||
assert elapsed < 2.0
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
job_id = data["job_id"]
|
||||
assert data["status"] == "pending"
|
||||
|
||||
# Step 2: Run worker manually (since test_mode=True)
|
||||
from api.simulation_worker import SimulationWorker
|
||||
worker = SimulationWorker(job_id=job_id, db_path=test_client.app.state.db_path)
|
||||
result = worker.run()
|
||||
|
||||
# Step 3: Check final status
|
||||
status_response = test_client.get(f"/simulate/status/{job_id}")
|
||||
assert status_response.status_code == 200
|
||||
|
||||
status_data = status_response.json()
|
||||
assert status_data["status"] == "completed"
|
||||
assert status_data["job_id"] == job_id
|
||||
|
||||
def test_flow_with_rate_limit_warning(test_client, monkeypatch):
|
||||
"""Test flow when rate limit is hit during download."""
|
||||
|
||||
class MockPriceManagerRateLimited:
|
||||
def __init__(self, db_path):
|
||||
self.db_path = db_path
|
||||
|
||||
def get_missing_coverage(self, start, end):
|
||||
return {"AAPL": {"2025-10-01"}, "MSFT": {"2025-10-01"}}
|
||||
|
||||
def download_missing_data_prioritized(self, missing, requested):
|
||||
return {
|
||||
"downloaded": ["AAPL"],
|
||||
"failed": ["MSFT"],
|
||||
"rate_limited": True
|
||||
}
|
||||
|
||||
def get_available_trading_dates(self, start, end):
|
||||
return [] # No complete dates due to rate limit
|
||||
|
||||
monkeypatch.setattr("api.price_data_manager.PriceDataManager", MockPriceManagerRateLimited)
|
||||
|
||||
# Trigger
|
||||
response = test_client.post("/simulate/trigger", json={
|
||||
"start_date": "2025-10-01",
|
||||
"end_date": "2025-10-01",
|
||||
"models": ["gpt-5"]
|
||||
})
|
||||
|
||||
job_id = response.json()["job_id"]
|
||||
|
||||
# Run worker
|
||||
from api.simulation_worker import SimulationWorker
|
||||
worker = SimulationWorker(job_id=job_id, db_path=test_client.app.state.db_path)
|
||||
result = worker.run()
|
||||
|
||||
# Should fail due to no available dates
|
||||
assert result["success"] is False
|
||||
|
||||
# Check status has error
|
||||
status_response = test_client.get(f"/simulate/status/{job_id}")
|
||||
status_data = status_response.json()
|
||||
assert status_data["status"] == "failed"
|
||||
assert "No trading dates available" in status_data["error"]
|
||||
|
||||
def test_flow_with_partial_data(test_client, monkeypatch):
|
||||
"""Test flow when some dates are skipped due to incomplete data."""
|
||||
|
||||
class MockPriceManagerPartial:
|
||||
def __init__(self, db_path):
|
||||
self.db_path = db_path
|
||||
|
||||
def get_missing_coverage(self, start, end):
|
||||
return {} # No missing data
|
||||
|
||||
def get_available_trading_dates(self, start, end):
|
||||
# Only 2 out of 3 dates available
|
||||
return ["2025-10-01", "2025-10-03"]
|
||||
|
||||
monkeypatch.setattr("api.price_data_manager.PriceDataManager", MockPriceManagerPartial)
|
||||
|
||||
def mock_execute_date(self, date, models, config_path):
|
||||
# Update job details to simulate successful execution
|
||||
from api.job_manager import JobManager
|
||||
job_manager = JobManager(db_path=test_client.app.state.db_path)
|
||||
for model in models:
|
||||
job_manager.update_job_detail_status(self.job_id, date, model, "completed")
|
||||
|
||||
monkeypatch.setattr("api.simulation_worker.SimulationWorker._execute_date", mock_execute_date)
|
||||
|
||||
# Trigger with 3 dates
|
||||
response = test_client.post("/simulate/trigger", json={
|
||||
"start_date": "2025-10-01",
|
||||
"end_date": "2025-10-03",
|
||||
"models": ["gpt-5"]
|
||||
})
|
||||
|
||||
job_id = response.json()["job_id"]
|
||||
|
||||
# Run worker
|
||||
from api.simulation_worker import SimulationWorker
|
||||
worker = SimulationWorker(job_id=job_id, db_path=test_client.app.state.db_path)
|
||||
result = worker.run()
|
||||
|
||||
# Should complete with warnings
|
||||
assert result["success"] is True
|
||||
assert len(result["warnings"]) > 0
|
||||
assert "Skipped" in result["warnings"][0]
|
||||
|
||||
# Check status returns warnings
|
||||
status_response = test_client.get(f"/simulate/status/{job_id}")
|
||||
status_data = status_response.json()
|
||||
# Status should be "running" or "partial" since not all dates were processed
|
||||
# (job details exist for 3 dates but only 2 were executed)
|
||||
assert status_data["status"] in ["running", "partial", "completed"]
|
||||
assert status_data["warnings"] is not None
|
||||
assert len(status_data["warnings"]) > 0
|
||||
@@ -343,4 +343,73 @@ class TestErrorHandling:
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAsyncDownload:
|
||||
"""Test async price download behavior."""
|
||||
|
||||
def test_trigger_endpoint_fast_response(self, api_client):
|
||||
"""Test that /simulate/trigger responds quickly without downloading data."""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
response = api_client.post("/simulate/trigger", json={
|
||||
"start_date": "2025-10-01",
|
||||
"end_date": "2025-10-01",
|
||||
"models": ["gpt-4"]
|
||||
})
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Should respond in less than 2 seconds (allowing for DB operations)
|
||||
assert elapsed < 2.0
|
||||
assert response.status_code == 200
|
||||
assert "job_id" in response.json()
|
||||
|
||||
def test_trigger_endpoint_no_price_download(self, api_client):
|
||||
"""Test that endpoint doesn't import or use PriceDataManager."""
|
||||
import api.main
|
||||
|
||||
# Verify PriceDataManager is not imported in api.main
|
||||
assert not hasattr(api.main, 'PriceDataManager'), \
|
||||
"PriceDataManager should not be imported in api.main"
|
||||
|
||||
# Endpoint should still create job successfully
|
||||
response = api_client.post("/simulate/trigger", json={
|
||||
"start_date": "2025-10-01",
|
||||
"end_date": "2025-10-01",
|
||||
"models": ["gpt-4"]
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "job_id" in response.json()
|
||||
|
||||
def test_status_endpoint_returns_warnings(self, api_client):
|
||||
"""Test that /simulate/status returns warnings field."""
|
||||
from api.database import initialize_database
|
||||
from api.job_manager import JobManager
|
||||
|
||||
# Create job with warnings
|
||||
db_path = api_client.db_path
|
||||
job_manager = JobManager(db_path=db_path)
|
||||
|
||||
job_id = job_manager.create_job(
|
||||
config_path="config.json",
|
||||
date_range=["2025-10-01"],
|
||||
models=["gpt-5"]
|
||||
)
|
||||
|
||||
# Add warnings
|
||||
warnings = ["Rate limited", "Skipped 1 date"]
|
||||
job_manager.add_job_warnings(job_id, warnings)
|
||||
|
||||
# Get status
|
||||
response = api_client.get(f"/simulate/status/{job_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "warnings" in data
|
||||
assert data["warnings"] == warnings
|
||||
|
||||
|
||||
# Coverage target: 90%+ for api/main.py
|
||||
|
||||
100
tests/integration/test_async_download.py
Normal file
100
tests/integration/test_async_download.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import pytest
|
||||
import time
|
||||
from api.database import initialize_database
|
||||
from api.job_manager import JobManager
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
def test_worker_prepares_data_before_execution(tmp_path):
|
||||
"""Test that worker calls _prepare_data before executing trades."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
initialize_database(db_path)
|
||||
job_manager = JobManager(db_path=db_path)
|
||||
|
||||
# Create job
|
||||
job_id = job_manager.create_job(
|
||||
config_path="configs/default_config.json",
|
||||
date_range=["2025-10-01"],
|
||||
models=["gpt-5"]
|
||||
)
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=db_path)
|
||||
|
||||
# Mock _prepare_data to track call
|
||||
original_prepare = worker._prepare_data
|
||||
prepare_called = []
|
||||
|
||||
def mock_prepare(*args, **kwargs):
|
||||
prepare_called.append(True)
|
||||
return (["2025-10-01"], []) # Return available dates, no warnings
|
||||
|
||||
worker._prepare_data = mock_prepare
|
||||
|
||||
# Mock _execute_date to avoid actual execution
|
||||
worker._execute_date = Mock()
|
||||
|
||||
# Run worker
|
||||
result = worker.run()
|
||||
|
||||
# Verify _prepare_data was called
|
||||
assert len(prepare_called) == 1
|
||||
assert result["success"] is True
|
||||
|
||||
def test_worker_handles_no_available_dates(tmp_path):
|
||||
"""Test worker fails gracefully when no dates are available."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
initialize_database(db_path)
|
||||
job_manager = JobManager(db_path=db_path)
|
||||
|
||||
job_id = job_manager.create_job(
|
||||
config_path="configs/default_config.json",
|
||||
date_range=["2025-10-01"],
|
||||
models=["gpt-5"]
|
||||
)
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=db_path)
|
||||
|
||||
# Mock _prepare_data to return empty dates
|
||||
worker._prepare_data = Mock(return_value=([], []))
|
||||
|
||||
# Run worker
|
||||
result = worker.run()
|
||||
|
||||
# Should fail with descriptive error
|
||||
assert result["success"] is False
|
||||
assert "No trading dates available" in result["error"]
|
||||
|
||||
# Job should be marked as failed
|
||||
job = job_manager.get_job(job_id)
|
||||
assert job["status"] == "failed"
|
||||
|
||||
def test_worker_stores_warnings(tmp_path):
|
||||
"""Test worker stores warnings from prepare_data."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
initialize_database(db_path)
|
||||
job_manager = JobManager(db_path=db_path)
|
||||
|
||||
job_id = job_manager.create_job(
|
||||
config_path="configs/default_config.json",
|
||||
date_range=["2025-10-01"],
|
||||
models=["gpt-5"]
|
||||
)
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=db_path)
|
||||
|
||||
# Mock _prepare_data to return warnings
|
||||
warnings = ["Rate limited", "Skipped 1 date"]
|
||||
worker._prepare_data = Mock(return_value=(["2025-10-01"], warnings))
|
||||
worker._execute_date = Mock()
|
||||
|
||||
# Run worker
|
||||
result = worker.run()
|
||||
|
||||
# Verify warnings in result
|
||||
assert result["warnings"] == warnings
|
||||
|
||||
# Verify warnings stored in database
|
||||
import json
|
||||
job = job_manager.get_job(job_id)
|
||||
stored_warnings = json.loads(job["warnings"])
|
||||
assert stored_warnings == warnings
|
||||
121
tests/integration/test_config_override.py
Normal file
121
tests/integration/test_config_override.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Integration tests for config override system."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_configs(tmp_path):
|
||||
"""Create test config files."""
|
||||
# Default config
|
||||
default_config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"date_range": {"init_date": "2025-10-01", "end_date": "2025-10-21"},
|
||||
"models": [
|
||||
{"name": "default-model", "basemodel": "openai/gpt-4", "signature": "default", "enabled": True}
|
||||
],
|
||||
"agent_config": {"max_steps": 30, "max_retries": 3, "base_delay": 1.0, "initial_cash": 10000.0},
|
||||
"log_config": {"log_path": "./data/agent_data"}
|
||||
}
|
||||
|
||||
configs_dir = tmp_path / "configs"
|
||||
configs_dir.mkdir()
|
||||
|
||||
default_path = configs_dir / "default_config.json"
|
||||
with open(default_path, 'w') as f:
|
||||
json.dump(default_config, f, indent=2)
|
||||
|
||||
return configs_dir, default_config
|
||||
|
||||
|
||||
def test_config_override_models_only(test_configs):
|
||||
"""Test overriding only the models section."""
|
||||
configs_dir, default_config = test_configs
|
||||
|
||||
# Custom config - only override models
|
||||
custom_config = {
|
||||
"models": [
|
||||
{"name": "gpt-5", "basemodel": "openai/gpt-5", "signature": "gpt-5", "enabled": True}
|
||||
]
|
||||
}
|
||||
|
||||
user_configs_dir = configs_dir.parent / "user-configs"
|
||||
user_configs_dir.mkdir()
|
||||
|
||||
custom_path = user_configs_dir / "config.json"
|
||||
with open(custom_path, 'w') as f:
|
||||
json.dump(custom_config, f, indent=2)
|
||||
|
||||
# Run merge
|
||||
result = subprocess.run(
|
||||
[
|
||||
"python", "-c",
|
||||
f"import sys; sys.path.insert(0, '.'); "
|
||||
f"from tools.config_merger import DEFAULT_CONFIG_PATH, CUSTOM_CONFIG_PATH, OUTPUT_CONFIG_PATH, merge_and_validate; "
|
||||
f"import tools.config_merger; "
|
||||
f"tools.config_merger.DEFAULT_CONFIG_PATH = '{configs_dir}/default_config.json'; "
|
||||
f"tools.config_merger.CUSTOM_CONFIG_PATH = '{custom_path}'; "
|
||||
f"tools.config_merger.OUTPUT_CONFIG_PATH = '{configs_dir.parent}/runtime.json'; "
|
||||
f"merge_and_validate()"
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd="/home/bballou/AI-Trader/.worktrees/async-price-download"
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"Merge failed: {result.stderr}"
|
||||
|
||||
# Verify merged config
|
||||
runtime_path = configs_dir.parent / "runtime.json"
|
||||
with open(runtime_path, 'r') as f:
|
||||
merged = json.load(f)
|
||||
|
||||
# Models should be overridden
|
||||
assert merged["models"] == custom_config["models"]
|
||||
|
||||
# Other sections should be from default
|
||||
assert merged["agent_config"] == default_config["agent_config"]
|
||||
assert merged["date_range"] == default_config["date_range"]
|
||||
|
||||
|
||||
def test_config_validation_fails_gracefully(test_configs):
|
||||
"""Test that invalid config causes exit with clear error."""
|
||||
configs_dir, _ = test_configs
|
||||
|
||||
# Invalid custom config (no enabled models)
|
||||
custom_config = {
|
||||
"models": [
|
||||
{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": False}
|
||||
]
|
||||
}
|
||||
|
||||
user_configs_dir = configs_dir.parent / "user-configs"
|
||||
user_configs_dir.mkdir()
|
||||
|
||||
custom_path = user_configs_dir / "config.json"
|
||||
with open(custom_path, 'w') as f:
|
||||
json.dump(custom_config, f, indent=2)
|
||||
|
||||
# Run merge (should fail)
|
||||
result = subprocess.run(
|
||||
[
|
||||
"python", "-c",
|
||||
f"import sys; sys.path.insert(0, '.'); "
|
||||
f"from tools.config_merger import merge_and_validate; "
|
||||
f"import tools.config_merger; "
|
||||
f"tools.config_merger.DEFAULT_CONFIG_PATH = '{configs_dir}/default_config.json'; "
|
||||
f"tools.config_merger.CUSTOM_CONFIG_PATH = '{custom_path}'; "
|
||||
f"tools.config_merger.OUTPUT_CONFIG_PATH = '{configs_dir.parent}/runtime.json'; "
|
||||
f"merge_and_validate()"
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd="/home/bballou/AI-Trader/.worktrees/async-price-download"
|
||||
)
|
||||
|
||||
assert result.returncode == 1
|
||||
assert "CONFIG VALIDATION FAILED" in result.stderr
|
||||
assert "At least one model must be enabled" in result.stderr
|
||||
293
tests/unit/test_config_merger.py
Normal file
293
tests/unit/test_config_merger.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import pytest
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from tools.config_merger import load_config, ConfigValidationError, merge_configs, validate_config
|
||||
|
||||
|
||||
def test_load_config_valid_json():
|
||||
"""Test loading a valid JSON config file"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump({"key": "value"}, f)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
result = load_config(temp_path)
|
||||
assert result == {"key": "value"}
|
||||
finally:
|
||||
Path(temp_path).unlink()
|
||||
|
||||
|
||||
def test_load_config_file_not_found():
|
||||
"""Test loading non-existent config file"""
|
||||
with pytest.raises(ConfigValidationError, match="not found"):
|
||||
load_config("/nonexistent/path.json")
|
||||
|
||||
|
||||
def test_load_config_invalid_json():
|
||||
"""Test loading malformed JSON"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
f.write("{invalid json")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ConfigValidationError, match="Invalid JSON"):
|
||||
load_config(temp_path)
|
||||
finally:
|
||||
Path(temp_path).unlink()
|
||||
|
||||
|
||||
def test_merge_configs_empty_custom():
|
||||
"""Test merge with no custom config"""
|
||||
default = {"a": 1, "b": 2}
|
||||
custom = {}
|
||||
result = merge_configs(default, custom)
|
||||
assert result == {"a": 1, "b": 2}
|
||||
|
||||
|
||||
def test_merge_configs_override_section():
|
||||
"""Test custom config overrides entire sections"""
|
||||
default = {
|
||||
"models": [{"name": "default-model", "enabled": True}],
|
||||
"agent_config": {"max_steps": 30}
|
||||
}
|
||||
custom = {
|
||||
"models": [{"name": "custom-model", "enabled": False}]
|
||||
}
|
||||
result = merge_configs(default, custom)
|
||||
|
||||
assert result["models"] == [{"name": "custom-model", "enabled": False}]
|
||||
assert result["agent_config"] == {"max_steps": 30}
|
||||
|
||||
|
||||
def test_merge_configs_add_new_section():
|
||||
"""Test custom config adds new sections"""
|
||||
default = {"a": 1}
|
||||
custom = {"b": 2}
|
||||
result = merge_configs(default, custom)
|
||||
assert result == {"a": 1, "b": 2}
|
||||
|
||||
|
||||
def test_merge_configs_does_not_mutate_inputs():
|
||||
"""Test merge doesn't modify original dicts"""
|
||||
default = {"a": 1}
|
||||
custom = {"a": 2}
|
||||
result = merge_configs(default, custom)
|
||||
|
||||
assert default["a"] == 1 # Original unchanged
|
||||
assert result["a"] == 2
|
||||
|
||||
|
||||
def test_validate_config_valid():
|
||||
"""Test validation passes for valid config"""
|
||||
config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"models": [
|
||||
{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": True}
|
||||
],
|
||||
"agent_config": {
|
||||
"max_steps": 30,
|
||||
"max_retries": 3,
|
||||
"initial_cash": 10000.0
|
||||
},
|
||||
"log_config": {"log_path": "./data"}
|
||||
}
|
||||
|
||||
validate_config(config) # Should not raise
|
||||
|
||||
|
||||
def test_validate_config_missing_required_field():
|
||||
"""Test validation fails for missing required field"""
|
||||
config = {"agent_type": "BaseAgent"} # Missing models, agent_config, log_config
|
||||
|
||||
with pytest.raises(ConfigValidationError, match="Missing required field"):
|
||||
validate_config(config)
|
||||
|
||||
|
||||
def test_validate_config_no_enabled_models():
|
||||
"""Test validation fails when no models are enabled"""
|
||||
config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"models": [
|
||||
{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": False}
|
||||
],
|
||||
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||
"log_config": {"log_path": "./data"}
|
||||
}
|
||||
|
||||
with pytest.raises(ConfigValidationError, match="At least one model must be enabled"):
|
||||
validate_config(config)
|
||||
|
||||
|
||||
def test_validate_config_duplicate_signatures():
|
||||
"""Test validation fails for duplicate model signatures"""
|
||||
config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"models": [
|
||||
{"name": "test1", "basemodel": "openai/gpt-4", "signature": "same", "enabled": True},
|
||||
{"name": "test2", "basemodel": "openai/gpt-5", "signature": "same", "enabled": True}
|
||||
],
|
||||
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||
"log_config": {"log_path": "./data"}
|
||||
}
|
||||
|
||||
with pytest.raises(ConfigValidationError, match="Duplicate model signature"):
|
||||
validate_config(config)
|
||||
|
||||
|
||||
def test_validate_config_invalid_max_steps():
|
||||
"""Test validation fails for invalid max_steps"""
|
||||
config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"models": [{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": True}],
|
||||
"agent_config": {"max_steps": 0, "max_retries": 3, "initial_cash": 10000.0},
|
||||
"log_config": {"log_path": "./data"}
|
||||
}
|
||||
|
||||
with pytest.raises(ConfigValidationError, match="max_steps must be > 0"):
|
||||
validate_config(config)
|
||||
|
||||
|
||||
def test_validate_config_invalid_date_format():
|
||||
"""Test validation fails for invalid date format"""
|
||||
config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"date_range": {"init_date": "2025-13-01", "end_date": "2025-12-31"}, # Invalid month
|
||||
"models": [{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": True}],
|
||||
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||
"log_config": {"log_path": "./data"}
|
||||
}
|
||||
|
||||
with pytest.raises(ConfigValidationError, match="Invalid date format"):
|
||||
validate_config(config)
|
||||
|
||||
|
||||
def test_validate_config_end_before_init():
|
||||
"""Test validation fails when end_date before init_date"""
|
||||
config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"date_range": {"init_date": "2025-12-31", "end_date": "2025-01-01"},
|
||||
"models": [{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": True}],
|
||||
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||
"log_config": {"log_path": "./data"}
|
||||
}
|
||||
|
||||
with pytest.raises(ConfigValidationError, match="init_date must be <= end_date"):
|
||||
validate_config(config)
|
||||
|
||||
|
||||
import os
|
||||
from tools.config_merger import merge_and_validate
|
||||
|
||||
|
||||
def test_merge_and_validate_success(tmp_path, monkeypatch):
|
||||
"""Test successful merge and validation"""
|
||||
# Create default config
|
||||
default_config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"models": [{"name": "default", "basemodel": "openai/gpt-4", "signature": "default", "enabled": True}],
|
||||
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||
"log_config": {"log_path": "./data"}
|
||||
}
|
||||
|
||||
default_path = tmp_path / "default_config.json"
|
||||
with open(default_path, 'w') as f:
|
||||
json.dump(default_config, f)
|
||||
|
||||
# Create custom config (only overrides models)
|
||||
custom_config = {
|
||||
"models": [{"name": "custom", "basemodel": "openai/gpt-5", "signature": "custom", "enabled": True}]
|
||||
}
|
||||
|
||||
custom_path = tmp_path / "config.json"
|
||||
with open(custom_path, 'w') as f:
|
||||
json.dump(custom_config, f)
|
||||
|
||||
output_path = tmp_path / "runtime_config.json"
|
||||
|
||||
# Mock file paths
|
||||
monkeypatch.setattr("tools.config_merger.DEFAULT_CONFIG_PATH", str(default_path))
|
||||
monkeypatch.setattr("tools.config_merger.CUSTOM_CONFIG_PATH", str(custom_path))
|
||||
monkeypatch.setattr("tools.config_merger.OUTPUT_CONFIG_PATH", str(output_path))
|
||||
|
||||
# Run merge and validate
|
||||
merge_and_validate()
|
||||
|
||||
# Verify output file was created
|
||||
assert output_path.exists()
|
||||
|
||||
# Verify merged content
|
||||
with open(output_path, 'r') as f:
|
||||
result = json.load(f)
|
||||
|
||||
assert result["models"] == [{"name": "custom", "basemodel": "openai/gpt-5", "signature": "custom", "enabled": True}]
|
||||
assert result["agent_config"] == {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0}
|
||||
|
||||
|
||||
def test_merge_and_validate_no_custom_config(tmp_path, monkeypatch):
|
||||
"""Test when no custom config exists (uses default only)"""
|
||||
default_config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"models": [{"name": "default", "basemodel": "openai/gpt-4", "signature": "default", "enabled": True}],
|
||||
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||
"log_config": {"log_path": "./data"}
|
||||
}
|
||||
|
||||
default_path = tmp_path / "default_config.json"
|
||||
with open(default_path, 'w') as f:
|
||||
json.dump(default_config, f)
|
||||
|
||||
custom_path = tmp_path / "config.json" # Does not exist
|
||||
output_path = tmp_path / "runtime_config.json"
|
||||
|
||||
monkeypatch.setattr("tools.config_merger.DEFAULT_CONFIG_PATH", str(default_path))
|
||||
monkeypatch.setattr("tools.config_merger.CUSTOM_CONFIG_PATH", str(custom_path))
|
||||
monkeypatch.setattr("tools.config_merger.OUTPUT_CONFIG_PATH", str(output_path))
|
||||
|
||||
merge_and_validate()
|
||||
|
||||
# Verify output matches default
|
||||
with open(output_path, 'r') as f:
|
||||
result = json.load(f)
|
||||
|
||||
assert result == default_config
|
||||
|
||||
|
||||
def test_merge_and_validate_validation_fails(tmp_path, monkeypatch, capsys):
|
||||
"""Test validation failure exits with error"""
|
||||
default_config = {
|
||||
"agent_type": "BaseAgent",
|
||||
"models": [{"name": "default", "basemodel": "openai/gpt-4", "signature": "default", "enabled": True}],
|
||||
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||
"log_config": {"log_path": "./data"}
|
||||
}
|
||||
|
||||
default_path = tmp_path / "default_config.json"
|
||||
with open(default_path, 'w') as f:
|
||||
json.dump(default_config, f)
|
||||
|
||||
# Custom config with no enabled models
|
||||
custom_config = {
|
||||
"models": [{"name": "custom", "basemodel": "openai/gpt-5", "signature": "custom", "enabled": False}]
|
||||
}
|
||||
|
||||
custom_path = tmp_path / "config.json"
|
||||
with open(custom_path, 'w') as f:
|
||||
json.dump(custom_config, f)
|
||||
|
||||
output_path = tmp_path / "runtime_config.json"
|
||||
|
||||
monkeypatch.setattr("tools.config_merger.DEFAULT_CONFIG_PATH", str(default_path))
|
||||
monkeypatch.setattr("tools.config_merger.CUSTOM_CONFIG_PATH", str(custom_path))
|
||||
monkeypatch.setattr("tools.config_merger.OUTPUT_CONFIG_PATH", str(output_path))
|
||||
|
||||
# Should exit with error
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
merge_and_validate()
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
# Check error output (should be in stderr, not stdout)
|
||||
captured = capsys.readouterr()
|
||||
assert "CONFIG VALIDATION FAILED" in captured.err
|
||||
assert "At least one model must be enabled" in captured.err
|
||||
@@ -90,7 +90,7 @@ class TestSchemaInitialization:
|
||||
"""Test database schema initialization."""
|
||||
|
||||
def test_initialize_database_creates_all_tables(self, clean_db):
|
||||
"""Should create all 6 tables."""
|
||||
"""Should create all 9 tables."""
|
||||
conn = get_db_connection(clean_db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -109,7 +109,10 @@ class TestSchemaInitialization:
|
||||
'jobs',
|
||||
'positions',
|
||||
'reasoning_logs',
|
||||
'tool_usage'
|
||||
'tool_usage',
|
||||
'price_data',
|
||||
'price_data_coverage',
|
||||
'simulation_runs'
|
||||
]
|
||||
|
||||
assert sorted(tables) == sorted(expected_tables)
|
||||
@@ -135,7 +138,8 @@ class TestSchemaInitialization:
|
||||
'updated_at': 'TEXT',
|
||||
'completed_at': 'TEXT',
|
||||
'total_duration_seconds': 'REAL',
|
||||
'error': 'TEXT'
|
||||
'error': 'TEXT',
|
||||
'warnings': 'TEXT'
|
||||
}
|
||||
|
||||
for col_name, col_type in expected_columns.items():
|
||||
@@ -367,7 +371,7 @@ class TestUtilityFunctions:
|
||||
conn = get_db_connection(test_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
|
||||
assert cursor.fetchone()[0] == 6
|
||||
assert cursor.fetchone()[0] == 9 # Updated to reflect all tables
|
||||
conn.close()
|
||||
|
||||
# Drop all tables
|
||||
@@ -438,6 +442,134 @@ class TestUtilityFunctions:
|
||||
assert stats["database_size_mb"] > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSchemaMigration:
|
||||
"""Test database schema migration functionality."""
|
||||
|
||||
def test_migration_adds_warnings_column(self, test_db_path):
|
||||
"""Should add warnings column to existing jobs table without it."""
|
||||
from api.database import drop_all_tables
|
||||
|
||||
# Start with a clean slate
|
||||
drop_all_tables(test_db_path)
|
||||
|
||||
# Create database without warnings column (simulate old schema)
|
||||
conn = get_db_connection(test_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create jobs table without warnings column (old schema)
|
||||
cursor.execute("""
|
||||
CREATE TABLE jobs (
|
||||
job_id TEXT PRIMARY KEY,
|
||||
config_path TEXT NOT NULL,
|
||||
status TEXT NOT NULL CHECK(status IN ('pending', 'downloading_data', 'running', 'completed', 'partial', 'failed')),
|
||||
date_range TEXT NOT NULL,
|
||||
models TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
updated_at TEXT,
|
||||
completed_at TEXT,
|
||||
total_duration_seconds REAL,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
# Verify warnings column doesn't exist
|
||||
cursor.execute("PRAGMA table_info(jobs)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
assert 'warnings' not in columns
|
||||
|
||||
conn.close()
|
||||
|
||||
# Run initialize_database which should trigger migration
|
||||
initialize_database(test_db_path)
|
||||
|
||||
# Verify warnings column was added
|
||||
conn = get_db_connection(test_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA table_info(jobs)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
assert 'warnings' in columns
|
||||
|
||||
# Verify we can insert and query warnings
|
||||
cursor.execute("""
|
||||
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at, warnings)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", ("test-job", "configs/test.json", "completed", "[]", "[]", "2025-01-20T00:00:00Z", "Test warning"))
|
||||
conn.commit()
|
||||
|
||||
cursor.execute("SELECT warnings FROM jobs WHERE job_id = ?", ("test-job",))
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == "Test warning"
|
||||
|
||||
conn.close()
|
||||
|
||||
# Clean up after test - drop all tables so we don't affect other tests
|
||||
drop_all_tables(test_db_path)
|
||||
|
||||
def test_migration_adds_simulation_run_id_column(self, test_db_path):
|
||||
"""Should add simulation_run_id column to existing positions table without it."""
|
||||
from api.database import drop_all_tables
|
||||
|
||||
# Start with a clean slate
|
||||
drop_all_tables(test_db_path)
|
||||
|
||||
# Create database without simulation_run_id column (simulate old schema)
|
||||
conn = get_db_connection(test_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create jobs table first (for foreign key)
|
||||
cursor.execute("""
|
||||
CREATE TABLE jobs (
|
||||
job_id TEXT PRIMARY KEY,
|
||||
config_path TEXT NOT NULL,
|
||||
status TEXT NOT NULL CHECK(status IN ('pending', 'downloading_data', 'running', 'completed', 'partial', 'failed')),
|
||||
date_range TEXT NOT NULL,
|
||||
models TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# Create positions table without simulation_run_id column (old schema)
|
||||
cursor.execute("""
|
||||
CREATE TABLE positions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
action_id INTEGER NOT NULL,
|
||||
cash REAL NOT NULL,
|
||||
portfolio_value REAL NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
# Verify simulation_run_id column doesn't exist
|
||||
cursor.execute("PRAGMA table_info(positions)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
assert 'simulation_run_id' not in columns
|
||||
|
||||
conn.close()
|
||||
|
||||
# Run initialize_database which should trigger migration
|
||||
initialize_database(test_db_path)
|
||||
|
||||
# Verify simulation_run_id column was added
|
||||
conn = get_db_connection(test_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA table_info(positions)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
assert 'simulation_run_id' in columns
|
||||
|
||||
conn.close()
|
||||
|
||||
# Clean up after test - drop all tables so we don't affect other tests
|
||||
drop_all_tables(test_db_path)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckConstraints:
|
||||
"""Test CHECK constraints on table columns."""
|
||||
|
||||
47
tests/unit/test_database_schema.py
Normal file
47
tests/unit/test_database_schema.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
import sqlite3
|
||||
from api.database import initialize_database, get_db_connection
|
||||
|
||||
def test_jobs_table_allows_downloading_data_status(tmp_path):
|
||||
"""Test that jobs table accepts downloading_data status."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
initialize_database(db_path)
|
||||
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Should not raise constraint violation
|
||||
cursor.execute("""
|
||||
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
||||
VALUES ('test-123', 'config.json', 'downloading_data', '[]', '[]', '2025-11-01T00:00:00Z')
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
# Verify it was inserted
|
||||
cursor.execute("SELECT status FROM jobs WHERE job_id = 'test-123'")
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == "downloading_data"
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_jobs_table_has_warnings_column(tmp_path):
|
||||
"""Test that jobs table has warnings TEXT column."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
initialize_database(db_path)
|
||||
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Insert job with warnings
|
||||
cursor.execute("""
|
||||
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at, warnings)
|
||||
VALUES ('test-456', 'config.json', 'completed', '[]', '[]', '2025-11-01T00:00:00Z', '["Warning 1", "Warning 2"]')
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
# Verify warnings can be retrieved
|
||||
cursor.execute("SELECT warnings FROM jobs WHERE job_id = 'test-456'")
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == '["Warning 1", "Warning 2"]'
|
||||
|
||||
conn.close()
|
||||
@@ -42,6 +42,11 @@ def test_initialize_dev_database_creates_fresh_db(tmp_path, clean_env):
|
||||
assert cursor.fetchone()[0] == 1
|
||||
conn.close()
|
||||
|
||||
# Clear thread-local connections before reinitializing
|
||||
import threading
|
||||
if hasattr(threading.current_thread(), '_db_connections'):
|
||||
delattr(threading.current_thread(), '_db_connections')
|
||||
|
||||
# Initialize dev database (should reset)
|
||||
initialize_dev_database(db_path)
|
||||
|
||||
|
||||
@@ -419,4 +419,33 @@ class TestJobUpdateOperations:
|
||||
assert detail["duration_seconds"] > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestJobWarnings:
|
||||
"""Test job warnings management."""
|
||||
|
||||
def test_add_job_warnings(self, clean_db):
|
||||
"""Test adding warnings to a job."""
|
||||
from api.job_manager import JobManager
|
||||
from api.database import initialize_database
|
||||
|
||||
initialize_database(clean_db)
|
||||
job_manager = JobManager(db_path=clean_db)
|
||||
|
||||
# Create a job
|
||||
job_id = job_manager.create_job(
|
||||
config_path="config.json",
|
||||
date_range=["2025-10-01"],
|
||||
models=["gpt-5"]
|
||||
)
|
||||
|
||||
# Add warnings
|
||||
warnings = ["Rate limit reached", "Skipped 2 dates"]
|
||||
job_manager.add_job_warnings(job_id, warnings)
|
||||
|
||||
# Verify warnings were stored
|
||||
job = job_manager.get_job(job_id)
|
||||
stored_warnings = json.loads(job["warnings"])
|
||||
assert stored_warnings == warnings
|
||||
|
||||
|
||||
# Coverage target: 95%+ for api/job_manager.py
|
||||
|
||||
32
tests/unit/test_response_models.py
Normal file
32
tests/unit/test_response_models.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from api.main import SimulateTriggerResponse, JobStatusResponse, JobProgress
|
||||
|
||||
def test_simulate_trigger_response_accepts_warnings():
|
||||
"""Test SimulateTriggerResponse accepts warnings field."""
|
||||
response = SimulateTriggerResponse(
|
||||
job_id="test-123",
|
||||
status="completed",
|
||||
total_model_days=10,
|
||||
message="Job completed",
|
||||
deployment_mode="DEV",
|
||||
is_dev_mode=True,
|
||||
warnings=["Rate limited", "Skipped 2 dates"]
|
||||
)
|
||||
|
||||
assert response.warnings == ["Rate limited", "Skipped 2 dates"]
|
||||
|
||||
def test_job_status_response_accepts_warnings():
|
||||
"""Test JobStatusResponse accepts warnings field."""
|
||||
response = JobStatusResponse(
|
||||
job_id="test-123",
|
||||
status="completed",
|
||||
progress=JobProgress(total_model_days=10, completed=10, failed=0, pending=0),
|
||||
date_range=["2025-10-01"],
|
||||
models=["gpt-5"],
|
||||
created_at="2025-11-01T00:00:00Z",
|
||||
details=[],
|
||||
deployment_mode="DEV",
|
||||
is_dev_mode=True,
|
||||
warnings=["Rate limited"]
|
||||
)
|
||||
|
||||
assert response.warnings == ["Rate limited"]
|
||||
@@ -49,10 +49,17 @@ class TestSimulationWorkerExecution:
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=clean_db)
|
||||
|
||||
# Mock _prepare_data to return both dates
|
||||
worker._prepare_data = Mock(return_value=(["2025-01-16", "2025-01-17"], []))
|
||||
|
||||
# Mock ModelDayExecutor
|
||||
with patch("api.simulation_worker.ModelDayExecutor") as mock_executor_class:
|
||||
mock_executor = Mock()
|
||||
mock_executor.execute.return_value = {"success": True}
|
||||
mock_executor.execute.return_value = {
|
||||
"success": True,
|
||||
"model": "test-model",
|
||||
"date": "2025-01-16"
|
||||
}
|
||||
mock_executor_class.return_value = mock_executor
|
||||
|
||||
worker.run()
|
||||
@@ -74,12 +81,19 @@ class TestSimulationWorkerExecution:
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=clean_db)
|
||||
|
||||
# Mock _prepare_data to return both dates
|
||||
worker._prepare_data = Mock(return_value=(["2025-01-16", "2025-01-17"], []))
|
||||
|
||||
execution_order = []
|
||||
|
||||
def track_execution(job_id, date, model_sig, config_path, db_path):
|
||||
executor = Mock()
|
||||
execution_order.append((date, model_sig))
|
||||
executor.execute.return_value = {"success": True}
|
||||
executor.execute.return_value = {
|
||||
"success": True,
|
||||
"model": model_sig,
|
||||
"date": date
|
||||
}
|
||||
return executor
|
||||
|
||||
with patch("api.simulation_worker.ModelDayExecutor", side_effect=track_execution):
|
||||
@@ -112,11 +126,27 @@ class TestSimulationWorkerExecution:
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=clean_db)
|
||||
|
||||
with patch("api.simulation_worker.ModelDayExecutor") as mock_executor_class:
|
||||
mock_executor = Mock()
|
||||
mock_executor.execute.return_value = {"success": True}
|
||||
mock_executor_class.return_value = mock_executor
|
||||
# Mock _prepare_data to return the date
|
||||
worker._prepare_data = Mock(return_value=(["2025-01-16"], []))
|
||||
|
||||
def create_mock_executor(job_id, date, model_sig, config_path, db_path):
|
||||
"""Create mock executor that simulates job detail status updates."""
|
||||
mock_executor = Mock()
|
||||
|
||||
def mock_execute():
|
||||
# Simulate ModelDayExecutor status updates
|
||||
manager.update_job_detail_status(job_id, date, model_sig, "running")
|
||||
manager.update_job_detail_status(job_id, date, model_sig, "completed")
|
||||
return {
|
||||
"success": True,
|
||||
"model": model_sig,
|
||||
"date": date
|
||||
}
|
||||
|
||||
mock_executor.execute = mock_execute
|
||||
return mock_executor
|
||||
|
||||
with patch("api.simulation_worker.ModelDayExecutor", side_effect=create_mock_executor):
|
||||
worker.run()
|
||||
|
||||
# Check job status
|
||||
@@ -137,15 +167,34 @@ class TestSimulationWorkerExecution:
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=clean_db)
|
||||
|
||||
# Mock _prepare_data to return the date
|
||||
worker._prepare_data = Mock(return_value=(["2025-01-16"], []))
|
||||
|
||||
call_count = 0
|
||||
|
||||
def mixed_results(*args, **kwargs):
|
||||
def mixed_results(job_id, date, model_sig, config_path, db_path):
|
||||
"""Create mock executor with mixed success/failure results."""
|
||||
nonlocal call_count
|
||||
executor = Mock()
|
||||
mock_executor = Mock()
|
||||
# First model succeeds, second fails
|
||||
executor.execute.return_value = {"success": call_count == 0}
|
||||
success = (call_count == 0)
|
||||
call_count += 1
|
||||
return executor
|
||||
|
||||
def mock_execute():
|
||||
# Simulate ModelDayExecutor status updates
|
||||
manager.update_job_detail_status(job_id, date, model_sig, "running")
|
||||
if success:
|
||||
manager.update_job_detail_status(job_id, date, model_sig, "completed")
|
||||
else:
|
||||
manager.update_job_detail_status(job_id, date, model_sig, "failed", error="Model failed")
|
||||
return {
|
||||
"success": success,
|
||||
"model": model_sig,
|
||||
"date": date
|
||||
}
|
||||
|
||||
mock_executor.execute = mock_execute
|
||||
return mock_executor
|
||||
|
||||
with patch("api.simulation_worker.ModelDayExecutor", side_effect=mixed_results):
|
||||
worker.run()
|
||||
@@ -173,6 +222,9 @@ class TestSimulationWorkerErrorHandling:
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=clean_db)
|
||||
|
||||
# Mock _prepare_data to return the date
|
||||
worker._prepare_data = Mock(return_value=(["2025-01-16"], []))
|
||||
|
||||
execution_count = 0
|
||||
|
||||
def counting_executor(*args, **kwargs):
|
||||
@@ -181,9 +233,18 @@ class TestSimulationWorkerErrorHandling:
|
||||
executor = Mock()
|
||||
# Second model fails
|
||||
if execution_count == 2:
|
||||
executor.execute.return_value = {"success": False, "error": "Model failed"}
|
||||
executor.execute.return_value = {
|
||||
"success": False,
|
||||
"error": "Model failed",
|
||||
"model": kwargs.get("model_sig", "unknown"),
|
||||
"date": kwargs.get("date", "2025-01-16")
|
||||
}
|
||||
else:
|
||||
executor.execute.return_value = {"success": True}
|
||||
executor.execute.return_value = {
|
||||
"success": True,
|
||||
"model": kwargs.get("model_sig", "unknown"),
|
||||
"date": kwargs.get("date", "2025-01-16")
|
||||
}
|
||||
return executor
|
||||
|
||||
with patch("api.simulation_worker.ModelDayExecutor", side_effect=counting_executor):
|
||||
@@ -206,8 +267,10 @@ class TestSimulationWorkerErrorHandling:
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=clean_db)
|
||||
|
||||
with patch("api.simulation_worker.ModelDayExecutor", side_effect=Exception("Unexpected error")):
|
||||
worker.run()
|
||||
# Mock _prepare_data to raise exception
|
||||
worker._prepare_data = Mock(side_effect=Exception("Unexpected error"))
|
||||
|
||||
worker.run()
|
||||
|
||||
# Check job status
|
||||
job = manager.get_job(job_id)
|
||||
@@ -233,16 +296,27 @@ class TestSimulationWorkerConcurrency:
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=clean_db)
|
||||
|
||||
# Mock _prepare_data to return the date
|
||||
worker._prepare_data = Mock(return_value=(["2025-01-16"], []))
|
||||
|
||||
with patch("api.simulation_worker.ModelDayExecutor") as mock_executor_class:
|
||||
mock_executor = Mock()
|
||||
mock_executor.execute.return_value = {"success": True}
|
||||
mock_executor.execute.return_value = {
|
||||
"success": True,
|
||||
"model": "test-model",
|
||||
"date": "2025-01-16"
|
||||
}
|
||||
mock_executor_class.return_value = mock_executor
|
||||
|
||||
# Mock ThreadPoolExecutor to verify it's being used
|
||||
with patch("api.simulation_worker.ThreadPoolExecutor") as mock_pool:
|
||||
mock_pool_instance = Mock()
|
||||
mock_pool.return_value.__enter__.return_value = mock_pool_instance
|
||||
mock_pool_instance.submit.return_value = Mock(result=lambda: {"success": True})
|
||||
mock_pool_instance.submit.return_value = Mock(result=lambda: {
|
||||
"success": True,
|
||||
"model": "test-model",
|
||||
"date": "2025-01-16"
|
||||
})
|
||||
|
||||
worker.run()
|
||||
|
||||
@@ -274,4 +348,239 @@ class TestSimulationWorkerJobRetrieval:
|
||||
assert job_info["models"] == ["gpt-5"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSimulationWorkerHelperMethods:
|
||||
"""Test worker helper methods."""
|
||||
|
||||
def test_download_price_data_success(self, clean_db):
|
||||
"""Test successful price data download."""
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.database import initialize_database
|
||||
|
||||
db_path = clean_db
|
||||
initialize_database(db_path)
|
||||
|
||||
worker = SimulationWorker(job_id="test-123", db_path=db_path)
|
||||
|
||||
# Mock price manager
|
||||
mock_price_manager = Mock()
|
||||
mock_price_manager.download_missing_data_prioritized.return_value = {
|
||||
"downloaded": ["AAPL", "MSFT"],
|
||||
"failed": [],
|
||||
"rate_limited": False
|
||||
}
|
||||
|
||||
warnings = []
|
||||
missing_coverage = {"AAPL": {"2025-10-01"}, "MSFT": {"2025-10-01"}}
|
||||
|
||||
worker._download_price_data(mock_price_manager, missing_coverage, ["2025-10-01"], warnings)
|
||||
|
||||
# Verify download was called
|
||||
mock_price_manager.download_missing_data_prioritized.assert_called_once()
|
||||
|
||||
# No warnings for successful download
|
||||
assert len(warnings) == 0
|
||||
|
||||
def test_download_price_data_rate_limited(self, clean_db):
|
||||
"""Test price download with rate limit."""
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.database import initialize_database
|
||||
|
||||
db_path = clean_db
|
||||
initialize_database(db_path)
|
||||
|
||||
worker = SimulationWorker(job_id="test-456", db_path=db_path)
|
||||
|
||||
# Mock price manager
|
||||
mock_price_manager = Mock()
|
||||
mock_price_manager.download_missing_data_prioritized.return_value = {
|
||||
"downloaded": ["AAPL"],
|
||||
"failed": ["MSFT"],
|
||||
"rate_limited": True
|
||||
}
|
||||
|
||||
warnings = []
|
||||
missing_coverage = {"AAPL": {"2025-10-01"}, "MSFT": {"2025-10-01"}}
|
||||
|
||||
worker._download_price_data(mock_price_manager, missing_coverage, ["2025-10-01"], warnings)
|
||||
|
||||
# Should add rate limit warning
|
||||
assert len(warnings) == 1
|
||||
assert "Rate limit" in warnings[0]
|
||||
|
||||
def test_filter_completed_dates_all_new(self, clean_db):
|
||||
"""Test filtering when no dates are completed."""
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.database import initialize_database
|
||||
|
||||
db_path = clean_db
|
||||
initialize_database(db_path)
|
||||
|
||||
worker = SimulationWorker(job_id="test-789", db_path=db_path)
|
||||
|
||||
# Mock job_manager to return empty completed dates
|
||||
mock_job_manager = Mock()
|
||||
mock_job_manager.get_completed_model_dates.return_value = {}
|
||||
worker.job_manager = mock_job_manager
|
||||
|
||||
available_dates = ["2025-10-01", "2025-10-02"]
|
||||
models = ["gpt-5"]
|
||||
|
||||
result = worker._filter_completed_dates(available_dates, models)
|
||||
|
||||
# All dates should be returned
|
||||
assert result == available_dates
|
||||
|
||||
def test_filter_completed_dates_some_completed(self, clean_db):
|
||||
"""Test filtering when some dates are completed."""
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.database import initialize_database
|
||||
|
||||
db_path = clean_db
|
||||
initialize_database(db_path)
|
||||
|
||||
worker = SimulationWorker(job_id="test-abc", db_path=db_path)
|
||||
|
||||
# Mock job_manager to return one completed date
|
||||
mock_job_manager = Mock()
|
||||
mock_job_manager.get_completed_model_dates.return_value = {
|
||||
"gpt-5": ["2025-10-01"]
|
||||
}
|
||||
worker.job_manager = mock_job_manager
|
||||
|
||||
available_dates = ["2025-10-01", "2025-10-02", "2025-10-03"]
|
||||
models = ["gpt-5"]
|
||||
|
||||
result = worker._filter_completed_dates(available_dates, models)
|
||||
|
||||
# Should exclude completed date
|
||||
assert result == ["2025-10-02", "2025-10-03"]
|
||||
|
||||
def test_add_job_warnings(self, clean_db):
|
||||
"""Test adding warnings to job via worker."""
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.job_manager import JobManager
|
||||
from api.database import initialize_database
|
||||
import json
|
||||
|
||||
db_path = clean_db
|
||||
initialize_database(db_path)
|
||||
job_manager = JobManager(db_path=db_path)
|
||||
|
||||
# Create job
|
||||
job_id = job_manager.create_job(
|
||||
config_path="config.json",
|
||||
date_range=["2025-10-01"],
|
||||
models=["gpt-5"]
|
||||
)
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=db_path)
|
||||
|
||||
# Add warnings
|
||||
warnings = ["Warning 1", "Warning 2"]
|
||||
worker._add_job_warnings(warnings)
|
||||
|
||||
# Verify warnings were stored
|
||||
job = job_manager.get_job(job_id)
|
||||
assert job["warnings"] is not None
|
||||
stored_warnings = json.loads(job["warnings"])
|
||||
assert stored_warnings == warnings
|
||||
|
||||
def test_prepare_data_no_missing_data(self, clean_db, monkeypatch):
|
||||
"""Test prepare_data when all data is available."""
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.job_manager import JobManager
|
||||
from api.database import initialize_database
|
||||
|
||||
db_path = clean_db
|
||||
initialize_database(db_path)
|
||||
job_manager = JobManager(db_path=db_path)
|
||||
|
||||
# Create job
|
||||
job_id = job_manager.create_job(
|
||||
config_path="config.json",
|
||||
date_range=["2025-10-01"],
|
||||
models=["gpt-5"]
|
||||
)
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=db_path)
|
||||
|
||||
# Mock PriceDataManager
|
||||
mock_price_manager = Mock()
|
||||
mock_price_manager.get_missing_coverage.return_value = {} # No missing data
|
||||
mock_price_manager.get_available_trading_dates.return_value = ["2025-10-01"]
|
||||
|
||||
# Patch PriceDataManager import where it's used
|
||||
def mock_pdm_init(db_path):
|
||||
return mock_price_manager
|
||||
|
||||
monkeypatch.setattr("api.price_data_manager.PriceDataManager", mock_pdm_init)
|
||||
|
||||
# Mock get_completed_model_dates
|
||||
worker.job_manager.get_completed_model_dates = Mock(return_value={})
|
||||
|
||||
# Execute
|
||||
available_dates, warnings = worker._prepare_data(
|
||||
requested_dates=["2025-10-01"],
|
||||
models=["gpt-5"],
|
||||
config_path="config.json"
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert available_dates == ["2025-10-01"]
|
||||
assert len(warnings) == 0
|
||||
|
||||
# Verify status was updated to running
|
||||
job = job_manager.get_job(job_id)
|
||||
assert job["status"] == "running"
|
||||
|
||||
def test_prepare_data_with_download(self, clean_db, monkeypatch):
|
||||
"""Test prepare_data when data needs downloading."""
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.job_manager import JobManager
|
||||
from api.database import initialize_database
|
||||
|
||||
db_path = clean_db
|
||||
initialize_database(db_path)
|
||||
job_manager = JobManager(db_path=db_path)
|
||||
|
||||
job_id = job_manager.create_job(
|
||||
config_path="config.json",
|
||||
date_range=["2025-10-01"],
|
||||
models=["gpt-5"]
|
||||
)
|
||||
|
||||
worker = SimulationWorker(job_id=job_id, db_path=db_path)
|
||||
|
||||
# Mock PriceDataManager
|
||||
mock_price_manager = Mock()
|
||||
mock_price_manager.get_missing_coverage.return_value = {"AAPL": {"2025-10-01"}}
|
||||
mock_price_manager.download_missing_data_prioritized.return_value = {
|
||||
"downloaded": ["AAPL"],
|
||||
"failed": [],
|
||||
"rate_limited": False
|
||||
}
|
||||
mock_price_manager.get_available_trading_dates.return_value = ["2025-10-01"]
|
||||
|
||||
def mock_pdm_init(db_path):
|
||||
return mock_price_manager
|
||||
|
||||
monkeypatch.setattr("api.price_data_manager.PriceDataManager", mock_pdm_init)
|
||||
worker.job_manager.get_completed_model_dates = Mock(return_value={})
|
||||
|
||||
# Execute
|
||||
available_dates, warnings = worker._prepare_data(
|
||||
requested_dates=["2025-10-01"],
|
||||
models=["gpt-5"],
|
||||
config_path="config.json"
|
||||
)
|
||||
|
||||
# Verify download was called
|
||||
mock_price_manager.download_missing_data_prioritized.assert_called_once()
|
||||
|
||||
# Verify status transitions
|
||||
job = job_manager.get_job(job_id)
|
||||
assert job["status"] == "running"
|
||||
|
||||
|
||||
# Coverage target: 90%+ for api/simulation_worker.py
|
||||
|
||||
228
tools/config_merger.py
Normal file
228
tools/config_merger.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Configuration merging and validation for AI-Trader."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ConfigValidationError(Exception):
|
||||
"""Raised when config validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
def load_config(path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load and parse JSON config file.
|
||||
|
||||
Args:
|
||||
path: Path to JSON config file
|
||||
|
||||
Returns:
|
||||
Parsed config dictionary
|
||||
|
||||
Raises:
|
||||
ConfigValidationError: If file not found or invalid JSON
|
||||
"""
|
||||
config_path = Path(path)
|
||||
|
||||
if not config_path.exists():
|
||||
raise ConfigValidationError(f"Config file not found: {path}")
|
||||
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ConfigValidationError(f"Invalid JSON in {path}: {e}")
|
||||
|
||||
|
||||
def merge_configs(default: Dict[str, Any], custom: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge custom config into default config (root-level override).
|
||||
|
||||
Custom config sections completely replace default sections.
|
||||
Does not mutate input dictionaries.
|
||||
|
||||
Args:
|
||||
default: Default configuration dict
|
||||
custom: Custom configuration dict (overrides)
|
||||
|
||||
Returns:
|
||||
Merged configuration dict
|
||||
"""
|
||||
merged = dict(default) # Shallow copy
|
||||
|
||||
for key, value in custom.items():
|
||||
merged[key] = value
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate configuration structure and values.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary to validate
|
||||
|
||||
Raises:
|
||||
ConfigValidationError: If validation fails with detailed message
|
||||
"""
|
||||
# Required top-level fields
|
||||
required_fields = ["agent_type", "models", "agent_config", "log_config"]
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
raise ConfigValidationError(f"Missing required field: '{field}'")
|
||||
|
||||
# Validate models
|
||||
models = config["models"]
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise ConfigValidationError("'models' must be a non-empty array")
|
||||
|
||||
# Check at least one enabled model
|
||||
enabled_models = [m for m in models if m.get("enabled", False)]
|
||||
if not enabled_models:
|
||||
raise ConfigValidationError("At least one model must be enabled")
|
||||
|
||||
# Check required model fields
|
||||
for i, model in enumerate(models):
|
||||
required_model_fields = ["name", "basemodel", "signature", "enabled"]
|
||||
for field in required_model_fields:
|
||||
if field not in model:
|
||||
raise ConfigValidationError(
|
||||
f"Model {i} missing required field: '{field}'"
|
||||
)
|
||||
|
||||
# Check for duplicate signatures
|
||||
signatures = [m["signature"] for m in models]
|
||||
if len(signatures) != len(set(signatures)):
|
||||
duplicates = [s for s in signatures if signatures.count(s) > 1]
|
||||
raise ConfigValidationError(
|
||||
f"Duplicate model signature: {duplicates[0]}"
|
||||
)
|
||||
|
||||
# Validate agent_config
|
||||
agent_config = config["agent_config"]
|
||||
|
||||
if "max_steps" in agent_config:
|
||||
if agent_config["max_steps"] <= 0:
|
||||
raise ConfigValidationError("max_steps must be > 0")
|
||||
|
||||
if "max_retries" in agent_config:
|
||||
if agent_config["max_retries"] < 0:
|
||||
raise ConfigValidationError("max_retries must be >= 0")
|
||||
|
||||
if "initial_cash" in agent_config:
|
||||
if agent_config["initial_cash"] <= 0:
|
||||
raise ConfigValidationError("initial_cash must be > 0")
|
||||
|
||||
# Validate date_range if present (optional)
|
||||
if "date_range" in config:
|
||||
date_range = config["date_range"]
|
||||
|
||||
if "init_date" in date_range:
|
||||
try:
|
||||
init_dt = datetime.strptime(date_range["init_date"], "%Y-%m-%d")
|
||||
except ValueError:
|
||||
raise ConfigValidationError(
|
||||
f"Invalid date format for init_date: {date_range['init_date']}. "
|
||||
"Expected YYYY-MM-DD"
|
||||
)
|
||||
|
||||
if "end_date" in date_range:
|
||||
try:
|
||||
end_dt = datetime.strptime(date_range["end_date"], "%Y-%m-%d")
|
||||
except ValueError:
|
||||
raise ConfigValidationError(
|
||||
f"Invalid date format for end_date: {date_range['end_date']}. "
|
||||
"Expected YYYY-MM-DD"
|
||||
)
|
||||
|
||||
# Check init <= end
|
||||
if "init_date" in date_range and "end_date" in date_range:
|
||||
if init_dt > end_dt:
|
||||
raise ConfigValidationError(
|
||||
f"init_date must be <= end_date (got {date_range['init_date']} > {date_range['end_date']})"
|
||||
)
|
||||
|
||||
|
||||
# File path constants (can be overridden for testing)
|
||||
DEFAULT_CONFIG_PATH = "configs/default_config.json"
|
||||
CUSTOM_CONFIG_PATH = "user-configs/config.json"
|
||||
OUTPUT_CONFIG_PATH = "/tmp/runtime_config.json"
|
||||
|
||||
|
||||
def format_error_message(error: str, location: str, file: str) -> str:
|
||||
"""Format validation error for display."""
|
||||
border = "━" * 60
|
||||
return f"""
|
||||
❌ CONFIG VALIDATION FAILED
|
||||
{border}
|
||||
|
||||
Error: {error}
|
||||
Location: {location}
|
||||
File: {file}
|
||||
|
||||
Merged config written to: {OUTPUT_CONFIG_PATH} (for debugging)
|
||||
|
||||
Container will exit. Fix config and restart.
|
||||
"""
|
||||
|
||||
|
||||
def merge_and_validate() -> None:
|
||||
"""
|
||||
Main entry point for config merging and validation.
|
||||
|
||||
Loads default config, optionally merges custom config,
|
||||
validates the result, and writes to output path.
|
||||
|
||||
Exits with code 1 on any error.
|
||||
"""
|
||||
try:
|
||||
# Load default config
|
||||
print(f"📄 Loading default config from {DEFAULT_CONFIG_PATH}")
|
||||
default_config = load_config(DEFAULT_CONFIG_PATH)
|
||||
|
||||
# Load custom config if exists
|
||||
custom_config = {}
|
||||
if Path(CUSTOM_CONFIG_PATH).exists():
|
||||
print(f"📝 Loading custom config from {CUSTOM_CONFIG_PATH}")
|
||||
custom_config = load_config(CUSTOM_CONFIG_PATH)
|
||||
else:
|
||||
print(f"ℹ️ No custom config found at {CUSTOM_CONFIG_PATH}, using defaults")
|
||||
|
||||
# Merge configs
|
||||
print("🔧 Merging configurations...")
|
||||
merged_config = merge_configs(default_config, custom_config)
|
||||
|
||||
# Write merged config (for debugging even if validation fails)
|
||||
output_path = Path(OUTPUT_CONFIG_PATH)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(merged_config, f, indent=2)
|
||||
|
||||
# Validate merged config
|
||||
print("✅ Validating merged configuration...")
|
||||
validate_config(merged_config)
|
||||
|
||||
print(f"✅ Configuration validated successfully")
|
||||
print(f"📦 Merged config written to {OUTPUT_CONFIG_PATH}")
|
||||
|
||||
except ConfigValidationError as e:
|
||||
# Determine which file caused the error
|
||||
error_file = CUSTOM_CONFIG_PATH if Path(CUSTOM_CONFIG_PATH).exists() else DEFAULT_CONFIG_PATH
|
||||
|
||||
error_msg = format_error_message(
|
||||
error=str(e),
|
||||
location="Root level",
|
||||
file=error_file
|
||||
)
|
||||
|
||||
print(error_msg, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error during config processing: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user