Compare commits

..

12 Commits

Author SHA1 Message Date
1347e3939f docs: add web UI feature to v0.4.0 roadmap
Add comprehensive web dashboard interface to planned features for v0.4.0.

Web UI Features:
- Job management dashboard
  * View/monitor active, pending, and completed jobs
  * Start new simulations with form-based configuration
  * Real-time job progress monitoring
  * Cancel running jobs

- Results visualization
  * Performance charts (P&L over time, cumulative returns)
  * Position history timeline
  * Model comparison views
  * Trade log explorer with filtering

- Configuration management
  * Model configuration editor
  * Date range selection with calendar picker
  * Price data coverage visualization

- Technical implementation
  * Modern frontend framework (React, Vue.js, or Svelte)
  * Real-time updates via WebSocket or Server-Sent Events
  * Responsive design for mobile access
  * Chart library for visualizations
  * Single container deployment alongside API

The web UI will provide an accessible interface for users who prefer
graphical interaction over API calls, while maintaining the same
functionality available through the REST API.

Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-31 17:22:31 -04:00
4b25ae96c2 docs: simplify roadmap to focus on v0.4.0 only
Remove all future releases (v0.5.0-v0.7.0) and infrastructure/enhancement
sections from roadmap. Focus exclusively on v0.4.0 planned features.

v0.4.0 - Enhanced Simulation Management remains with:
- Resume/continue API for advancing from last completed date
- Position history tracking and analysis
- Advanced performance metrics (Sharpe, Sortino, drawdown, win rate)
- Price data management endpoints

Removed sections:
- v0.5.0 Real-Time Trading Support
- v0.6.0 Multi-Strategy & Portfolio Management
- v0.7.0 Alternative Data & Advanced Features
- Future Enhancements (infrastructure, data, UI, AI/ML, integration, testing)

Keep roadmap focused on near-term deliverables with clear scope.

Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-31 17:20:48 -04:00
5606df1f51 docs: add comprehensive roadmap for future development
Create ROADMAP.md documenting planned features across multiple releases.

Planned releases:
- v0.4.0: Enhanced simulation management
  * Resume/continue API for advancing from last completed date
  * Position history tracking and analysis
  * Advanced performance metrics (Sharpe, Sortino, drawdown)
  * Price data management endpoints

- v0.5.0: Real-time trading support
  * Live market data integration
  * Real-time simulation mode
  * Scheduled automation
  * WebSocket price feeds

- v0.6.0: Multi-strategy & portfolio management
  * Strategy composition and ensembles
  * Advanced risk controls
  * Portfolio-level optimization
  * Dynamic allocation

- v0.7.0: Alternative data & advanced features
  * News and sentiment analysis
  * Market regime detection
  * Custom indicators
  * Event-driven strategies

Future enhancements:
- Kubernetes deployment and cloud provider support
- Alternative databases (PostgreSQL, TimescaleDB)
- Web UI dashboard with real-time visualization
- Model training and reinforcement learning
- Webhook notifications and plugin system
- Performance and chaos testing

Key feature: Resume API in v0.4.0
- POST /simulate/resume - Continue from last completed date
- POST /simulate/continue - Extend existing simulations
- Automatic detection of completion state per model
- Support for daily incremental updates

Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-31 17:18:55 -04:00
02c8a48b37 docs: improve CHANGELOG to reflect actual v0.2.0 baseline
Clarify that v0.3.0 is the first version with REST API functionality,
and remove misleading "API Request Format Changed" entries that implied
the API existed in v0.2.0.

Key improvements:
- Remove "API Request Format Changed" from Changed section (API is new)
- Remove "Model Selection" and "API Interface" items (API design, not changes)
- Clarify batch mode removal context (v0.2.0 had batch, v0.3.0 adds API)
- Update test counts to reflect new tests (175 total, up from 102)
- Add coverage details for new test files (date_utils, price_data_manager)
- Update test execution time estimate (~12 seconds for full suite)

Breaking changes now correctly identify what changed from v0.2.0:
- Batch execution replaced with REST API (new capability)
- Price data storage moved from JSONL to SQLite (migration required)
- Configuration variables added/removed for new features

v0.2.0 was Docker-focused with batch execution
v0.3.0 adds REST API, on-demand downloads, and database storage

Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-31 17:15:50 -04:00
c3ea358a12 test: add comprehensive test suite for v0.3.0 on-demand price downloads
Add 64 new tests covering date utilities, price data management, and
on-demand download workflows with 100% coverage for date_utils and 85%
coverage for price_data_manager.

New test files:
- tests/unit/test_date_utils.py (22 tests)
  * Date range expansion and validation
  * Max simulation days configuration
  * Chronological ordering and boundary checks
  * 100% coverage of api/date_utils.py

- tests/unit/test_price_data_manager.py (33 tests)
  * Initialization and configuration
  * Symbol date retrieval and coverage detection
  * Priority-based download ordering
  * Rate limit and error handling
  * Data storage and coverage tracking
  * 85% coverage of api/price_data_manager.py

- tests/integration/test_on_demand_downloads.py (10 tests)
  * End-to-end download workflows
  * Rate limit handling with graceful degradation
  * Coverage tracking and gap detection
  * Data validation and filtering

Code improvements:
- Add DownloadError exception class for non-rate-limit failures
- Update all ValueError raises to DownloadError for consistency
- Add API key validation at download start
- Improve response validation to check for Meta Data

Test coverage:
- 64 tests passing (54 unit + 10 integration)
- api/date_utils.py: 100% coverage
- api/price_data_manager.py: 85% coverage
- Validates priority-first download strategy
- Confirms graceful rate limit handling
- Verifies database storage and retrieval

Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-31 17:13:03 -04:00
1bfcdd78b8 feat: complete v0.3.0 database migration and configuration
Final phase of v0.3.0 implementation - all core features complete.

Price Tools Migration:
- Update get_open_prices() to query price_data table
- Update get_yesterday_open_and_close_price() to query database
- Remove merged.jsonl file I/O (replaced with SQLite queries)
- Maintain backward-compatible function signatures
- Add db_path parameter (default: data/jobs.db)

Configuration:
- Add AUTO_DOWNLOAD_PRICE_DATA to .env.example (default: true)
- Add MAX_SIMULATION_DAYS to .env.example (default: 30)
- Document new configuration options

Documentation:
- Comprehensive CHANGELOG updates for v0.3.0
- Document all breaking changes (API format, data storage, config)
- Document new features (on-demand downloads, date ranges, database)
- Document migration path (scripts/migrate_price_data.py)
- Clear upgrade instructions

Breaking Changes (v0.3.0):
1. API request format: date_range -> start_date/end_date
2. Data storage: merged.jsonl -> price_data table
3. Config variables: removed RUNTIME_ENV_PATH, MCP ports, WEB_HTTP_PORT
4. Added AUTO_DOWNLOAD_PRICE_DATA, MAX_SIMULATION_DAYS

Migration Steps:
1. Run: python scripts/migrate_price_data.py
2. Update API clients to use new date format
3. Update .env with new variables
4. Remove old config variables

Status: v0.3.0 implementation complete
Ready for: Testing, deployment, and release
2025-10-31 16:44:46 -04:00
76b946449e feat: implement date range API and on-demand downloads (WIP phase 2)
Phase 2 progress - API integration complete.

API Changes:
- Replace date_range (List[str]) with start_date/end_date (str)
- Add automatic end_date defaulting to start_date for single day
- Add date format validation
- Integrate PriceDataManager for on-demand downloads
- Add rate limit handling (trusts provider, no pre-config)
- Validate date ranges with configurable max days (MAX_SIMULATION_DAYS)

New Modules:
- api/date_utils.py - Date validation and expansion utilities
- scripts/migrate_price_data.py - Migration script for merged.jsonl

API Flow:
1. Validate date range (start <= end, max 30 days, not future)
2. Check missing price data coverage
3. Download missing data if AUTO_DOWNLOAD_PRICE_DATA=true
4. Priority-based download (maximize date completion)
5. Create job with available trading dates
6. Graceful handling of partial data (rate limits)

Configuration:
- AUTO_DOWNLOAD_PRICE_DATA (default: true)
- MAX_SIMULATION_DAYS (default: 30)
- No rate limit configuration needed

Still TODO:
- Update tools/price_tools.py to read from database
- Implement simulation run tracking
- Update .env.example
- Comprehensive testing
- Documentation updates

Breaking Changes:
- API request format changed (date_range -> start_date/end_date)
- This completes v0.3.0 preparation
2025-10-31 16:40:50 -04:00
bddf4d8b72 feat: add price data management infrastructure (WIP)
Phase 1 of v0.3.0 date range and on-demand download implementation.

Database changes:
- Add price_data table (OHLCV data, replaces merged.jsonl)
- Add price_data_coverage table (track downloaded date ranges)
- Add simulation_runs table (soft delete support)
- Add simulation_run_id to positions table
- Add comprehensive indexes for new tables

New modules:
- api/price_data_manager.py - Priority-based download manager
  - Coverage gap detection
  - Smart download prioritization (maximize date completion)
  - Rate limit handling with retry logic
  - Alpha Vantage integration

Configuration:
- configs/nasdaq100_symbols.json - NASDAQ 100 constituent list

Next steps (not yet implemented):
- Migration script for merged.jsonl -> price_data
- Update API models (start_date/end_date)
- Update tools/price_tools.py to read from database
- Simulation run tracking implementation
- API integration
- Tests and documentation

This is work in progress for the v0.3.0 release.
2025-10-31 16:37:14 -04:00
8e7e80807b refactor: remove config_path from API interface
Makes config_path an internal server detail rather than an API parameter.

Changes:
- Remove config_path from SimulateTriggerRequest
- Add config_path parameter to create_app() with default
- Store in app.state.config_path for internal use
- Update trigger endpoint to use internal config path
- Change missing config error from 400 to 500 (server error)

API calls now only need to specify date_range (and optionally models):
  POST /simulate/trigger
  {"date_range": ["2025-01-16"]}

The server uses configs/default_config.json by default.
This simplifies the API and hides implementation details from clients.
2025-10-31 15:18:56 -04:00
ec2a37e474 feat: use enabled field from config to determine which models run
Changed the API to respect the 'enabled' field in model configurations,
rather than requiring models to be explicitly specified in API requests.

Changes:
- Make 'models' parameter optional in POST /simulate/trigger
- If models not provided, read config and use enabled models
- If models provided, use as explicit override (for testing)
- Raise error if no enabled models found and none specified
- Update response message to show model count

Behavior:
- Default: Only runs models with "enabled": true in config
- Override: Can still specify models in request for manual testing
- Safety: Prevents accidental execution of disabled/expensive models

Example before (required):
  POST /simulate/trigger
  {"config_path": "...", "date_range": [...], "models": ["gpt-4"]}

Example after (optional):
  POST /simulate/trigger
  {"config_path": "...", "date_range": [...]}
  # Uses models where enabled: true

This makes the config file the source of truth for which models
should run, while still allowing ad-hoc overrides for testing.
2025-10-31 15:12:11 -04:00
20506a379d docs: rewrite README for API-first architecture
Complete rewrite of README.md to reflect the new REST API service
architecture and remove batch mode references.

Changes:
- Focus on REST API deployment and usage
- Updated architecture diagram showing FastAPI → Worker → Database flow
- Comprehensive API endpoint documentation with examples
- Docker-first quick start guide
- Integration examples (Windmill.dev, Python client)
- Database schema documentation
- Simplified configuration guide
- Updated project structure
- Removed batch mode references
- Removed web UI mentions

The new README positions AI-Trader as an API service for autonomous
trading simulations, not a standalone batch application.

Key additions:
- Complete API reference (/trigger, /status, /results, /health)
- Integration patterns for external orchestration
- Database querying examples
- Testing and validation procedures
- Production deployment guidance
2025-10-31 14:57:29 -04:00
246dbd1b34 refactor: remove unused web UI port configuration
The web UI (docs/index.html, portfolio.html) exists but is not served
in API mode. Removing the port configuration to eliminate confusion.

Changes:
- Remove port 8888 mapping from docker-compose.yml
- Remove WEB_HTTP_PORT from .env.example
- Update Dockerfile EXPOSE to only port 8080
- Update CHANGELOG.md to document removal

Technical details:
- Web UI static files remain in docs/ folder (legacy from batch mode)
- These were designed for JSONL file format, not the new SQLite database
- No web server was ever started in entrypoint.sh for API mode
- Port 8888 was exposed but nothing listened on it

Result:
- Cleaner configuration (1 fewer port mapping)
- Only REST API (8080) is exposed
- Eliminates user confusion about non-functional web UI
2025-10-31 14:54:10 -04:00
17 changed files with 2892 additions and 1282 deletions

View File

@@ -21,13 +21,19 @@ JINA_API_KEY=your_jina_key_here # https://jina.ai/
# Used for Windmill integration and external API access
API_PORT=8080
# Web Interface Host Port (exposed on host machine)
# Container always uses 8888 internally
WEB_HTTP_PORT=8888
# Agent Configuration
AGENT_MAX_STEP=30
# Simulation Configuration
# Maximum number of days allowed in a single simulation range
# Prevents accidentally requesting very large date ranges
MAX_SIMULATION_DAYS=30
# Price Data Configuration
# Automatically download missing price data from Alpha Vantage when needed
# If disabled, all price data must be pre-populated in the database
AUTO_DOWNLOAD_PRICE_DATA=true
# Data Volume Configuration
# Base directory for all persistent data (will contain data/, logs/, configs/ subdirectories)
# Use relative paths (./volumes) or absolute paths (/home/user/ai-trader-volumes)

View File

@@ -9,6 +9,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.3.0] - 2025-10-31
### Added - Price Data Management & On-Demand Downloads
- **SQLite Price Data Storage** - Replaced JSONL files with relational database
- `price_data` table for OHLCV data (replaces merged.jsonl)
- `price_data_coverage` table for tracking downloaded date ranges
- `simulation_runs` table for soft-delete position tracking
- Comprehensive indexes for query performance
- **On-Demand Price Data Downloads** - Automatic gap filling via Alpha Vantage
- Priority-based download strategy (maximize date completion)
- Graceful rate limit handling (no pre-configured limits needed)
- Smart coverage gap detection
- Configurable via `AUTO_DOWNLOAD_PRICE_DATA` (default: true)
- **Date Range API** - Simplified date specification
- Single date: `{"start_date": "2025-01-20"}`
- Date range: `{"start_date": "2025-01-20", "end_date": "2025-01-24"}`
- Automatic validation (chronological order, max range, not future)
- Configurable max days via `MAX_SIMULATION_DAYS` (default: 30)
- **Migration Tooling** - Script to import existing merged.jsonl data
- `scripts/migrate_price_data.py` for one-time data migration
- Automatic coverage tracking during migration
### Added - API Service Transformation
- **REST API Service** - Complete FastAPI implementation for external orchestration
- `POST /simulate/trigger` - Trigger simulation jobs with config, date range, and models
@@ -28,13 +48,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- ModelDayExecutor - Single model-day execution engine
- SimulationWorker - Job orchestration with date-sequential, model-parallel execution
- **Comprehensive Test Suite**
- 102 unit and integration tests (85% coverage)
- 175 unit and integration tests
- 19 database tests (98% coverage)
- 23 job manager tests (98% coverage)
- 10 model executor tests (84% coverage)
- 20 API endpoint tests (81% coverage)
- 20 Pydantic model tests (100% coverage)
- 10 runtime manager tests (89% coverage)
- 22 date utilities tests (100% coverage)
- 33 price data manager tests (85% coverage)
- 10 on-demand download integration tests
- 8 existing integration tests
- **Docker Deployment** - Persistent REST API service
- API-only deployment (batch mode removed for simplicity)
- Single docker-compose service (ai-trader)
@@ -55,14 +79,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- **Architecture** - Transformed from batch-only to API-first service with database persistence
- **Data Storage** - Migrated from JSONL files to SQLite relational database
- **Deployment** - Simplified to single API-only Docker service
- Price data now stored in `price_data` table instead of `merged.jsonl`
- Tools/price_tools.py updated to query database
- Position data remains in database (already migrated in earlier versions)
- **Deployment** - Simplified to single API-only Docker service (REST API is new in v0.3.0)
- **Configuration** - Simplified environment variable configuration
- Added configurable API_PORT for host port mapping (default: 8080, customizable for port conflicts)
- Removed `RUNTIME_ENV_PATH` (API dynamically manages runtime configs via RuntimeConfigManager)
- Removed MCP service port configuration (MATH_HTTP_PORT, SEARCH_HTTP_PORT, TRADE_HTTP_PORT, GETPRICE_HTTP_PORT)
- **Added:** `AUTO_DOWNLOAD_PRICE_DATA` (default: true) - Enable on-demand downloads
- **Added:** `MAX_SIMULATION_DAYS` (default: 30) - Maximum date range size
- **Added:** `API_PORT` for host port mapping (default: 8080, customizable for port conflicts)
- **Removed:** `RUNTIME_ENV_PATH` (API dynamically manages runtime configs)
- **Removed:** MCP service ports (MATH_HTTP_PORT, SEARCH_HTTP_PORT, TRADE_HTTP_PORT, GETPRICE_HTTP_PORT)
- **Removed:** `WEB_HTTP_PORT` (web UI not implemented)
- MCP services use fixed internal ports (8000-8003) and are no longer exposed to host
- Container always uses port 8080 internally for API (hardcoded in entrypoint.sh)
- Only API port (8080) and web dashboard (8888) are exposed to host
- Container always uses port 8080 internally for API
- Only API port (8080) is exposed to host
- Reduces configuration complexity and attack surface
- **Requirements** - Added fastapi>=0.120.0, uvicorn[standard]>=0.27.0, pydantic>=2.0.0
- **Docker Compose** - Single service (ai-trader) instead of dual-mode
@@ -80,15 +110,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Automatic Status Transitions** - Job status updates based on model-day completion
### Performance & Quality
- **Code Coverage** - 85% overall (84.63% measured)
- **Test Suite** - 175 tests, all passing
- Unit tests: 155 tests
- Integration tests: 18 tests
- API tests: 20+ tests
- **Code Coverage** - High coverage for new modules
- Date utilities: 100%
- Price data manager: 85%
- Database layer: 98%
- Job manager: 98%
- Pydantic models: 100%
- Runtime manager: 89%
- Model executor: 84%
- FastAPI app: 81%
- **Test Execution** - 102 tests in ~2.5 seconds
- **Zero Test Failures** - All tests passing (threading tests excluded)
- **Test Execution** - Fast test suite (~12 seconds for full suite)
### Integration Ready
- **Windmill.dev** - HTTP-based integration with polling support
@@ -98,9 +133,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Breaking Changes
- **Batch Mode Removed** - All simulations now run through REST API
- Simplifies deployment and eliminates dual-mode complexity
- Focus on API-first architecture for external orchestration
- Migration: Use POST /simulate/trigger endpoint instead of batch execution
- v0.2.0 used sequential batch execution via Docker entrypoint
- v0.3.0 introduces REST API for external orchestration
- Migration: Use `POST /simulate/trigger` endpoint instead of direct script execution
- **Data Storage Format Changed** - Price data moved from JSONL to SQLite
- Run `python scripts/migrate_price_data.py` to migrate existing merged.jsonl data
- `merged.jsonl` no longer used (replaced by `price_data` table)
- Automatic on-demand downloads eliminate need for manual data fetching
- **Configuration Variables Changed**
- Added: `AUTO_DOWNLOAD_PRICE_DATA`, `MAX_SIMULATION_DAYS`, `API_PORT`
- Removed: `RUNTIME_ENV_PATH`, MCP service ports, `WEB_HTTP_PORT`
- MCP services now use fixed internal ports (not exposed to host)
## [0.2.0] - 2025-10-31

View File

@@ -33,8 +33,8 @@ RUN mkdir -p data logs data/agent_data
# Make entrypoint executable
RUN chmod +x entrypoint.sh
# Expose MCP service ports, API server, and web dashboard
EXPOSE 8000 8001 8002 8003 8080 8888
# Expose API server port (MCP services are internal only)
EXPOSE 8080
# Set Python to run unbuffered for real-time logs
ENV PYTHONUNBUFFERED=1

978
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -1,584 +0,0 @@
<div align="center">
# 🚀 AI-Trader: Which LLM Rules the Market?
### *让AI在金融市场中一展身手*
[![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)](https://python.org)
[![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
**一个AI股票交易代理系统让多个大语言模型在纳斯达克100股票池中完全自主决策、同台竞技**
## 🏆 当前锦标赛排行榜
[*点击查看*](https://hkuds.github.io/AI-Trader/)
<div align="center">
### 🥇 **锦标赛期间:(Last Update 2025/10/29)**
| 🏆 Rank | 🤖 AI Model | 📈 Total Earnings |
|---------|-------------|----------------|
| **🥇 1st** | **DeepSeek** | 🚀 +16.46% |
| 🥈 2nd | MiniMax-M2 | 📊 +12.03% |
| 🥉 3rd | GPT-5 | 📊 +9.98% |
| 4th | Claude-3.7 | 📊 +9.80% |
| 5th | Qwen3-max | 📊 +7.96% |
| Baseline | QQQ | 📊 +5.39% |
| 6th | Gemini-2.5-flash | 📊 +0.48% |
### 📊 **实时性能仪表板**
![rank](assets/rank.png)
*每日追踪AI模型在纳斯达克100交易中的表现*
</div>
---
## 📝 本周更新计划
我们很高兴宣布以下更新将在本周内上线:
-**小时级别交易支持** - 升级至小时级精度交易
- 🚀 **服务部署与并行执行** - 部署生产服务 + 并行模型执行
- 🎨 **增强前端仪表板** - 添加详细的交易日志可视化(完整交易过程展示)
敬请期待这些激动人心的改进!🎉
---
> 🎯 **核心特色**: 100% AI自主决策零人工干预纯工具驱动架构
[🚀 快速开始](#-快速开始) • [📈 性能分析](#-性能分析) • [🛠️ 配置指南](#-配置指南)
</div>
---
## 🌟 项目介绍
> **AI-Trader让五个不同的AI模型每个都采用独特的投资策略在同一个市场中完全自主决策、竞争看谁能在纳斯达克100交易中赚得最多**
### 🎯 核心特性
- 🤖 **完全自主决策**: AI代理100%独立分析、决策、执行,零人工干预
- 🛠️ **纯工具驱动架构**: 基于MCP工具链AI通过标准化工具调用完成所有交易操作
- 🏆 **多模型竞技场**: 部署多个AI模型GPT、Claude、Qwen等进行竞争性交易
- 📊 **实时性能分析**: 完整的交易记录、持仓监控和盈亏分析
- 🔍 **智能市场情报**: 集成Jina搜索获取实时市场新闻和财务报告
-**MCP工具链集成**: 基于Model Context Protocol的模块化工具生态系统
- 🔌 **可扩展策略框架**: 支持第三方策略和自定义AI代理集成
-**历史回放功能**: 时间段回放功能,自动过滤未来信息
---
### 🎮 交易环境
每个AI模型以$10,000起始资金在受控环境中交易纳斯达克100股票使用真实市场数据和历史回放功能。
- 💰 **初始资金**: $10,000美元起始余额
- 📈 **交易范围**: 纳斯达克100成分股100只顶级科技股
-**交易时间**: 工作日市场时间,支持历史模拟
- 📊 **数据集成**: Alpha Vantage API结合Jina AI市场情报
- 🔄 **时间管理**: 历史期间回放,自动过滤未来信息
---
### 🧠 智能交易能力
AI代理完全自主运行进行市场研究、制定交易决策并在无人干预的情况下持续优化策略。
- 📰 **自主市场研究**: 智能检索和过滤市场新闻、分析师报告和财务数据
- 💡 **独立决策引擎**: 多维度分析驱动完全自主的买卖执行
- 📝 **全面交易记录**: 自动记录交易理由、执行细节和投资组合变化
- 🔄 **自适应策略演进**: 基于市场表现反馈自我优化的算法
---
### 🏁 竞赛规则
所有AI模型在相同条件下竞争使用相同的资金、数据访问、工具和评估指标确保公平比较。
- 💰 **起始资金**: $10,000美元初始投资
- 📊 **数据访问**: 统一的市场数据和信息源
-**运行时间**: 同步的交易时间窗口
- 📈 **性能指标**: 所有模型的标准评估标准
- 🛠️ **工具访问**: 所有参与者使用相同的MCP工具链
🎯 **目标**: 确定哪个AI模型通过纯自主操作获得卓越的投资回报
### 🚫 零人工干预
AI代理完全自主运行在没有任何人工编程、指导或干预的情况下制定所有交易决策和策略调整。
-**无预编程**: 零预设交易策略或算法规则
-**无人工输入**: 完全依赖内在的AI推理能力
-**无手动覆盖**: 交易期间绝对禁止人工干预
-**纯工具执行**: 所有操作仅通过标准化工具调用执行
-**自适应学习**: 基于市场表现反馈的独立策略优化
---
## ⏰ 历史回放架构
AI-Trader Bench的核心创新是其**完全可重放**的交易环境确保AI代理在历史市场数据上的性能评估具有科学严谨性和可重复性。
### 🔄 时间控制框架
#### 📅 灵活的时间设置
```json
{
"date_range": {
"init_date": "2025-01-01", // 任意开始日期
"end_date": "2025-01-31" // 任意结束日期
}
}
```
---
### 🛡️ 防前瞻数据控制
AI只能访问当前时间及之前的数据。不允许未来信息。
- 📊 **价格数据边界**: 市场数据访问限制在模拟时间戳和历史记录
- 📰 **新闻时间线执行**: 实时过滤防止访问未来日期的新闻和公告
- 📈 **财务报告时间线**: 信息限制在模拟当前日期的官方发布数据
- 🔍 **历史情报范围**: 市场分析限制在时间上适当的数据可用性
### 🎯 重放优势
#### 🔬 实证研究框架
- 📊 **市场效率研究**: 评估AI在不同市场条件和波动制度下的表现
- 🧠 **决策一致性分析**: 检查AI交易逻辑的时间稳定性和行为模式
- 📈 **风险管理评估**: 验证AI驱动的风险缓解策略的有效性
#### 🎯 公平竞赛框架
- 🏆 **平等信息访问**: 所有AI模型使用相同的历史数据集运行
- 📊 **标准化评估**: 使用统一数据源计算的性能指标
- 🔍 **完全可重复性**: 具有可验证结果的完整实验透明度
---
## 📁 项目架构
```
AI-Trader Bench/
├── 🤖 核心系统
│ ├── main.py # 🎯 主程序入口
│ ├── agent/base_agent/ # 🧠 AI代理核心
│ └── configs/ # ⚙️ 配置文件
├── 🛠️ MCP工具链
│ ├── agent_tools/
│ │ ├── tool_trade.py # 💰 交易执行
│ │ ├── tool_get_price_local.py # 📊 价格查询
│ │ ├── tool_jina_search.py # 🔍 信息搜索
│ │ └── tool_math.py # 🧮 数学计算
│ └── tools/ # 🔧 辅助工具
├── 📊 数据系统
│ ├── data/
│ │ ├── daily_prices_*.json # 📈 股票价格数据
│ │ ├── merged.jsonl # 🔄 统一数据格式
│ │ └── agent_data/ # 📝 AI交易记录
│ └── calculate_performance.py # 📈 性能分析
├── 🎨 前端界面
│ └── frontend/ # 🌐 Web仪表板
└── 📋 配置与文档
├── configs/ # ⚙️ 系统配置
├── prompts/ # 💬 AI提示词
└── calc_perf.sh # 🚀 性能计算脚本
```
### 🔧 核心组件详解
#### 🎯 主程序 (`main.py`)
- **多模型并发**: 同时运行多个AI模型进行交易
- **配置管理**: 支持JSON配置文件和环境变量
- **日期管理**: 灵活的交易日历和日期范围设置
- **错误处理**: 完善的异常处理和重试机制
#### 🛠️ MCP工具链
| 工具 | 功能 | API |
|------|------|-----|
| **交易工具** | 买入/卖出股票,持仓管理 | `buy()`, `sell()` |
| **价格工具** | 实时和历史价格查询 | `get_price_local()` |
| **搜索工具** | 市场信息搜索 | `get_information()` |
| **数学工具** | 财务计算和分析 | 基础数学运算 |
#### 📊 数据系统
- **📈 价格数据**: 纳斯达克100成分股的完整OHLCV数据
- **📝 交易记录**: 每个AI模型的详细交易历史
- **📊 性能指标**: 夏普比率、最大回撤、年化收益等
- **🔄 数据同步**: 自动化的数据获取和更新机制
## 🚀 快速开始
### 📋 前置要求
- **Python 3.10+**
- **API密钥**: OpenAI、Alpha Vantage、Jina AI
### ⚡ 一键安装
```bash
# 1. 克隆项目
git clone https://github.com/HKUDS/AI-Trader.git
cd AI-Trader
# 2. 安装依赖
pip install -r requirements.txt
# 3. 配置环境变量
cp .env.example .env
# 编辑 .env 文件填入你的API密钥
```
### 🔑 环境配置
创建 `.env` 文件并配置以下变量:
```bash
# 🤖 AI模型API配置
OPENAI_API_BASE=https://your-openai-proxy.com/v1
OPENAI_API_KEY=your_openai_key
# 📊 数据源配置
ALPHAADVANTAGE_API_KEY=your_alpha_vantage_key
JINA_API_KEY=your_jina_api_key
# ⚙️ 系统配置
RUNTIME_ENV_PATH=./runtime_env.json #推荐使用绝对路径
# 🌐 服务端口配置
MATH_HTTP_PORT=8000
SEARCH_HTTP_PORT=8001
TRADE_HTTP_PORT=8002
GETPRICE_HTTP_PORT=8003
# 🧠 AI代理配置
AGENT_MAX_STEP=30 # 最大推理步数
```
### 📦 依赖包
```bash
# 安装生产环境依赖
pip install -r requirements.txt
# 或手动安装核心依赖
pip install langchain langchain-openai langchain-mcp-adapters fastmcp python-dotenv requests numpy pandas
```
## 🎮 运行指南
### 📊 步骤1: 数据准备 (`./fresh_data.sh`)
```bash
# 📈 获取纳斯达克100股票数据
cd data
python get_daily_price.py
# 🔄 合并数据为统一格式
python merge_jsonl.py
```
### 🛠️ 步骤2: 启动MCP服务
```bash
cd ./agent_tools
python start_mcp_services.py
```
### 🚀 步骤3: 启动AI竞技场
```bash
# 🎯 运行主程序 - 让AI们开始交易
python main.py
# 🎯 或使用自定义配置
python main.py configs/my_config.json
```
### ⏰ 时间设置示例
#### 📅 创建自定义时间配置
```json
{
"agent_type": "BaseAgent",
"date_range": {
"init_date": "2024-01-01", // 回测开始日期
"end_date": "2024-03-31" // 回测结束日期
},
"models": [
{
"name": "claude-3.7-sonnet",
"basemodel": "anthropic/claude-3.7-sonnet",
"signature": "claude-3.7-sonnet",
"enabled": true
}
]
}
```
### 📈 启动Web界面
```bash
cd docs
python3 -m http.server 8000
# 访问 http://localhost:8000
```
## 📈 性能分析
### 🏆 竞技规则
| 规则项 | 设置 | 说明 |
|--------|------|------|
| **💰 初始资金** | $10,000 | 每个AI模型起始资金 |
| **📈 交易标的** | 纳斯达克100 | 100只顶级科技股 |
| **⏰ 交易时间** | 工作日 | 周一至周五 |
| **💲 价格基准** | 开盘价 | 使用当日开盘价交易 |
| **📝 记录方式** | JSONL格式 | 完整交易历史记录 |
## ⚙️ 配置指南
### 📋 配置文件结构
```json
{
"agent_type": "BaseAgent",
"date_range": {
"init_date": "2025-01-01",
"end_date": "2025-01-31"
},
"models": [
{
"name": "claude-3.7-sonnet",
"basemodel": "anthropic/claude-3.7-sonnet",
"signature": "claude-3.7-sonnet",
"enabled": true
}
],
"agent_config": {
"max_steps": 30,
"max_retries": 3,
"base_delay": 1.0,
"initial_cash": 10000.0
},
"log_config": {
"log_path": "./data/agent_data"
}
}
```
### 🔧 配置参数说明
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `agent_type` | AI代理类型 | "BaseAgent" |
| `max_steps` | 最大推理步数 | 30 |
| `max_retries` | 最大重试次数 | 3 |
| `base_delay` | 操作延迟(秒) | 1.0 |
| `initial_cash` | 初始资金 | $10,000 |
### 📊 数据格式
#### 💰 持仓记录 (position.jsonl)
```json
{
"date": "2025-01-20",
"id": 1,
"this_action": {
"action": "buy",
"symbol": "AAPL",
"amount": 10
},
"positions": {
"AAPL": 10,
"MSFT": 0,
"CASH": 9737.6
}
}
```
#### 📈 价格数据 (merged.jsonl)
```json
{
"Meta Data": {
"2. Symbol": "AAPL",
"3. Last Refreshed": "2025-01-20"
},
"Time Series (Daily)": {
"2025-01-20": {
"1. buy price": "255.8850",
"2. high": "264.3750",
"3. low": "255.6300",
"4. sell price": "262.2400",
"5. volume": "90483029"
}
}
}
```
### 📁 文件结构
```
data/agent_data/
├── claude-3.7-sonnet/
│ ├── position/
│ │ └── position.jsonl # 📝 持仓记录
│ └── log/
│ └── 2025-01-20/
│ └── log.jsonl # 📊 交易日志
├── gpt-4o/
│ └── ...
└── qwen3-max/
└── ...
```
## 🔌 第三方策略集成
AI-Trader Bench采用模块化设计支持轻松集成第三方策略和自定义AI代理。
### 🛠️ 集成方式
#### 1. 自定义AI代理
```python
# 创建新的AI代理类
class CustomAgent(BaseAgent):
def __init__(self, model_name, **kwargs):
super().__init__(model_name, **kwargs)
# 添加自定义逻辑
```
#### 2. 注册新代理
```python
# 在 main.py 中注册
AGENT_REGISTRY = {
"BaseAgent": {
"module": "agent.base_agent.base_agent",
"class": "BaseAgent"
},
"CustomAgent": { # 新增
"module": "agent.custom.custom_agent",
"class": "CustomAgent"
},
}
```
#### 3. 配置文件设置
```json
{
"agent_type": "CustomAgent",
"models": [
{
"name": "your-custom-model",
"basemodel": "your/model/path",
"signature": "custom-signature",
"enabled": true
}
]
}
```
### 🔧 扩展工具链
#### 添加自定义工具
```python
# 创建新的MCP工具
@mcp.tools()
class CustomTool:
def __init__(self):
self.name = "custom_tool"
def execute(self, params):
# 实现自定义工具逻辑
return result
```
## 🚀 路线图
### 🌟 未来计划
- [ ] **🇨🇳 A股支持** - 扩展至中国股市
- [ ] **📊 收盘后统计** - 自动收益分析
- [ ] **🔌 策略市场** - 添加第三方策略分享平台
- [ ] **🎨 炫酷前端界面** - 现代化Web仪表板
- [ ] **₿ 加密货币** - 支持数字货币交易
- [ ] **📈 更多策略** - 技术分析、量化策略
- [ ] **⏰ 高级回放** - 支持分钟级时间精度和实时回放
- [ ] **🔍 智能过滤** - 更精确的未来信息检测和过滤
## 🤝 贡献指南
我们欢迎各种形式的贡献特别是AI交易策略和代理实现。
### 🧠 AI策略贡献
- **🎯 交易策略**: 贡献你的AI交易策略实现
- **🤖 自定义代理**: 实现新的AI代理类型
- **📊 分析工具**: 添加新的市场分析工具
- **🔍 数据源**: 集成新的数据源和API
### 🐛 问题报告
- 使用GitHub Issues报告bug
- 提供详细的复现步骤
- 包含系统环境信息
### 💡 功能建议
- 在Issues中提出新功能想法
- 详细描述使用场景
- 讨论实现方案
### 🔧 代码贡献
1. Fork项目
2. 创建功能分支
3. 实现你的策略或功能
4. 添加测试用例
5. 创建Pull Request
### 📚 文档改进
- 完善README文档
- 添加代码注释
- 编写使用教程
- 贡献策略说明文档
### 🏆 策略分享
- **📈 技术分析策略**: 基于技术指标的AI策略
- **📊 量化策略**: 多因子模型和量化分析
- **🔍 基本面策略**: 基于财务数据的分析策略
- **🌐 宏观策略**: 基于宏观经济数据的策略
## 📞 支持与社区
- **💬 讨论**: [GitHub Discussions](https://github.com/HKUDS/AI-Trader/discussions)
- **🐛 问题**: [GitHub Issues](https://github.com/HKUDS/AI-Trader/issues)
## 📄 许可证
本项目采用 [MIT License](LICENSE) 开源协议。
## 🙏 致谢
感谢以下开源项目和服务:
- [LangChain](https://github.com/langchain-ai/langchain) - AI应用开发框架
- [MCP](https://github.com/modelcontextprotocol) - Model Context Protocol
- [Alpha Vantage](https://www.alphavantage.co/) - 金融数据API
- [Jina AI](https://jina.ai/) - 信息搜索服务
## 免责声明
AI-Trader项目所提供的资料仅供研究之用并不构成任何投资建议。投资者在作出任何投资决策之前应寻求独立专业意见。任何过往表现未必可作为未来业绩的指标。阁下应注意投资价值可能上升亦可能下跌且并无任何保证。AI-Trader项目的所有内容仅作研究之用并不构成对所提及之证券行业的任何投资推荐。投资涉及风险。如有需要请寻求专业咨询。
---
<div align="center">
**🌟 如果这个项目对你有帮助请给我们一个Star**
[![GitHub stars](https://img.shields.io/github/stars/HKUDS/AI-Trader?style=social)](https://github.com/HKUDS/AI-Trader)
[![GitHub forks](https://img.shields.io/github/forks/HKUDS/AI-Trader?style=social)](https://github.com/HKUDS/AI-Trader)
**🤖 让AI在金融市场中完全自主决策、一展身手**
**🛠️ 纯工具驱动零人工干预真正的AI交易竞技场** 🚀
</div>

88
ROADMAP.md Normal file
View File

@@ -0,0 +1,88 @@
# AI-Trader Roadmap
This document outlines planned features and improvements for the AI-Trader project.
## Release Planning
### v0.4.0 - Enhanced Simulation Management (Planned)
**Focus:** Improved simulation control, resume capabilities, and performance analysis
#### Simulation Resume & Continuation
- **Resume from Last Completed Date** - API to continue simulations without re-running completed dates
- `POST /simulate/resume` - Resume last incomplete job or start from last completed date
- `POST /simulate/continue` - Extend existing simulation with new date range
- Query parameters to specify which model(s) to continue
- Automatic detection of last completed date per model
- Validation to prevent overlapping simulations
- Support for extending date ranges forward in time
- Use cases:
- Daily simulation updates (add today's date to existing run)
- Recovering from failed jobs (resume from interruption point)
- Incremental backtesting (extend historical analysis)
#### Position History & Analysis
- **Position History Tracking** - Track position changes over time
- Query endpoint: `GET /positions/history?model=<name>&start_date=<date>&end_date=<date>`
- Timeline view of all trades and position changes
- Calculate holding periods and turnover rates
- Support for position snapshots at specific dates
#### Performance Metrics
- **Advanced Performance Analytics** - Calculate standard trading metrics
- Sharpe ratio, Sortino ratio, maximum drawdown
- Win rate, average win/loss, profit factor
- Volatility and beta calculations
- Risk-adjusted returns
- Comparison across models
#### Data Management
- **Price Data Management API** - Endpoints for price data operations
- `GET /data/coverage` - Check date ranges available per symbol
- `POST /data/download` - Trigger manual price data downloads
- `GET /data/status` - Check download progress and rate limits
- `DELETE /data/range` - Remove price data for specific date ranges
#### Web UI
- **Dashboard Interface** - Web-based monitoring and control interface
- Job management dashboard
- View active, pending, and completed jobs
- Start new simulations with form-based configuration
- Monitor job progress in real-time
- Cancel running jobs
- Results visualization
- Performance charts (P&L over time, cumulative returns)
- Position history timeline
- Model comparison views
- Trade log explorer with filtering
- Configuration management
- Model configuration editor
- Date range selection with calendar picker
- Price data coverage visualization
- Technical implementation
- Modern frontend framework (React, Vue.js, or Svelte)
- Real-time updates via WebSocket or SSE
- Responsive design for mobile access
- Chart library (Plotly.js, Chart.js, or Recharts)
- Served alongside API (single container deployment)
## Contributing
We welcome contributions to any of these planned features! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
To propose a new feature:
1. Open an issue with the `feature-request` label
2. Describe the use case and expected behavior
3. Discuss implementation approach with maintainers
4. Submit a PR with tests and documentation
## Version History
- **v0.1.0** - Initial release with batch execution
- **v0.2.0** - Docker deployment support
- **v0.3.0** - REST API, on-demand downloads, database storage (current)
- **v0.4.0** - Enhanced simulation management (planned)
---
Last updated: 2025-10-31

View File

@@ -50,6 +50,9 @@ def initialize_database(db_path: str = "data/jobs.db") -> None:
4. holdings - Portfolio holdings per position
5. reasoning_logs - AI decision logs (optional, for detail=full)
6. tool_usage - Tool usage statistics
7. price_data - Historical OHLCV price data (replaces merged.jsonl)
8. price_data_coverage - Downloaded date range tracking per symbol
9. simulation_runs - Simulation run tracking for soft delete
Args:
db_path: Path to SQLite database file
@@ -108,8 +111,10 @@ def initialize_database(db_path: str = "data/jobs.db") -> None:
daily_return_pct REAL,
cumulative_profit REAL,
cumulative_return_pct REAL,
simulation_run_id TEXT,
created_at TEXT NOT NULL,
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE,
FOREIGN KEY (simulation_run_id) REFERENCES simulation_runs(run_id) ON DELETE SET NULL
)
""")
@@ -154,6 +159,50 @@ def initialize_database(db_path: str = "data/jobs.db") -> None:
)
""")
# Table 7: Price Data - OHLCV price data (replaces merged.jsonl)
cursor.execute("""
CREATE TABLE IF NOT EXISTS price_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
date TEXT NOT NULL,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume INTEGER NOT NULL,
created_at TEXT NOT NULL,
UNIQUE(symbol, date)
)
""")
# Table 8: Price Data Coverage - Track downloaded date ranges per symbol
cursor.execute("""
CREATE TABLE IF NOT EXISTS price_data_coverage (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
start_date TEXT NOT NULL,
end_date TEXT NOT NULL,
downloaded_at TEXT NOT NULL,
source TEXT DEFAULT 'alpha_vantage',
UNIQUE(symbol, start_date, end_date)
)
""")
# Table 9: Simulation Runs - Track simulation runs for soft delete
cursor.execute("""
CREATE TABLE IF NOT EXISTS simulation_runs (
run_id TEXT PRIMARY KEY,
job_id TEXT NOT NULL,
model TEXT NOT NULL,
start_date TEXT NOT NULL,
end_date TEXT NOT NULL,
status TEXT NOT NULL CHECK(status IN ('active', 'superseded')),
created_at TEXT NOT NULL,
superseded_at TEXT,
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE
)
""")
# Create indexes for performance
_create_indexes(cursor)
@@ -222,6 +271,41 @@ def _create_indexes(cursor: sqlite3.Cursor) -> None:
ON tool_usage(job_id, date, model)
""")
# Price data table indexes
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_price_data_symbol_date ON price_data(symbol, date)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_price_data_date ON price_data(date)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_price_data_symbol ON price_data(symbol)
""")
# Price data coverage table indexes
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_coverage_symbol ON price_data_coverage(symbol)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_coverage_dates ON price_data_coverage(start_date, end_date)
""")
# Simulation runs table indexes
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_runs_job_model ON simulation_runs(job_id, model)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_runs_status ON simulation_runs(status)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_runs_dates ON simulation_runs(start_date, end_date)
""")
# Positions table - add index for simulation_run_id
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_positions_run_id ON positions(simulation_run_id)
""")
def drop_all_tables(db_path: str = "data/jobs.db") -> None:
"""
@@ -240,8 +324,11 @@ def drop_all_tables(db_path: str = "data/jobs.db") -> None:
'reasoning_logs',
'holdings',
'positions',
'simulation_runs',
'job_details',
'jobs'
'jobs',
'price_data_coverage',
'price_data'
]
for table in tables:
@@ -296,7 +383,8 @@ def get_database_stats(db_path: str = "data/jobs.db") -> dict:
stats["database_size_mb"] = 0
# Get row counts for each table
tables = ['jobs', 'job_details', 'positions', 'holdings', 'reasoning_logs', 'tool_usage']
tables = ['jobs', 'job_details', 'positions', 'holdings', 'reasoning_logs', 'tool_usage',
'price_data', 'price_data_coverage', 'simulation_runs']
for table in tables:
cursor.execute(f"SELECT COUNT(*) FROM {table}")

93
api/date_utils.py Normal file
View File

@@ -0,0 +1,93 @@
"""
Date range utilities for simulation date management.
This module provides:
- Date range expansion
- Date range validation
- Trading day detection
"""
import os
from datetime import datetime, timedelta
from typing import List
def expand_date_range(start_date: str, end_date: str) -> List[str]:
"""
Expand date range into list of all dates (inclusive).
Args:
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
Returns:
Sorted list of dates in range
Raises:
ValueError: If dates are invalid or start > end
"""
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
if start > end:
raise ValueError(f"start_date ({start_date}) must be <= end_date ({end_date})")
dates = []
current = start
while current <= end:
dates.append(current.strftime("%Y-%m-%d"))
current += timedelta(days=1)
return dates
def validate_date_range(
start_date: str,
end_date: str,
max_days: int = 30
) -> None:
"""
Validate date range for simulation.
Args:
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
max_days: Maximum allowed days in range
Raises:
ValueError: If validation fails
"""
# Parse dates
try:
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# Check order
if start > end:
raise ValueError(f"start_date ({start_date}) must be <= end_date ({end_date})")
# Check range size
days = (end - start).days + 1
if days > max_days:
raise ValueError(
f"Date range too large: {days} days (max: {max_days}). "
f"Reduce range or increase MAX_SIMULATION_DAYS."
)
# Check not in future
today = datetime.now().date()
if end.date() > today:
raise ValueError(f"end_date ({end_date}) cannot be in the future")
def get_max_simulation_days() -> int:
"""
Get maximum simulation days from environment.
Returns:
Maximum days allowed in simulation range
"""
return int(os.getenv("MAX_SIMULATION_DAYS", "30"))

View File

@@ -9,6 +9,7 @@ Provides endpoints for:
"""
import logging
import os
from typing import Optional, List, Dict, Any
from datetime import datetime
from pathlib import Path
@@ -20,6 +21,8 @@ from pydantic import BaseModel, Field, field_validator
from api.job_manager import JobManager
from api.simulation_worker import SimulationWorker
from api.database import get_db_connection
from api.price_data_manager import PriceDataManager
from api.date_utils import validate_date_range, expand_date_range, get_max_simulation_days
import threading
import time
@@ -29,21 +32,29 @@ logger = logging.getLogger(__name__)
# Pydantic models for request/response validation
class SimulateTriggerRequest(BaseModel):
"""Request body for POST /simulate/trigger."""
config_path: str = Field(..., description="Path to configuration file")
date_range: List[str] = Field(..., min_length=1, description="List of trading dates (YYYY-MM-DD)")
models: List[str] = Field(..., min_length=1, description="List of model signatures to simulate")
start_date: str = Field(..., description="Start date for simulation (YYYY-MM-DD)")
end_date: Optional[str] = Field(None, description="End date for simulation (YYYY-MM-DD). If not provided, simulates single day.")
models: Optional[List[str]] = Field(
None,
description="Optional: List of model signatures to simulate. If not provided, uses enabled models from config."
)
@field_validator("date_range")
@field_validator("start_date", "end_date")
@classmethod
def validate_date_range(cls, v):
def validate_date_format(cls, v):
"""Validate date format."""
for date in v:
try:
datetime.strptime(date, "%Y-%m-%d")
except ValueError:
raise ValueError(f"Invalid date format: {date}. Expected YYYY-MM-DD")
if v is None:
return v
try:
datetime.strptime(v, "%Y-%m-%d")
except ValueError:
raise ValueError(f"Invalid date format: {v}. Expected YYYY-MM-DD")
return v
def get_end_date(self) -> str:
"""Get end date, defaulting to start_date if not provided."""
return self.end_date or self.start_date
class SimulateTriggerResponse(BaseModel):
"""Response body for POST /simulate/trigger."""
@@ -83,12 +94,16 @@ class HealthResponse(BaseModel):
timestamp: str
def create_app(db_path: str = "data/jobs.db") -> FastAPI:
def create_app(
db_path: str = "data/jobs.db",
config_path: str = "configs/default_config.json"
) -> FastAPI:
"""
Create FastAPI application instance.
Args:
db_path: Path to SQLite database
config_path: Path to default configuration file
Returns:
Configured FastAPI app
@@ -99,27 +114,121 @@ def create_app(db_path: str = "data/jobs.db") -> FastAPI:
version="1.0.0"
)
# Store db_path in app state
# Store paths in app state
app.state.db_path = db_path
app.state.config_path = config_path
@app.post("/simulate/trigger", response_model=SimulateTriggerResponse, status_code=200)
async def trigger_simulation(request: SimulateTriggerRequest):
"""
Trigger a new simulation job.
Creates a job with specified config, dates, and models.
Job runs asynchronously in background thread.
Validates date range, downloads missing price data if needed,
and creates job with available trading dates.
Raises:
HTTPException 400: If another job is already running or config invalid
HTTPException 422: If request validation fails
HTTPException 400: Validation errors, running job, or invalid dates
HTTPException 503: Price data download failed
"""
try:
# Use config path from app state
config_path = app.state.config_path
# Validate config path exists
if not Path(request.config_path).exists():
if not Path(config_path).exists():
raise HTTPException(
status_code=500,
detail=f"Server configuration file not found: {config_path}"
)
# Get end date (defaults to start_date for single day)
end_date = request.get_end_date()
# Validate date range
max_days = get_max_simulation_days()
validate_date_range(request.start_date, end_date, max_days=max_days)
# Determine which models to run
import json
with open(config_path, 'r') as f:
config = json.load(f)
if request.models is not None:
# Use models from request (explicit override)
models_to_run = request.models
else:
# Use enabled models from config
models_to_run = [
model["signature"]
for model in config.get("models", [])
if model.get("enabled", False)
]
if not models_to_run:
raise HTTPException(
status_code=400,
detail="No enabled models found in config. Either enable models in config or specify them in request."
)
# Check price data and download if needed
auto_download = os.getenv("AUTO_DOWNLOAD_PRICE_DATA", "true").lower() == "true"
price_manager = PriceDataManager(db_path=app.state.db_path)
# Check what's missing
missing_coverage = price_manager.get_missing_coverage(
request.start_date,
end_date
)
download_info = None
# Download missing data if enabled
if any(missing_coverage.values()):
if not auto_download:
raise HTTPException(
status_code=400,
detail=f"Missing price data for {len(missing_coverage)} symbols and auto-download is disabled. "
f"Enable AUTO_DOWNLOAD_PRICE_DATA or pre-populate data."
)
logger.info(f"Downloading missing price data for {len(missing_coverage)} symbols")
requested_dates = set(expand_date_range(request.start_date, end_date))
download_result = price_manager.download_missing_data_prioritized(
missing_coverage,
requested_dates
)
if not download_result["success"]:
raise HTTPException(
status_code=503,
detail="Failed to download any price data. Check ALPHAADVANTAGE_API_KEY."
)
download_info = {
"symbols_downloaded": len(download_result["downloaded"]),
"symbols_failed": len(download_result["failed"]),
"rate_limited": download_result["rate_limited"]
}
logger.info(
f"Downloaded {len(download_result['downloaded'])} symbols, "
f"{len(download_result['failed'])} failed, rate_limited={download_result['rate_limited']}"
)
# Get available trading dates (after potential download)
available_dates = price_manager.get_available_trading_dates(
request.start_date,
end_date
)
if not available_dates:
raise HTTPException(
status_code=400,
detail=f"Config path does not exist: {request.config_path}"
detail=f"No trading dates with complete price data in range "
f"{request.start_date} to {end_date}. "
f"All symbols must have data for a date to be tradeable."
)
job_manager = JobManager(db_path=app.state.db_path)
@@ -131,11 +240,11 @@ def create_app(db_path: str = "data/jobs.db") -> FastAPI:
detail="Another simulation job is already running or pending. Please wait for it to complete."
)
# Create job
# Create job with available dates
job_id = job_manager.create_job(
config_path=request.config_path,
date_range=request.date_range,
models=request.models
config_path=config_path,
date_range=available_dates,
models=models_to_run
)
# Start worker in background thread (only if not in test mode)
@@ -147,15 +256,27 @@ def create_app(db_path: str = "data/jobs.db") -> FastAPI:
thread = threading.Thread(target=run_worker, daemon=True)
thread.start()
logger.info(f"Triggered simulation job {job_id}")
logger.info(f"Triggered simulation job {job_id} with {len(available_dates)} dates")
return SimulateTriggerResponse(
# Build response message
message = f"Simulation job created with {len(available_dates)} trading dates"
if download_info and download_info["rate_limited"]:
message += " (rate limit reached - partial data)"
response = SimulateTriggerResponse(
job_id=job_id,
status="pending",
total_model_days=len(request.date_range) * len(request.models),
message=f"Simulation job {job_id} created and started"
total_model_days=len(available_dates) * len(models_to_run),
message=message
)
# Add download info if we downloaded
if download_info:
# Note: Need to add download_info field to response model
logger.info(f"Download info: {download_info}")
return response
except HTTPException:
raise
except ValueError as e:

546
api/price_data_manager.py Normal file
View File

@@ -0,0 +1,546 @@
"""
Price data management for on-demand downloads and coverage tracking.
This module provides:
- Coverage gap detection
- Priority-based download ordering
- Rate limit handling with retry logic
- Price data storage and retrieval
"""
import logging
import json
import os
import time
import requests
from pathlib import Path
from typing import List, Dict, Set, Tuple, Optional, Callable, Any
from datetime import datetime, timedelta
from collections import defaultdict
from api.database import get_db_connection
logger = logging.getLogger(__name__)
class RateLimitError(Exception):
"""Raised when API rate limit is hit."""
pass
class DownloadError(Exception):
"""Raised when download fails for non-rate-limit reasons."""
pass
class PriceDataManager:
"""
Manages price data availability, downloads, and coverage tracking.
Responsibilities:
- Check which dates/symbols have price data
- Download missing data from Alpha Vantage
- Track downloaded date ranges per symbol
- Prioritize downloads to maximize date completion
- Handle rate limiting gracefully
"""
def __init__(
self,
db_path: str = "data/jobs.db",
symbols_config: str = "configs/nasdaq100_symbols.json",
api_key: Optional[str] = None
):
"""
Initialize PriceDataManager.
Args:
db_path: Path to SQLite database
symbols_config: Path to NASDAQ 100 symbols configuration
api_key: Alpha Vantage API key (defaults to env var)
"""
self.db_path = db_path
self.symbols_config = symbols_config
self.api_key = api_key or os.getenv("ALPHAADVANTAGE_API_KEY")
# Load symbols list
self.symbols = self._load_symbols()
logger.info(f"Initialized PriceDataManager with {len(self.symbols)} symbols")
def _load_symbols(self) -> List[str]:
"""Load NASDAQ 100 symbols from config file."""
config_path = Path(self.symbols_config)
if not config_path.exists():
logger.warning(f"Symbols config not found: {config_path}. Using default list.")
# Fallback to a minimal list
return ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA"]
with open(config_path, 'r') as f:
config = json.load(f)
return config.get("symbols", [])
def get_available_dates(self) -> Set[str]:
"""
Get all dates that have price data in database.
Returns:
Set of dates (YYYY-MM-DD) with data
"""
conn = get_db_connection(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT DISTINCT date FROM price_data ORDER BY date")
dates = {row[0] for row in cursor.fetchall()}
conn.close()
return dates
def get_symbol_dates(self, symbol: str) -> Set[str]:
"""
Get all dates that have data for a specific symbol.
Args:
symbol: Stock symbol
Returns:
Set of dates with data for this symbol
"""
conn = get_db_connection(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT date FROM price_data WHERE symbol = ? ORDER BY date",
(symbol,)
)
dates = {row[0] for row in cursor.fetchall()}
conn.close()
return dates
def get_missing_coverage(
self,
start_date: str,
end_date: str
) -> Dict[str, Set[str]]:
"""
Identify which symbols are missing data for which dates in range.
Args:
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
Returns:
Dict mapping symbol to set of missing dates
Example: {"AAPL": {"2025-01-20", "2025-01-21"}, "MSFT": set()}
"""
# Generate all dates in range
requested_dates = self._expand_date_range(start_date, end_date)
missing = {}
for symbol in self.symbols:
symbol_dates = self.get_symbol_dates(symbol)
missing_dates = requested_dates - symbol_dates
if missing_dates:
missing[symbol] = missing_dates
return missing
def _expand_date_range(self, start_date: str, end_date: str) -> Set[str]:
"""
Expand date range into set of all dates.
Args:
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
Returns:
Set of all dates in range (inclusive)
"""
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
dates = set()
current = start
while current <= end:
dates.add(current.strftime("%Y-%m-%d"))
current += timedelta(days=1)
return dates
def prioritize_downloads(
self,
missing_coverage: Dict[str, Set[str]],
requested_dates: Set[str]
) -> List[str]:
"""
Prioritize symbol downloads to maximize date completion.
Strategy: Download symbols that complete the most requested dates first.
Args:
missing_coverage: Dict of symbol -> missing dates
requested_dates: Set of dates we want to simulate
Returns:
List of symbols in priority order (highest impact first)
"""
# Calculate impact score for each symbol
impacts = []
for symbol, missing_dates in missing_coverage.items():
# Impact = number of requested dates this symbol would complete
impact = len(missing_dates & requested_dates)
if impact > 0:
impacts.append((symbol, impact))
# Sort by impact (descending)
impacts.sort(key=lambda x: x[1], reverse=True)
# Return symbols in priority order
prioritized = [symbol for symbol, _ in impacts]
logger.info(f"Prioritized {len(prioritized)} symbols for download")
if prioritized:
logger.debug(f"Top 5 symbols: {prioritized[:5]}")
return prioritized
def download_missing_data_prioritized(
self,
missing_coverage: Dict[str, Set[str]],
requested_dates: Set[str],
progress_callback: Optional[Callable] = None
) -> Dict[str, Any]:
"""
Download data in priority order until rate limited.
Args:
missing_coverage: Dict of symbol -> missing dates
requested_dates: Set of dates being requested
progress_callback: Optional callback for progress updates
Returns:
{
"success": True/False,
"downloaded": ["AAPL", "MSFT", ...],
"failed": ["GOOGL", ...],
"rate_limited": True/False,
"dates_completed": ["2025-01-20", ...],
"partial_dates": {"2025-01-21": 75}
}
"""
if not self.api_key:
raise ValueError("ALPHAADVANTAGE_API_KEY not configured")
# Prioritize downloads
prioritized_symbols = self.prioritize_downloads(missing_coverage, requested_dates)
if not prioritized_symbols:
logger.info("No downloads needed - all data available")
return {
"success": True,
"downloaded": [],
"failed": [],
"rate_limited": False,
"dates_completed": sorted(requested_dates),
"partial_dates": {}
}
logger.info(f"Starting priority download of {len(prioritized_symbols)} symbols")
downloaded = []
failed = []
rate_limited = False
# Download in priority order
for i, symbol in enumerate(prioritized_symbols):
try:
# Progress callback
if progress_callback:
progress_callback({
"current": i + 1,
"total": len(prioritized_symbols),
"symbol": symbol,
"phase": "downloading"
})
# Download symbol data
logger.info(f"Downloading {symbol} ({i+1}/{len(prioritized_symbols)})")
data = self._download_symbol(symbol)
# Store in database
stored_dates = self._store_symbol_data(symbol, data, requested_dates)
# Update coverage tracking
if stored_dates:
self._update_coverage(symbol, min(stored_dates), max(stored_dates))
downloaded.append(symbol)
logger.info(f"✓ Downloaded {symbol} - {len(stored_dates)} dates stored")
except RateLimitError as e:
# Hit rate limit - stop downloading
logger.warning(f"Rate limit hit after {len(downloaded)} downloads: {e}")
rate_limited = True
failed = prioritized_symbols[i:] # Rest are undownloaded
break
except Exception as e:
# Other error - log and continue
logger.error(f"Failed to download {symbol}: {e}")
failed.append(symbol)
continue
# Analyze coverage
coverage_analysis = self._analyze_coverage(requested_dates)
result = {
"success": len(downloaded) > 0 or len(requested_dates) == len(coverage_analysis["completed_dates"]),
"downloaded": downloaded,
"failed": failed,
"rate_limited": rate_limited,
"dates_completed": coverage_analysis["completed_dates"],
"partial_dates": coverage_analysis["partial_dates"]
}
logger.info(
f"Download complete: {len(downloaded)} symbols downloaded, "
f"{len(failed)} failed/skipped, rate_limited={rate_limited}"
)
return result
def _download_symbol(self, symbol: str, retries: int = 3) -> Dict:
"""
Download full price history for a symbol.
Args:
symbol: Stock symbol
retries: Number of retry attempts for transient errors
Returns:
JSON response from Alpha Vantage
Raises:
RateLimitError: If rate limit is hit
DownloadError: If download fails after retries
"""
if not self.api_key:
raise DownloadError("API key not configured")
for attempt in range(retries):
try:
response = requests.get(
"https://www.alphavantage.co/query",
params={
"function": "TIME_SERIES_DAILY",
"symbol": symbol,
"outputsize": "full", # Get full history
"apikey": self.api_key
},
timeout=30
)
if response.status_code == 200:
data = response.json()
# Check for API error messages
if "Error Message" in data:
raise DownloadError(f"API error: {data['Error Message']}")
# Check for rate limit in response body
if "Note" in data:
note = data["Note"]
if "call frequency" in note.lower() or "rate limit" in note.lower():
raise RateLimitError(note)
# Other notes are warnings, continue
logger.warning(f"{symbol}: {note}")
if "Information" in data:
info = data["Information"]
if "premium" in info.lower() or "limit" in info.lower():
raise RateLimitError(info)
# Validate response has time series data
if "Time Series (Daily)" not in data or "Meta Data" not in data:
raise DownloadError(f"Invalid response format for {symbol}")
return data
elif response.status_code == 429:
raise RateLimitError("HTTP 429: Too Many Requests")
elif response.status_code >= 500:
# Server error - retry with backoff
if attempt < retries - 1:
wait_time = (2 ** attempt)
logger.warning(f"Server error {response.status_code}. Retrying in {wait_time}s...")
time.sleep(wait_time)
continue
raise DownloadError(f"Server error: {response.status_code}")
else:
raise DownloadError(f"HTTP {response.status_code}: {response.text[:200]}")
except RateLimitError:
raise # Don't retry rate limits
except DownloadError:
raise # Don't retry download errors
except requests.RequestException as e:
if attempt < retries - 1:
logger.warning(f"Request failed: {e}. Retrying...")
time.sleep(2)
continue
raise DownloadError(f"Request failed after {retries} attempts: {e}")
raise DownloadError(f"Failed to download {symbol} after {retries} attempts")
def _store_symbol_data(
self,
symbol: str,
data: Dict,
requested_dates: Set[str]
) -> List[str]:
"""
Store downloaded price data in database.
Args:
symbol: Stock symbol
data: Alpha Vantage API response
requested_dates: Only store dates in this set
Returns:
List of dates actually stored
"""
time_series = data.get("Time Series (Daily)", {})
if not time_series:
logger.warning(f"No time series data for {symbol}")
return []
conn = get_db_connection(self.db_path)
cursor = conn.cursor()
stored_dates = []
created_at = datetime.utcnow().isoformat() + "Z"
for date, ohlcv in time_series.items():
# Only store requested dates
if date not in requested_dates:
continue
try:
cursor.execute("""
INSERT OR REPLACE INTO price_data
(symbol, date, open, high, low, close, volume, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
symbol,
date,
float(ohlcv.get("1. open", 0)),
float(ohlcv.get("2. high", 0)),
float(ohlcv.get("3. low", 0)),
float(ohlcv.get("4. close", 0)),
int(ohlcv.get("5. volume", 0)),
created_at
))
stored_dates.append(date)
except Exception as e:
logger.error(f"Failed to store {symbol} {date}: {e}")
continue
conn.commit()
conn.close()
return stored_dates
def _update_coverage(self, symbol: str, start_date: str, end_date: str) -> None:
"""
Update coverage tracking for a symbol.
Args:
symbol: Stock symbol
start_date: Start of date range downloaded
end_date: End of date range downloaded
"""
conn = get_db_connection(self.db_path)
cursor = conn.cursor()
downloaded_at = datetime.utcnow().isoformat() + "Z"
cursor.execute("""
INSERT OR REPLACE INTO price_data_coverage
(symbol, start_date, end_date, downloaded_at, source)
VALUES (?, ?, ?, ?, 'alpha_vantage')
""", (symbol, start_date, end_date, downloaded_at))
conn.commit()
conn.close()
def _analyze_coverage(self, requested_dates: Set[str]) -> Dict[str, Any]:
"""
Analyze which requested dates have complete/partial coverage.
Args:
requested_dates: Set of dates requested
Returns:
{
"completed_dates": ["2025-01-20", ...], # All symbols available
"partial_dates": {"2025-01-21": 75, ...} # Date -> symbol count
}
"""
conn = get_db_connection(self.db_path)
cursor = conn.cursor()
total_symbols = len(self.symbols)
completed_dates = []
partial_dates = {}
for date in sorted(requested_dates):
# Count symbols available for this date
cursor.execute(
"SELECT COUNT(DISTINCT symbol) FROM price_data WHERE date = ?",
(date,)
)
count = cursor.fetchone()[0]
if count == total_symbols:
completed_dates.append(date)
elif count > 0:
partial_dates[date] = count
conn.close()
return {
"completed_dates": completed_dates,
"partial_dates": partial_dates
}
def get_available_trading_dates(
self,
start_date: str,
end_date: str
) -> List[str]:
"""
Get trading dates with complete data in range.
Args:
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
Returns:
Sorted list of dates with complete data (all symbols)
"""
requested_dates = self._expand_date_range(start_date, end_date)
analysis = self._analyze_coverage(requested_dates)
return sorted(analysis["completed_dates"])

View File

@@ -0,0 +1,18 @@
{
"symbols": [
"NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA",
"NFLX", "PLTR", "COST", "ASML", "AMD", "CSCO", "AZN", "TMUS", "MU", "LIN",
"PEP", "SHOP", "APP", "INTU", "AMAT", "LRCX", "PDD", "QCOM", "ARM", "INTC",
"BKNG", "AMGN", "TXN", "ISRG", "GILD", "KLAC", "PANW", "ADBE", "HON",
"CRWD", "CEG", "ADI", "ADP", "DASH", "CMCSA", "VRTX", "MELI", "SBUX",
"CDNS", "ORLY", "SNPS", "MSTR", "MDLZ", "ABNB", "MRVL", "CTAS", "TRI",
"MAR", "MNST", "CSX", "ADSK", "PYPL", "FTNT", "AEP", "WDAY", "REGN", "ROP",
"NXPI", "DDOG", "AXON", "ROST", "IDXX", "EA", "PCAR", "FAST", "EXC", "TTWO",
"XEL", "ZS", "PAYX", "WBD", "BKR", "CPRT", "CCEP", "FANG", "TEAM", "CHTR",
"KDP", "MCHP", "GEHC", "VRSK", "CTSH", "CSGP", "KHC", "ODFL", "DXCM", "TTD",
"ON", "BIIB", "LULU", "CDW", "GFS", "QQQ"
],
"description": "NASDAQ 100 constituent stocks plus QQQ ETF",
"last_updated": "2025-10-31",
"total_symbols": 101
}

View File

@@ -23,8 +23,6 @@ services:
ports:
# API server port (primary interface for external access)
- "${API_PORT:-8080}:8080"
# Web dashboard
- "${WEB_HTTP_PORT:-8888}:8888"
restart: unless-stopped # Keep API server running
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]

166
scripts/migrate_price_data.py Executable file
View File

@@ -0,0 +1,166 @@
#!/usr/bin/env python3
"""
Migration script: Import merged.jsonl price data into SQLite database.
This script:
1. Reads existing merged.jsonl file
2. Parses OHLCV data for each symbol/date
3. Inserts into price_data table
4. Tracks coverage in price_data_coverage table
Run this once to migrate from jsonl to database.
"""
import json
import sys
from pathlib import Path
from datetime import datetime
from collections import defaultdict
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from api.database import get_db_connection, initialize_database
def migrate_merged_jsonl(
jsonl_path: str = "data/merged.jsonl",
db_path: str = "data/jobs.db"
):
"""
Migrate price data from merged.jsonl to SQLite database.
Args:
jsonl_path: Path to merged.jsonl file
db_path: Path to SQLite database
"""
jsonl_file = Path(jsonl_path)
if not jsonl_file.exists():
print(f"⚠️ merged.jsonl not found at {jsonl_path}")
print(" No price data to migrate. Skipping migration.")
return
print(f"📊 Migrating price data from {jsonl_path} to {db_path}")
# Ensure database is initialized
initialize_database(db_path)
conn = get_db_connection(db_path)
cursor = conn.cursor()
# Track what we're importing
total_records = 0
symbols_processed = set()
symbol_date_ranges = defaultdict(lambda: {"min": None, "max": None})
created_at = datetime.utcnow().isoformat() + "Z"
print("Reading merged.jsonl...")
with open(jsonl_file, 'r') as f:
for line_num, line in enumerate(f, 1):
if not line.strip():
continue
try:
record = json.loads(line)
# Extract metadata
meta = record.get("Meta Data", {})
symbol = meta.get("2. Symbol")
if not symbol:
print(f"⚠️ Line {line_num}: No symbol found, skipping")
continue
symbols_processed.add(symbol)
# Extract time series data
time_series = record.get("Time Series (Daily)", {})
if not time_series:
print(f"⚠️ {symbol}: No time series data, skipping")
continue
# Insert each date's data
for date, ohlcv in time_series.items():
try:
# Parse OHLCV values
open_price = float(ohlcv.get("1. buy price") or ohlcv.get("1. open", 0))
high_price = float(ohlcv.get("2. high", 0))
low_price = float(ohlcv.get("3. low", 0))
close_price = float(ohlcv.get("4. sell price") or ohlcv.get("4. close", 0))
volume = int(ohlcv.get("5. volume", 0))
# Insert or replace price data
cursor.execute("""
INSERT OR REPLACE INTO price_data
(symbol, date, open, high, low, close, volume, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (symbol, date, open_price, high_price, low_price, close_price, volume, created_at))
total_records += 1
# Track date range for this symbol
if symbol_date_ranges[symbol]["min"] is None or date < symbol_date_ranges[symbol]["min"]:
symbol_date_ranges[symbol]["min"] = date
if symbol_date_ranges[symbol]["max"] is None or date > symbol_date_ranges[symbol]["max"]:
symbol_date_ranges[symbol]["max"] = date
except (ValueError, KeyError) as e:
print(f"⚠️ {symbol} {date}: Failed to parse OHLCV data: {e}")
continue
# Commit every 1000 records for progress
if total_records % 1000 == 0:
conn.commit()
print(f" Imported {total_records} records...")
except json.JSONDecodeError as e:
print(f"⚠️ Line {line_num}: JSON decode error: {e}")
continue
# Final commit
conn.commit()
print(f"\n✓ Imported {total_records} price records for {len(symbols_processed)} symbols")
# Update coverage tracking
print("\nUpdating coverage tracking...")
for symbol, date_range in symbol_date_ranges.items():
if date_range["min"] and date_range["max"]:
cursor.execute("""
INSERT OR REPLACE INTO price_data_coverage
(symbol, start_date, end_date, downloaded_at, source)
VALUES (?, ?, ?, ?, 'migrated_from_jsonl')
""", (symbol, date_range["min"], date_range["max"], created_at))
conn.commit()
conn.close()
print(f"✓ Coverage tracking updated for {len(symbol_date_ranges)} symbols")
print("\n✅ Migration complete!")
print(f"\nSymbols migrated: {', '.join(sorted(symbols_processed))}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Migrate merged.jsonl to SQLite database")
parser.add_argument(
"--jsonl",
default="data/merged.jsonl",
help="Path to merged.jsonl file (default: data/merged.jsonl)"
)
parser.add_argument(
"--db",
default="data/jobs.db",
help="Path to SQLite database (default: data/jobs.db)"
)
args = parser.parse_args()
migrate_merged_jsonl(args.jsonl, args.db)

View File

@@ -0,0 +1,453 @@
"""
Integration tests for on-demand price data downloads.
Tests the complete flow from missing coverage detection through download
and storage, including priority-based download strategy and rate limit handling.
"""
import pytest
import os
import tempfile
import json
from unittest.mock import patch, Mock
from datetime import datetime
from api.price_data_manager import PriceDataManager, RateLimitError, DownloadError
from api.database import initialize_database, get_db_connection
from api.date_utils import expand_date_range
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.db', delete=False) as f:
db_path = f.name
initialize_database(db_path)
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
@pytest.fixture
def temp_symbols_config():
"""Create temporary symbols config with small symbol set."""
symbols_data = {
"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA"],
"description": "Test symbols",
"total_symbols": 5
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(symbols_data, f)
config_path = f.name
yield config_path
# Cleanup
if os.path.exists(config_path):
os.unlink(config_path)
@pytest.fixture
def manager(temp_db, temp_symbols_config):
"""Create PriceDataManager instance."""
return PriceDataManager(
db_path=temp_db,
symbols_config=temp_symbols_config,
api_key="test_api_key"
)
@pytest.fixture
def mock_alpha_vantage_response():
"""Create mock Alpha Vantage API response."""
def create_response(symbol: str, dates: list):
"""Create response for given symbol and dates."""
time_series = {}
for date in dates:
time_series[date] = {
"1. open": "150.00",
"2. high": "155.00",
"3. low": "149.00",
"4. close": "154.00",
"5. volume": "1000000"
}
return {
"Meta Data": {
"1. Information": "Daily Prices",
"2. Symbol": symbol,
"3. Last Refreshed": dates[0] if dates else "2025-01-20"
},
"Time Series (Daily)": time_series
}
return create_response
class TestEndToEndDownload:
"""Test complete download workflow."""
@patch('api.price_data_manager.requests.get')
def test_download_missing_data_success(self, mock_get, manager, mock_alpha_vantage_response):
"""Test successful download of missing price data."""
# Setup: Mock API responses for each symbol
dates = ["2025-01-20", "2025-01-21"]
def mock_response_factory(url, **kwargs):
"""Return appropriate mock response based on symbol in params."""
symbol = kwargs.get('params', {}).get('symbol', 'AAPL')
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_alpha_vantage_response(symbol, dates)
return mock_response
mock_get.side_effect = mock_response_factory
# Test: Request date range with no existing data
missing = manager.get_missing_coverage("2025-01-20", "2025-01-21")
# All symbols should be missing both dates
assert len(missing) == 5
for symbol in ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA"]:
assert symbol in missing
assert missing[symbol] == {"2025-01-20", "2025-01-21"}
# Download missing data
requested_dates = set(dates)
result = manager.download_missing_data_prioritized(missing, requested_dates)
# Should successfully download all symbols
assert result["success"] is True
assert len(result["downloaded"]) == 5
assert result["rate_limited"] is False
assert set(result["dates_completed"]) == requested_dates
# Verify data in database
available_dates = manager.get_available_trading_dates("2025-01-20", "2025-01-21")
assert available_dates == ["2025-01-20", "2025-01-21"]
# Verify coverage tracking
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM price_data_coverage")
coverage_count = cursor.fetchone()[0]
assert coverage_count == 5 # One record per symbol
conn.close()
@patch('api.price_data_manager.requests.get')
def test_download_with_partial_existing_data(self, mock_get, manager, mock_alpha_vantage_response):
"""Test download when some data already exists."""
dates = ["2025-01-20", "2025-01-21", "2025-01-22"]
# Prepopulate database with some data (AAPL and MSFT for first two dates)
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
created_at = datetime.utcnow().isoformat() + "Z"
for symbol in ["AAPL", "MSFT"]:
for date in dates[:2]: # Only first two dates
cursor.execute("""
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
VALUES (?, ?, 150.0, 155.0, 149.0, 154.0, 1000000, ?)
""", (symbol, date, created_at))
cursor.execute("""
INSERT INTO price_data_coverage (symbol, start_date, end_date, downloaded_at, source)
VALUES (?, ?, ?, ?, 'test')
""", (symbol, dates[0], dates[1], created_at))
conn.commit()
conn.close()
# Mock API for remaining downloads
def mock_response_factory(url, **kwargs):
symbol = kwargs.get('params', {}).get('symbol', 'GOOGL')
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_alpha_vantage_response(symbol, dates)
return mock_response
mock_get.side_effect = mock_response_factory
# Check missing coverage
missing = manager.get_missing_coverage(dates[0], dates[2])
# AAPL and MSFT should be missing only date 3
# GOOGL, AMZN, NVDA should be missing all dates
assert missing["AAPL"] == {dates[2]}
assert missing["MSFT"] == {dates[2]}
assert missing["GOOGL"] == set(dates)
# Download missing data
requested_dates = set(dates)
result = manager.download_missing_data_prioritized(missing, requested_dates)
assert result["success"] is True
assert len(result["downloaded"]) == 5
# Verify all dates are now available
available_dates = manager.get_available_trading_dates(dates[0], dates[2])
assert set(available_dates) == set(dates)
@patch('api.price_data_manager.requests.get')
def test_priority_based_download_order(self, mock_get, manager, mock_alpha_vantage_response):
"""Test that downloads prioritize symbols that complete the most dates."""
dates = ["2025-01-20", "2025-01-21", "2025-01-22"]
# Prepopulate with specific pattern to create different priorities
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
created_at = datetime.utcnow().isoformat() + "Z"
# AAPL: Has date 1 only (missing 2 dates)
cursor.execute("""
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
VALUES ('AAPL', ?, 150.0, 155.0, 149.0, 154.0, 1000000, ?)
""", (dates[0], created_at))
# MSFT: Has date 1 and 2 (missing 1 date)
for date in dates[:2]:
cursor.execute("""
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
VALUES ('MSFT', ?, 150.0, 155.0, 149.0, 154.0, 1000000, ?)
""", (date, created_at))
# GOOGL, AMZN, NVDA: No data (missing 3 dates)
conn.commit()
conn.close()
# Track download order
download_order = []
def mock_response_factory(url, **kwargs):
symbol = kwargs.get('params', {}).get('symbol')
download_order.append(symbol)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_alpha_vantage_response(symbol, dates)
return mock_response
mock_get.side_effect = mock_response_factory
# Download missing data
missing = manager.get_missing_coverage(dates[0], dates[2])
requested_dates = set(dates)
result = manager.download_missing_data_prioritized(missing, requested_dates)
assert result["success"] is True
# Verify symbols with highest impact were downloaded first
# GOOGL, AMZN, NVDA should be first (3 dates each)
# Then AAPL (2 dates)
# Then MSFT (1 date)
first_three = set(download_order[:3])
assert first_three == {"GOOGL", "AMZN", "NVDA"}
assert download_order[3] == "AAPL"
assert download_order[4] == "MSFT"
class TestRateLimitHandling:
"""Test rate limit handling during downloads."""
@patch('api.price_data_manager.requests.get')
def test_rate_limit_stops_downloads(self, mock_get, manager, mock_alpha_vantage_response):
"""Test that rate limit error stops further downloads."""
dates = ["2025-01-20"]
# First symbol succeeds, second hits rate limit
responses = [
# AAPL succeeds (or whichever symbol is first in priority)
Mock(status_code=200, json=lambda: mock_alpha_vantage_response("AAPL", dates)),
# MSFT hits rate limit
Mock(status_code=200, json=lambda: {"Note": "Thank you for using Alpha Vantage! Our standard API call frequency is 25 calls per day."}),
]
mock_get.side_effect = responses
missing = manager.get_missing_coverage("2025-01-20", "2025-01-20")
requested_dates = {"2025-01-20"}
result = manager.download_missing_data_prioritized(missing, requested_dates)
# Partial success - one symbol downloaded
assert result["success"] is True # At least one succeeded
assert len(result["downloaded"]) >= 1
assert result["rate_limited"] is True
assert len(result["failed"]) >= 1
# Completed dates should be empty (need all symbols for complete date)
assert len(result["dates_completed"]) == 0
@patch('api.price_data_manager.requests.get')
def test_graceful_handling_of_mixed_failures(self, mock_get, manager, mock_alpha_vantage_response):
"""Test handling of mix of successes, failures, and rate limits."""
dates = ["2025-01-20"]
call_count = [0]
def response_factory(url, **kwargs):
"""Return different responses for different calls."""
call_count[0] += 1
mock_response = Mock()
if call_count[0] == 1:
# First call succeeds
mock_response.status_code = 200
mock_response.json.return_value = mock_alpha_vantage_response("AAPL", dates)
elif call_count[0] == 2:
# Second call fails with server error
mock_response.status_code = 500
mock_response.raise_for_status.side_effect = Exception("Server error")
else:
# Third call hits rate limit
mock_response.status_code = 200
mock_response.json.return_value = {"Note": "rate limit exceeded"}
return mock_response
mock_get.side_effect = response_factory
missing = manager.get_missing_coverage("2025-01-20", "2025-01-20")
requested_dates = {"2025-01-20"}
result = manager.download_missing_data_prioritized(missing, requested_dates)
# Should have handled errors gracefully
assert "downloaded" in result
assert "failed" in result
assert len(result["downloaded"]) >= 1
class TestCoverageTracking:
"""Test coverage tracking functionality."""
@patch('api.price_data_manager.requests.get')
def test_coverage_updated_after_download(self, mock_get, manager, mock_alpha_vantage_response):
"""Test that coverage table is updated after successful download."""
dates = ["2025-01-20", "2025-01-21"]
mock_get.return_value = Mock(
status_code=200,
json=lambda: mock_alpha_vantage_response("AAPL", dates)
)
# Download for single symbol
data = manager._download_symbol("AAPL")
stored_dates = manager._store_symbol_data("AAPL", data, set(dates))
manager._update_coverage("AAPL", dates[0], dates[1])
# Verify coverage was recorded
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
cursor.execute("""
SELECT symbol, start_date, end_date, source
FROM price_data_coverage
WHERE symbol = 'AAPL'
""")
row = cursor.fetchone()
conn.close()
assert row is not None
assert row[0] == "AAPL"
assert row[1] == dates[0]
assert row[2] == dates[1]
assert row[3] == "alpha_vantage"
def test_coverage_gap_detection_accuracy(self, manager):
"""Test accuracy of coverage gap detection."""
# Populate database with specific pattern
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
created_at = datetime.utcnow().isoformat() + "Z"
test_data = [
("AAPL", "2025-01-20"),
("AAPL", "2025-01-21"),
("AAPL", "2025-01-23"), # Gap on 2025-01-22
("MSFT", "2025-01-20"),
("MSFT", "2025-01-22"), # Gap on 2025-01-21
]
for symbol, date in test_data:
cursor.execute("""
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
VALUES (?, ?, 150.0, 155.0, 149.0, 154.0, 1000000, ?)
""", (symbol, date, created_at))
conn.commit()
conn.close()
# Check for gaps in range
missing = manager.get_missing_coverage("2025-01-20", "2025-01-23")
# AAPL should be missing 2025-01-22
assert "2025-01-22" in missing["AAPL"]
assert "2025-01-20" not in missing["AAPL"]
# MSFT should be missing 2025-01-21 and 2025-01-23
assert "2025-01-21" in missing["MSFT"]
assert "2025-01-23" in missing["MSFT"]
assert "2025-01-20" not in missing["MSFT"]
class TestDataValidation:
"""Test data validation during download and storage."""
@patch('api.price_data_manager.requests.get')
def test_invalid_response_handling(self, mock_get, manager):
"""Test handling of invalid API responses."""
# Mock response with missing required fields
mock_get.return_value = Mock(
status_code=200,
json=lambda: {"invalid": "response"}
)
with pytest.raises(DownloadError, match="Invalid response format"):
manager._download_symbol("AAPL")
@patch('api.price_data_manager.requests.get')
def test_empty_time_series_handling(self, mock_get, manager):
"""Test handling of empty time series data (should raise error for missing data)."""
# API returns valid structure but no time series
mock_get.return_value = Mock(
status_code=200,
json=lambda: {
"Meta Data": {"2. Symbol": "AAPL"},
# Missing "Time Series (Daily)" key
}
)
with pytest.raises(DownloadError, match="Invalid response format"):
manager._download_symbol("AAPL")
def test_date_filtering_during_storage(self, manager):
"""Test that only requested dates are stored."""
# Create mock data with dates outside requested range
data = {
"Meta Data": {"2. Symbol": "AAPL"},
"Time Series (Daily)": {
"2025-01-15": {"1. open": "145.00", "2. high": "150.00", "3. low": "144.00", "4. close": "149.00", "5. volume": "1000000"},
"2025-01-20": {"1. open": "150.00", "2. high": "155.00", "3. low": "149.00", "4. close": "154.00", "5. volume": "1000000"},
"2025-01-21": {"1. open": "154.00", "2. high": "156.00", "3. low": "153.00", "4. close": "155.00", "5. volume": "1100000"},
"2025-01-25": {"1. open": "156.00", "2. high": "158.00", "3. low": "155.00", "4. close": "157.00", "5. volume": "1200000"},
}
}
# Request only specific dates
requested_dates = {"2025-01-20", "2025-01-21"}
stored_dates = manager._store_symbol_data("AAPL", data, requested_dates)
# Only requested dates should be stored
assert set(stored_dates) == requested_dates
# Verify in database
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
cursor.execute("SELECT date FROM price_data WHERE symbol = 'AAPL' ORDER BY date")
db_dates = [row[0] for row in cursor.fetchall()]
conn.close()
assert db_dates == ["2025-01-20", "2025-01-21"]

View File

@@ -0,0 +1,149 @@
"""
Unit tests for api/date_utils.py
Tests date range expansion, validation, and utility functions.
"""
import pytest
from datetime import datetime, timedelta
from api.date_utils import (
expand_date_range,
validate_date_range,
get_max_simulation_days
)
class TestExpandDateRange:
"""Test expand_date_range function."""
def test_single_day(self):
"""Test single day range (start == end)."""
result = expand_date_range("2025-01-20", "2025-01-20")
assert result == ["2025-01-20"]
def test_multi_day_range(self):
"""Test multiple day range."""
result = expand_date_range("2025-01-20", "2025-01-22")
assert result == ["2025-01-20", "2025-01-21", "2025-01-22"]
def test_week_range(self):
"""Test week-long range."""
result = expand_date_range("2025-01-20", "2025-01-26")
assert len(result) == 7
assert result[0] == "2025-01-20"
assert result[-1] == "2025-01-26"
def test_chronological_order(self):
"""Test dates are in chronological order."""
result = expand_date_range("2025-01-20", "2025-01-25")
for i in range(len(result) - 1):
assert result[i] < result[i + 1]
def test_invalid_order(self):
"""Test error when start > end."""
with pytest.raises(ValueError, match="must be <= end_date"):
expand_date_range("2025-01-25", "2025-01-20")
def test_invalid_date_format(self):
"""Test error with invalid date format."""
with pytest.raises(ValueError):
expand_date_range("01-20-2025", "01-21-2025")
def test_month_boundary(self):
"""Test range spanning month boundary."""
result = expand_date_range("2025-01-30", "2025-02-02")
assert result == ["2025-01-30", "2025-01-31", "2025-02-01", "2025-02-02"]
def test_year_boundary(self):
"""Test range spanning year boundary."""
result = expand_date_range("2024-12-30", "2025-01-02")
assert len(result) == 4
assert "2024-12-31" in result
assert "2025-01-01" in result
class TestValidateDateRange:
"""Test validate_date_range function."""
def test_valid_single_day(self):
"""Test valid single day range."""
# Should not raise
validate_date_range("2025-01-20", "2025-01-20", max_days=30)
def test_valid_multi_day(self):
"""Test valid multi-day range."""
# Should not raise
validate_date_range("2025-01-20", "2025-01-25", max_days=30)
def test_max_days_boundary(self):
"""Test exactly at max days limit."""
# 30 days total (inclusive)
start = "2025-01-01"
end = "2025-01-30"
# Should not raise
validate_date_range(start, end, max_days=30)
def test_exceeds_max_days(self):
"""Test exceeds max days limit."""
start = "2025-01-01"
end = "2025-02-01" # 32 days
with pytest.raises(ValueError, match="Date range too large: 32 days"):
validate_date_range(start, end, max_days=30)
def test_invalid_order(self):
"""Test start > end."""
with pytest.raises(ValueError, match="must be <= end_date"):
validate_date_range("2025-01-25", "2025-01-20", max_days=30)
def test_future_date_rejected(self):
"""Test future dates are rejected."""
tomorrow = (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d")
next_week = (datetime.now() + timedelta(days=7)).strftime("%Y-%m-%d")
with pytest.raises(ValueError, match="cannot be in the future"):
validate_date_range(tomorrow, next_week, max_days=30)
def test_today_allowed(self):
"""Test today's date is allowed."""
today = datetime.now().strftime("%Y-%m-%d")
# Should not raise
validate_date_range(today, today, max_days=30)
def test_past_dates_allowed(self):
"""Test past dates are allowed."""
# Should not raise
validate_date_range("2020-01-01", "2020-01-10", max_days=30)
def test_invalid_date_format(self):
"""Test invalid date format raises error."""
with pytest.raises(ValueError, match="Invalid date format"):
validate_date_range("01-20-2025", "01-21-2025", max_days=30)
def test_custom_max_days(self):
"""Test custom max_days parameter."""
# Should raise with max_days=5
with pytest.raises(ValueError, match="Date range too large: 10 days"):
validate_date_range("2025-01-01", "2025-01-10", max_days=5)
class TestGetMaxSimulationDays:
"""Test get_max_simulation_days function."""
def test_default_value(self, monkeypatch):
"""Test default value when env var not set."""
monkeypatch.delenv("MAX_SIMULATION_DAYS", raising=False)
result = get_max_simulation_days()
assert result == 30
def test_env_var_override(self, monkeypatch):
"""Test environment variable override."""
monkeypatch.setenv("MAX_SIMULATION_DAYS", "60")
result = get_max_simulation_days()
assert result == 60
def test_env_var_string_to_int(self, monkeypatch):
"""Test env var is converted to int."""
monkeypatch.setenv("MAX_SIMULATION_DAYS", "100")
result = get_max_simulation_days()
assert isinstance(result, int)
assert result == 100

View File

@@ -0,0 +1,572 @@
"""
Unit tests for api/price_data_manager.py
Tests price data management, coverage detection, download prioritization,
and rate limit handling.
"""
import pytest
import json
import os
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock, call
from pathlib import Path
import tempfile
import sqlite3
from api.price_data_manager import (
PriceDataManager,
RateLimitError,
DownloadError
)
from api.database import initialize_database, get_db_connection
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.db', delete=False) as f:
db_path = f.name
initialize_database(db_path)
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
@pytest.fixture
def temp_symbols_config():
"""Create temporary symbols config for testing."""
symbols_data = {
"symbols": ["AAPL", "MSFT", "GOOGL"],
"description": "Test symbols",
"total_symbols": 3
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(symbols_data, f)
config_path = f.name
yield config_path
# Cleanup
if os.path.exists(config_path):
os.unlink(config_path)
@pytest.fixture
def manager(temp_db, temp_symbols_config):
"""Create PriceDataManager instance with temp database and config."""
return PriceDataManager(
db_path=temp_db,
symbols_config=temp_symbols_config,
api_key="test_api_key"
)
@pytest.fixture
def populated_db(temp_db):
"""Create database with sample price data."""
conn = get_db_connection(temp_db)
cursor = conn.cursor()
# Insert sample price data for multiple symbols and dates
test_data = [
("AAPL", "2025-01-20", 150.0, 155.0, 149.0, 154.0, 1000000),
("AAPL", "2025-01-21", 154.0, 156.0, 153.0, 155.0, 1100000),
("MSFT", "2025-01-20", 380.0, 385.0, 379.0, 383.0, 2000000),
("MSFT", "2025-01-21", 383.0, 387.0, 382.0, 386.0, 2100000),
("GOOGL", "2025-01-20", 140.0, 142.0, 139.0, 141.0, 1500000),
# Note: GOOGL missing 2025-01-21
]
created_at = datetime.utcnow().isoformat() + "Z"
for symbol, date, open_p, high, low, close, volume in test_data:
cursor.execute("""
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (symbol, date, open_p, high, low, close, volume, created_at))
# Insert coverage data
cursor.execute("""
INSERT INTO price_data_coverage (symbol, start_date, end_date, downloaded_at, source)
VALUES
('AAPL', '2025-01-20', '2025-01-21', ?, 'test'),
('MSFT', '2025-01-20', '2025-01-21', ?, 'test'),
('GOOGL', '2025-01-20', '2025-01-20', ?, 'test')
""", (created_at, created_at, created_at))
conn.commit()
conn.close()
return temp_db
class TestPriceDataManagerInit:
"""Test PriceDataManager initialization."""
def test_init_with_defaults(self, temp_db):
"""Test initialization with default parameters."""
with patch.dict(os.environ, {"ALPHAADVANTAGE_API_KEY": "env_key"}):
manager = PriceDataManager(db_path=temp_db)
assert manager.db_path == temp_db
assert manager.api_key == "env_key"
assert manager.symbols_config == "configs/nasdaq100_symbols.json"
def test_init_with_custom_params(self, temp_db, temp_symbols_config):
"""Test initialization with custom parameters."""
manager = PriceDataManager(
db_path=temp_db,
symbols_config=temp_symbols_config,
api_key="custom_key"
)
assert manager.db_path == temp_db
assert manager.api_key == "custom_key"
assert manager.symbols_config == temp_symbols_config
def test_load_symbols_success(self, manager):
"""Test successful symbol loading from config."""
assert manager.symbols == ["AAPL", "MSFT", "GOOGL"]
def test_load_symbols_file_not_found(self, temp_db):
"""Test handling of missing symbols config file uses fallback."""
manager = PriceDataManager(
db_path=temp_db,
symbols_config="nonexistent.json",
api_key="test_key"
)
# Should use fallback symbols list
assert len(manager.symbols) > 0
assert "AAPL" in manager.symbols
def test_load_symbols_invalid_json(self, temp_db):
"""Test handling of invalid JSON in symbols config."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write("invalid json{")
bad_config = f.name
try:
with pytest.raises(json.JSONDecodeError):
PriceDataManager(
db_path=temp_db,
symbols_config=bad_config,
api_key="test_key"
)
finally:
os.unlink(bad_config)
def test_missing_api_key(self, temp_db, temp_symbols_config):
"""Test initialization without API key."""
with patch.dict(os.environ, {}, clear=True):
manager = PriceDataManager(
db_path=temp_db,
symbols_config=temp_symbols_config
)
assert manager.api_key is None
class TestGetSymbolDates:
"""Test get_symbol_dates method."""
def test_get_symbol_dates_with_data(self, manager, populated_db):
"""Test retrieving dates for symbol with data."""
manager.db_path = populated_db
dates = manager.get_symbol_dates("AAPL")
assert dates == {"2025-01-20", "2025-01-21"}
def test_get_symbol_dates_no_data(self, manager):
"""Test retrieving dates for symbol without data."""
dates = manager.get_symbol_dates("TSLA")
assert dates == set()
def test_get_symbol_dates_partial_data(self, manager, populated_db):
"""Test retrieving dates for symbol with partial data."""
manager.db_path = populated_db
dates = manager.get_symbol_dates("GOOGL")
assert dates == {"2025-01-20"}
class TestGetMissingCoverage:
"""Test get_missing_coverage method."""
def test_missing_coverage_empty_db(self, manager):
"""Test missing coverage with empty database."""
missing = manager.get_missing_coverage("2025-01-20", "2025-01-21")
# All symbols should be missing all dates
assert "AAPL" in missing
assert "MSFT" in missing
assert "GOOGL" in missing
assert missing["AAPL"] == {"2025-01-20", "2025-01-21"}
def test_missing_coverage_partial_db(self, manager, populated_db):
"""Test missing coverage with partial data."""
manager.db_path = populated_db
missing = manager.get_missing_coverage("2025-01-20", "2025-01-21")
# AAPL and MSFT have all dates, GOOGL missing 2025-01-21
assert "AAPL" not in missing or len(missing["AAPL"]) == 0
assert "MSFT" not in missing or len(missing["MSFT"]) == 0
assert "GOOGL" in missing
assert missing["GOOGL"] == {"2025-01-21"}
def test_missing_coverage_complete_db(self, manager, populated_db):
"""Test missing coverage when all data available."""
manager.db_path = populated_db
missing = manager.get_missing_coverage("2025-01-20", "2025-01-20")
# All symbols have 2025-01-20
for symbol in ["AAPL", "MSFT", "GOOGL"]:
assert symbol not in missing or len(missing[symbol]) == 0
def test_missing_coverage_single_date(self, manager, populated_db):
"""Test missing coverage for single date."""
manager.db_path = populated_db
missing = manager.get_missing_coverage("2025-01-21", "2025-01-21")
# Only GOOGL missing 2025-01-21
assert "GOOGL" in missing
assert missing["GOOGL"] == {"2025-01-21"}
class TestPrioritizeDownloads:
"""Test prioritize_downloads method."""
def test_prioritize_single_symbol(self, manager):
"""Test prioritization with single symbol missing data."""
missing_coverage = {"AAPL": {"2025-01-20", "2025-01-21"}}
requested_dates = {"2025-01-20", "2025-01-21"}
prioritized = manager.prioritize_downloads(missing_coverage, requested_dates)
assert prioritized == ["AAPL"]
def test_prioritize_multiple_symbols_equal_impact(self, manager):
"""Test prioritization with equal impact symbols."""
missing_coverage = {
"AAPL": {"2025-01-20", "2025-01-21"},
"MSFT": {"2025-01-20", "2025-01-21"}
}
requested_dates = {"2025-01-20", "2025-01-21"}
prioritized = manager.prioritize_downloads(missing_coverage, requested_dates)
# Both should be included (order may vary)
assert set(prioritized) == {"AAPL", "MSFT"}
assert len(prioritized) == 2
def test_prioritize_by_impact(self, manager):
"""Test prioritization by date completion impact."""
missing_coverage = {
"AAPL": {"2025-01-20", "2025-01-21", "2025-01-22"}, # High impact (3 dates)
"MSFT": {"2025-01-20"}, # Low impact (1 date)
"GOOGL": {"2025-01-21", "2025-01-22"} # Medium impact (2 dates)
}
requested_dates = {"2025-01-20", "2025-01-21", "2025-01-22"}
prioritized = manager.prioritize_downloads(missing_coverage, requested_dates)
# AAPL should be first (highest impact)
assert prioritized[0] == "AAPL"
# GOOGL should be second
assert prioritized[1] == "GOOGL"
# MSFT should be last (lowest impact)
assert prioritized[2] == "MSFT"
def test_prioritize_excludes_irrelevant_dates(self, manager):
"""Test that symbols with no impact on requested dates are excluded."""
missing_coverage = {
"AAPL": {"2025-01-20"}, # Relevant
"MSFT": {"2025-01-25", "2025-01-26"} # Not relevant
}
requested_dates = {"2025-01-20", "2025-01-21"}
prioritized = manager.prioritize_downloads(missing_coverage, requested_dates)
# Only AAPL should be included
assert prioritized == ["AAPL"]
class TestGetAvailableTradingDates:
"""Test get_available_trading_dates method."""
def test_available_dates_empty_db(self, manager):
"""Test with empty database returns no dates."""
available = manager.get_available_trading_dates("2025-01-20", "2025-01-21")
assert available == []
def test_available_dates_complete_range(self, manager, populated_db):
"""Test with complete data for all symbols in range."""
manager.db_path = populated_db
available = manager.get_available_trading_dates("2025-01-20", "2025-01-20")
assert available == ["2025-01-20"]
def test_available_dates_partial_range(self, manager, populated_db):
"""Test with partial data (some symbols missing some dates)."""
manager.db_path = populated_db
available = manager.get_available_trading_dates("2025-01-20", "2025-01-21")
# 2025-01-20 has all symbols, 2025-01-21 missing GOOGL
assert available == ["2025-01-20"]
def test_available_dates_filters_incomplete(self, manager, populated_db):
"""Test that dates with incomplete symbol coverage are filtered."""
manager.db_path = populated_db
available = manager.get_available_trading_dates("2025-01-21", "2025-01-21")
# 2025-01-21 is missing GOOGL, so not complete
assert available == []
class TestDownloadSymbol:
"""Test _download_symbol method (Alpha Vantage API calls)."""
@patch('api.price_data_manager.requests.get')
def test_download_success(self, mock_get, manager):
"""Test successful symbol download."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"Meta Data": {"2. Symbol": "AAPL"},
"Time Series (Daily)": {
"2025-01-20": {
"1. open": "150.00",
"2. high": "155.00",
"3. low": "149.00",
"4. close": "154.00",
"5. volume": "1000000"
}
}
}
mock_get.return_value = mock_response
data = manager._download_symbol("AAPL")
assert data["Meta Data"]["2. Symbol"] == "AAPL"
assert "2025-01-20" in data["Time Series (Daily)"]
mock_get.assert_called_once()
@patch('api.price_data_manager.requests.get')
def test_download_rate_limit(self, mock_get, manager):
"""Test rate limit detection."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"Note": "Thank you for using Alpha Vantage! Our standard API call frequency is 25 calls per day."
}
mock_get.return_value = mock_response
with pytest.raises(RateLimitError):
manager._download_symbol("AAPL")
@patch('api.price_data_manager.requests.get')
def test_download_http_error(self, mock_get, manager):
"""Test HTTP error handling."""
mock_response = Mock()
mock_response.status_code = 500
mock_response.raise_for_status.side_effect = Exception("Server error")
mock_get.return_value = mock_response
with pytest.raises(DownloadError):
manager._download_symbol("AAPL")
@patch('api.price_data_manager.requests.get')
def test_download_invalid_response(self, mock_get, manager):
"""Test handling of invalid API response."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {} # Missing required fields
mock_get.return_value = mock_response
with pytest.raises(DownloadError, match="Invalid response format"):
manager._download_symbol("AAPL")
def test_download_missing_api_key(self, manager):
"""Test download without API key."""
manager.api_key = None
with pytest.raises(DownloadError, match="API key not configured"):
manager._download_symbol("AAPL")
class TestStoreSymbolData:
"""Test _store_symbol_data method."""
def test_store_symbol_data_success(self, manager):
"""Test successful data storage."""
data = {
"Meta Data": {"2. Symbol": "AAPL"},
"Time Series (Daily)": {
"2025-01-20": {
"1. open": "150.00",
"2. high": "155.00",
"3. low": "149.00",
"4. close": "154.00",
"5. volume": "1000000"
},
"2025-01-21": {
"1. open": "154.00",
"2. high": "156.00",
"3. low": "153.00",
"4. close": "155.00",
"5. volume": "1100000"
}
}
}
requested_dates = {"2025-01-20", "2025-01-21"}
stored_dates = manager._store_symbol_data("AAPL", data, requested_dates)
# Returns list, not set
assert set(stored_dates) == {"2025-01-20", "2025-01-21"}
# Verify data in database
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM price_data WHERE symbol = 'AAPL'")
count = cursor.fetchone()[0]
assert count == 2
conn.close()
def test_store_filters_by_requested_dates(self, manager):
"""Test that only requested dates are stored."""
data = {
"Meta Data": {"2. Symbol": "AAPL"},
"Time Series (Daily)": {
"2025-01-20": {
"1. open": "150.00",
"2. high": "155.00",
"3. low": "149.00",
"4. close": "154.00",
"5. volume": "1000000"
},
"2025-01-21": {
"1. open": "154.00",
"2. high": "156.00",
"3. low": "153.00",
"4. close": "155.00",
"5. volume": "1100000"
}
}
}
requested_dates = {"2025-01-20"} # Only request one date
stored_dates = manager._store_symbol_data("AAPL", data, requested_dates)
# Returns list, not set
assert set(stored_dates) == {"2025-01-20"}
# Verify only one date in database
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM price_data WHERE symbol = 'AAPL'")
count = cursor.fetchone()[0]
assert count == 1
conn.close()
class TestUpdateCoverage:
"""Test _update_coverage method."""
def test_update_coverage_new_symbol(self, manager):
"""Test coverage tracking for new symbol."""
manager._update_coverage("AAPL", "2025-01-20", "2025-01-21")
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
cursor.execute("""
SELECT symbol, start_date, end_date, source
FROM price_data_coverage
WHERE symbol = 'AAPL'
""")
row = cursor.fetchone()
conn.close()
assert row is not None
assert row[0] == "AAPL"
assert row[1] == "2025-01-20"
assert row[2] == "2025-01-21"
assert row[3] == "alpha_vantage"
def test_update_coverage_existing_symbol(self, manager, populated_db):
"""Test coverage update for existing symbol."""
manager.db_path = populated_db
# Update with new range
manager._update_coverage("AAPL", "2025-01-22", "2025-01-23")
conn = get_db_connection(manager.db_path)
cursor = conn.cursor()
cursor.execute("""
SELECT COUNT(*) FROM price_data_coverage WHERE symbol = 'AAPL'
""")
count = cursor.fetchone()[0]
conn.close()
# Should have 2 coverage records now
assert count == 2
class TestDownloadMissingDataPrioritized:
"""Test download_missing_data_prioritized method (integration)."""
@patch.object(PriceDataManager, '_download_symbol')
@patch.object(PriceDataManager, '_store_symbol_data')
@patch.object(PriceDataManager, '_update_coverage')
def test_download_all_success(self, mock_update, mock_store, mock_download, manager):
"""Test successful download of all missing symbols."""
missing_coverage = {
"AAPL": {"2025-01-20"},
"MSFT": {"2025-01-20"}
}
requested_dates = {"2025-01-20"}
mock_download.return_value = {"Meta Data": {}, "Time Series (Daily)": {}}
mock_store.return_value = {"2025-01-20"}
result = manager.download_missing_data_prioritized(missing_coverage, requested_dates)
assert result["success"] is True
assert len(result["downloaded"]) == 2
assert result["rate_limited"] is False
assert mock_download.call_count == 2
@patch.object(PriceDataManager, '_download_symbol')
def test_download_rate_limited_mid_process(self, mock_download, manager):
"""Test graceful handling of rate limit during downloads."""
missing_coverage = {
"AAPL": {"2025-01-20"},
"MSFT": {"2025-01-20"},
"GOOGL": {"2025-01-20"}
}
requested_dates = {"2025-01-20"}
# First call succeeds, second raises rate limit
mock_download.side_effect = [
{"Meta Data": {"2. Symbol": "AAPL"}, "Time Series (Daily)": {"2025-01-20": {}}},
RateLimitError("Rate limit reached")
]
with patch.object(manager, '_store_symbol_data', return_value={"2025-01-20"}):
with patch.object(manager, '_update_coverage'):
result = manager.download_missing_data_prioritized(missing_coverage, requested_dates)
assert result["success"] is True # Partial success
assert len(result["downloaded"]) == 1
assert result["rate_limited"] is True
assert len(result["failed"]) == 2 # MSFT and GOOGL not downloaded
@patch.object(PriceDataManager, '_download_symbol')
def test_download_all_failed(self, mock_download, manager):
"""Test handling when all downloads fail."""
missing_coverage = {"AAPL": {"2025-01-20"}}
requested_dates = {"2025-01-20"}
mock_download.side_effect = DownloadError("Network error")
result = manager.download_missing_data_prioritized(missing_coverage, requested_dates)
assert result["success"] is False
assert len(result["downloaded"]) == 0
assert len(result["failed"]) == 1

View File

@@ -12,6 +12,7 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from tools.general_tools import get_config_value
from api.database import get_db_connection
all_nasdaq_100_symbols = [
"NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA",
@@ -47,143 +48,95 @@ def get_yesterday_date(today_date: str) -> str:
yesterday_date = yesterday_dt.strftime("%Y-%m-%d")
return yesterday_date
def get_open_prices(today_date: str, symbols: List[str], merged_path: Optional[str] = None) -> Dict[str, Optional[float]]:
"""data/merged.jsonl 中读取指定日期与标的的开盘价。
def get_open_prices(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Dict[str, Optional[float]]:
"""price_data 数据库表中读取指定日期与标的的开盘价。
Args:
today_date: 日期字符串,格式 YYYY-MM-DD。
symbols: 需要查询的股票代码列表。
merged_path: 可选,自定义 merged.jsonl 路径;默认读取项目根目录下 data/merged.jsonl
merged_path: 已废弃,保留用于向后兼容
db_path: 数据库路径,默认 data/jobs.db。
Returns:
{symbol_price: open_price 或 None} 的字典;若未找到对应日期或标的,则值为 None。
"""
wanted = set(symbols)
results: Dict[str, Optional[float]] = {}
if merged_path is None:
base_dir = Path(__file__).resolve().parents[1]
merged_file = base_dir / "data" / "merged.jsonl"
else:
merged_file = Path(merged_path)
try:
conn = get_db_connection(db_path)
cursor = conn.cursor()
if not merged_file.exists():
return results
# Query all requested symbols for the date
placeholders = ','.join('?' * len(symbols))
query = f"""
SELECT symbol, open
FROM price_data
WHERE date = ? AND symbol IN ({placeholders})
"""
with merged_file.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
doc = json.loads(line)
except Exception:
continue
meta = doc.get("Meta Data", {}) if isinstance(doc, dict) else {}
sym = meta.get("2. Symbol")
if sym not in wanted:
continue
series = doc.get("Time Series (Daily)", {})
if not isinstance(series, dict):
continue
bar = series.get(today_date)
if isinstance(bar, dict):
open_val = bar.get("1. buy price")
try:
results[f'{sym}_price'] = float(open_val) if open_val is not None else None
except Exception:
results[f'{sym}_price'] = None
params = [today_date] + list(symbols)
cursor.execute(query, params)
# Build results dict
for row in cursor.fetchall():
symbol = row[0]
open_price = row[1]
results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
conn.close()
except Exception as e:
# Log error but return empty results to maintain compatibility
print(f"Error querying price data: {e}")
return results
def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]:
"""data/merged.jsonl 中读取指定日期与股票的昨日买入价和卖出价。
def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]:
"""price_data 数据库表中读取指定日期与股票的昨日买入价和卖出价。
Args:
today_date: 日期字符串,格式 YYYY-MM-DD代表今天日期。
symbols: 需要查询的股票代码列表。
merged_path: 可选,自定义 merged.jsonl 路径;默认读取项目根目录下 data/merged.jsonl
merged_path: 已废弃,保留用于向后兼容
db_path: 数据库路径,默认 data/jobs.db。
Returns:
(买入价字典, 卖出价字典) 的元组;若未找到对应日期或标的,则值为 None。
"""
wanted = set(symbols)
buy_results: Dict[str, Optional[float]] = {}
sell_results: Dict[str, Optional[float]] = {}
if merged_path is None:
base_dir = Path(__file__).resolve().parents[1]
merged_file = base_dir / "data" / "merged.jsonl"
else:
merged_file = Path(merged_path)
if not merged_file.exists():
return buy_results, sell_results
yesterday_date = get_yesterday_date(today_date)
with merged_file.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
doc = json.loads(line)
except Exception:
continue
meta = doc.get("Meta Data", {}) if isinstance(doc, dict) else {}
sym = meta.get("2. Symbol")
if sym not in wanted:
continue
series = doc.get("Time Series (Daily)", {})
if not isinstance(series, dict):
continue
# 尝试获取昨日买入价和卖出价
bar = series.get(yesterday_date)
if isinstance(bar, dict):
buy_val = bar.get("1. buy price") # 买入价字段
sell_val = bar.get("4. sell price") # 卖出价字段
try:
buy_price = float(buy_val) if buy_val is not None else None
sell_price = float(sell_val) if sell_val is not None else None
buy_results[f'{sym}_price'] = buy_price
sell_results[f'{sym}_price'] = sell_price
except Exception:
buy_results[f'{sym}_price'] = None
sell_results[f'{sym}_price'] = None
else:
# 如果昨日没有数据,尝试向前查找最近的交易日
today_dt = datetime.strptime(today_date, "%Y-%m-%d")
yesterday_dt = today_dt - timedelta(days=1)
current_date = yesterday_dt
found_data = False
# 最多向前查找5个交易日
for _ in range(5):
current_date -= timedelta(days=1)
# 跳过周末
while current_date.weekday() >= 5:
current_date -= timedelta(days=1)
check_date = current_date.strftime("%Y-%m-%d")
bar = series.get(check_date)
if isinstance(bar, dict):
buy_val = bar.get("1. buy price")
sell_val = bar.get("4. sell price")
try:
buy_price = float(buy_val) if buy_val is not None else None
sell_price = float(sell_val) if sell_val is not None else None
buy_results[f'{sym}_price'] = buy_price
sell_results[f'{sym}_price'] = sell_price
found_data = True
break
except Exception:
continue
if not found_data:
buy_results[f'{sym}_price'] = None
sell_results[f'{sym}_price'] = None
try:
conn = get_db_connection(db_path)
cursor = conn.cursor()
# Query all requested symbols for yesterday's date
placeholders = ','.join('?' * len(symbols))
query = f"""
SELECT symbol, open, close
FROM price_data
WHERE date = ? AND symbol IN ({placeholders})
"""
params = [yesterday_date] + list(symbols)
cursor.execute(query, params)
# Build results dicts
for row in cursor.fetchall():
symbol = row[0]
open_price = row[1] # Buy price (open)
close_price = row[2] # Sell price (close)
buy_results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
sell_results[f'{symbol}_price'] = float(close_price) if close_price is not None else None
conn.close()
except Exception as e:
# Log error but return empty results to maintain compatibility
print(f"Error querying price data: {e}")
return buy_results, sell_results