mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 01:27:24 -04:00
Compare commits
14 Commits
v0.3.0-alp
...
v0.3.0-alp
| Author | SHA1 | Date | |
|---|---|---|---|
| 2575e0c12a | |||
| 1347e3939f | |||
| 4b25ae96c2 | |||
| 5606df1f51 | |||
| 02c8a48b37 | |||
| c3ea358a12 | |||
| 1bfcdd78b8 | |||
| 76b946449e | |||
| bddf4d8b72 | |||
| 8e7e80807b | |||
| ec2a37e474 | |||
| 20506a379d | |||
| 246dbd1b34 | |||
| 9539d63103 |
14
.env.example
14
.env.example
@@ -21,13 +21,19 @@ JINA_API_KEY=your_jina_key_here # https://jina.ai/
|
||||
# Used for Windmill integration and external API access
|
||||
API_PORT=8080
|
||||
|
||||
# Web Interface Host Port (exposed on host machine)
|
||||
# Container always uses 8888 internally
|
||||
WEB_HTTP_PORT=8888
|
||||
|
||||
# Agent Configuration
|
||||
AGENT_MAX_STEP=30
|
||||
|
||||
# Simulation Configuration
|
||||
# Maximum number of days allowed in a single simulation range
|
||||
# Prevents accidentally requesting very large date ranges
|
||||
MAX_SIMULATION_DAYS=30
|
||||
|
||||
# Price Data Configuration
|
||||
# Automatically download missing price data from Alpha Vantage when needed
|
||||
# If disabled, all price data must be pre-populated in the database
|
||||
AUTO_DOWNLOAD_PRICE_DATA=true
|
||||
|
||||
# Data Volume Configuration
|
||||
# Base directory for all persistent data (will contain data/, logs/, configs/ subdirectories)
|
||||
# Use relative paths (./volumes) or absolute paths (/home/user/ai-trader-volumes)
|
||||
|
||||
3
.github/workflows/docker-release.yml
vendored
3
.github/workflows/docker-release.yml
vendored
@@ -67,8 +67,7 @@ jobs:
|
||||
|
||||
# Only add 'latest' tag for stable releases
|
||||
if [[ "$IS_PRERELEASE" == "false" ]]; then
|
||||
TAGS="$TAGS
|
||||
ghcr.io/$REPO_OWNER_LOWER/ai-trader:latest"
|
||||
TAGS="${TAGS}"$'\n'"ghcr.io/$REPO_OWNER_LOWER/ai-trader:latest"
|
||||
echo "Tagging as both $VERSION and latest"
|
||||
else
|
||||
echo "Pre-release detected - tagging as $VERSION only (NOT latest)"
|
||||
|
||||
69
CHANGELOG.md
69
CHANGELOG.md
@@ -9,6 +9,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [0.3.0] - 2025-10-31
|
||||
|
||||
### Added - Price Data Management & On-Demand Downloads
|
||||
- **SQLite Price Data Storage** - Replaced JSONL files with relational database
|
||||
- `price_data` table for OHLCV data (replaces merged.jsonl)
|
||||
- `price_data_coverage` table for tracking downloaded date ranges
|
||||
- `simulation_runs` table for soft-delete position tracking
|
||||
- Comprehensive indexes for query performance
|
||||
- **On-Demand Price Data Downloads** - Automatic gap filling via Alpha Vantage
|
||||
- Priority-based download strategy (maximize date completion)
|
||||
- Graceful rate limit handling (no pre-configured limits needed)
|
||||
- Smart coverage gap detection
|
||||
- Configurable via `AUTO_DOWNLOAD_PRICE_DATA` (default: true)
|
||||
- **Date Range API** - Simplified date specification
|
||||
- Single date: `{"start_date": "2025-01-20"}`
|
||||
- Date range: `{"start_date": "2025-01-20", "end_date": "2025-01-24"}`
|
||||
- Automatic validation (chronological order, max range, not future)
|
||||
- Configurable max days via `MAX_SIMULATION_DAYS` (default: 30)
|
||||
- **Migration Tooling** - Script to import existing merged.jsonl data
|
||||
- `scripts/migrate_price_data.py` for one-time data migration
|
||||
- Automatic coverage tracking during migration
|
||||
|
||||
### Added - API Service Transformation
|
||||
- **REST API Service** - Complete FastAPI implementation for external orchestration
|
||||
- `POST /simulate/trigger` - Trigger simulation jobs with config, date range, and models
|
||||
@@ -28,13 +48,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- ModelDayExecutor - Single model-day execution engine
|
||||
- SimulationWorker - Job orchestration with date-sequential, model-parallel execution
|
||||
- **Comprehensive Test Suite**
|
||||
- 102 unit and integration tests (85% coverage)
|
||||
- 175 unit and integration tests
|
||||
- 19 database tests (98% coverage)
|
||||
- 23 job manager tests (98% coverage)
|
||||
- 10 model executor tests (84% coverage)
|
||||
- 20 API endpoint tests (81% coverage)
|
||||
- 20 Pydantic model tests (100% coverage)
|
||||
- 10 runtime manager tests (89% coverage)
|
||||
- 22 date utilities tests (100% coverage)
|
||||
- 33 price data manager tests (85% coverage)
|
||||
- 10 on-demand download integration tests
|
||||
- 8 existing integration tests
|
||||
- **Docker Deployment** - Persistent REST API service
|
||||
- API-only deployment (batch mode removed for simplicity)
|
||||
- Single docker-compose service (ai-trader)
|
||||
@@ -55,14 +79,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
### Changed
|
||||
- **Architecture** - Transformed from batch-only to API-first service with database persistence
|
||||
- **Data Storage** - Migrated from JSONL files to SQLite relational database
|
||||
- **Deployment** - Simplified to single API-only Docker service
|
||||
- Price data now stored in `price_data` table instead of `merged.jsonl`
|
||||
- Tools/price_tools.py updated to query database
|
||||
- Position data remains in database (already migrated in earlier versions)
|
||||
- **Deployment** - Simplified to single API-only Docker service (REST API is new in v0.3.0)
|
||||
- **Configuration** - Simplified environment variable configuration
|
||||
- Added configurable API_PORT for host port mapping (default: 8080, customizable for port conflicts)
|
||||
- Removed `RUNTIME_ENV_PATH` (API dynamically manages runtime configs via RuntimeConfigManager)
|
||||
- Removed MCP service port configuration (MATH_HTTP_PORT, SEARCH_HTTP_PORT, TRADE_HTTP_PORT, GETPRICE_HTTP_PORT)
|
||||
- **Added:** `AUTO_DOWNLOAD_PRICE_DATA` (default: true) - Enable on-demand downloads
|
||||
- **Added:** `MAX_SIMULATION_DAYS` (default: 30) - Maximum date range size
|
||||
- **Added:** `API_PORT` for host port mapping (default: 8080, customizable for port conflicts)
|
||||
- **Removed:** `RUNTIME_ENV_PATH` (API dynamically manages runtime configs)
|
||||
- **Removed:** MCP service ports (MATH_HTTP_PORT, SEARCH_HTTP_PORT, TRADE_HTTP_PORT, GETPRICE_HTTP_PORT)
|
||||
- **Removed:** `WEB_HTTP_PORT` (web UI not implemented)
|
||||
- MCP services use fixed internal ports (8000-8003) and are no longer exposed to host
|
||||
- Container always uses port 8080 internally for API (hardcoded in entrypoint.sh)
|
||||
- Only API port (8080) and web dashboard (8888) are exposed to host
|
||||
- Container always uses port 8080 internally for API
|
||||
- Only API port (8080) is exposed to host
|
||||
- Reduces configuration complexity and attack surface
|
||||
- **Requirements** - Added fastapi>=0.120.0, uvicorn[standard]>=0.27.0, pydantic>=2.0.0
|
||||
- **Docker Compose** - Single service (ai-trader) instead of dual-mode
|
||||
@@ -80,15 +110,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- **Automatic Status Transitions** - Job status updates based on model-day completion
|
||||
|
||||
### Performance & Quality
|
||||
- **Code Coverage** - 85% overall (84.63% measured)
|
||||
- **Test Suite** - 175 tests, all passing
|
||||
- Unit tests: 155 tests
|
||||
- Integration tests: 18 tests
|
||||
- API tests: 20+ tests
|
||||
- **Code Coverage** - High coverage for new modules
|
||||
- Date utilities: 100%
|
||||
- Price data manager: 85%
|
||||
- Database layer: 98%
|
||||
- Job manager: 98%
|
||||
- Pydantic models: 100%
|
||||
- Runtime manager: 89%
|
||||
- Model executor: 84%
|
||||
- FastAPI app: 81%
|
||||
- **Test Execution** - 102 tests in ~2.5 seconds
|
||||
- **Zero Test Failures** - All tests passing (threading tests excluded)
|
||||
- **Test Execution** - Fast test suite (~12 seconds for full suite)
|
||||
|
||||
### Integration Ready
|
||||
- **Windmill.dev** - HTTP-based integration with polling support
|
||||
@@ -98,9 +133,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Breaking Changes
|
||||
- **Batch Mode Removed** - All simulations now run through REST API
|
||||
- Simplifies deployment and eliminates dual-mode complexity
|
||||
- Focus on API-first architecture for external orchestration
|
||||
- Migration: Use POST /simulate/trigger endpoint instead of batch execution
|
||||
- v0.2.0 used sequential batch execution via Docker entrypoint
|
||||
- v0.3.0 introduces REST API for external orchestration
|
||||
- Migration: Use `POST /simulate/trigger` endpoint instead of direct script execution
|
||||
- **Data Storage Format Changed** - Price data moved from JSONL to SQLite
|
||||
- Run `python scripts/migrate_price_data.py` to migrate existing merged.jsonl data
|
||||
- `merged.jsonl` no longer used (replaced by `price_data` table)
|
||||
- Automatic on-demand downloads eliminate need for manual data fetching
|
||||
- **Configuration Variables Changed**
|
||||
- Added: `AUTO_DOWNLOAD_PRICE_DATA`, `MAX_SIMULATION_DAYS`, `API_PORT`
|
||||
- Removed: `RUNTIME_ENV_PATH`, MCP service ports, `WEB_HTTP_PORT`
|
||||
- MCP services now use fixed internal ports (not exposed to host)
|
||||
|
||||
## [0.2.0] - 2025-10-31
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ RUN mkdir -p data logs data/agent_data
|
||||
# Make entrypoint executable
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
||||
# Expose MCP service ports, API server, and web dashboard
|
||||
EXPOSE 8000 8001 8002 8003 8080 8888
|
||||
# Expose API server port (MCP services are internal only)
|
||||
EXPOSE 8080
|
||||
|
||||
# Set Python to run unbuffered for real-time logs
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
584
README_CN.md
584
README_CN.md
@@ -1,584 +0,0 @@
|
||||
<div align="center">
|
||||
|
||||
# 🚀 AI-Trader: Which LLM Rules the Market?
|
||||
### *让AI在金融市场中一展身手*
|
||||
|
||||
[](https://python.org)
|
||||
[](LICENSE)
|
||||
|
||||
|
||||
**一个AI股票交易代理系统,让多个大语言模型在纳斯达克100股票池中完全自主决策、同台竞技!**
|
||||
|
||||
## 🏆 当前锦标赛排行榜
|
||||
[*点击查看*](https://hkuds.github.io/AI-Trader/)
|
||||
|
||||
<div align="center">
|
||||
|
||||
### 🥇 **锦标赛期间:(Last Update 2025/10/29)**
|
||||
|
||||
| 🏆 Rank | 🤖 AI Model | 📈 Total Earnings |
|
||||
|---------|-------------|----------------|
|
||||
| **🥇 1st** | **DeepSeek** | 🚀 +16.46% |
|
||||
| 🥈 2nd | MiniMax-M2 | 📊 +12.03% |
|
||||
| 🥉 3rd | GPT-5 | 📊 +9.98% |
|
||||
| 4th | Claude-3.7 | 📊 +9.80% |
|
||||
| 5th | Qwen3-max | 📊 +7.96% |
|
||||
| Baseline | QQQ | 📊 +5.39% |
|
||||
| 6th | Gemini-2.5-flash | 📊 +0.48% |
|
||||
|
||||
### 📊 **实时性能仪表板**
|
||||

|
||||
|
||||
*每日追踪AI模型在纳斯达克100交易中的表现*
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## 📝 本周更新计划
|
||||
|
||||
我们很高兴宣布以下更新将在本周内上线:
|
||||
|
||||
- ⏰ **小时级别交易支持** - 升级至小时级精度交易
|
||||
- 🚀 **服务部署与并行执行** - 部署生产服务 + 并行模型执行
|
||||
- 🎨 **增强前端仪表板** - 添加详细的交易日志可视化(完整交易过程展示)
|
||||
|
||||
敬请期待这些激动人心的改进!🎉
|
||||
|
||||
---
|
||||
|
||||
> 🎯 **核心特色**: 100% AI自主决策,零人工干预,纯工具驱动架构
|
||||
|
||||
[🚀 快速开始](#-快速开始) • [📈 性能分析](#-性能分析) • [🛠️ 配置指南](#-配置指南)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## 🌟 项目介绍
|
||||
|
||||
> **AI-Trader让五个不同的AI模型,每个都采用独特的投资策略,在同一个市场中完全自主决策、竞争,看谁能在纳斯达克100交易中赚得最多!**
|
||||
|
||||
### 🎯 核心特性
|
||||
|
||||
- 🤖 **完全自主决策**: AI代理100%独立分析、决策、执行,零人工干预
|
||||
- 🛠️ **纯工具驱动架构**: 基于MCP工具链,AI通过标准化工具调用完成所有交易操作
|
||||
- 🏆 **多模型竞技场**: 部署多个AI模型(GPT、Claude、Qwen等)进行竞争性交易
|
||||
- 📊 **实时性能分析**: 完整的交易记录、持仓监控和盈亏分析
|
||||
- 🔍 **智能市场情报**: 集成Jina搜索,获取实时市场新闻和财务报告
|
||||
- ⚡ **MCP工具链集成**: 基于Model Context Protocol的模块化工具生态系统
|
||||
- 🔌 **可扩展策略框架**: 支持第三方策略和自定义AI代理集成
|
||||
- ⏰ **历史回放功能**: 时间段回放功能,自动过滤未来信息
|
||||
|
||||
|
||||
---
|
||||
|
||||
### 🎮 交易环境
|
||||
每个AI模型以$10,000起始资金在受控环境中交易纳斯达克100股票,使用真实市场数据和历史回放功能。
|
||||
|
||||
- 💰 **初始资金**: $10,000美元起始余额
|
||||
- 📈 **交易范围**: 纳斯达克100成分股(100只顶级科技股)
|
||||
- ⏰ **交易时间**: 工作日市场时间,支持历史模拟
|
||||
- 📊 **数据集成**: Alpha Vantage API结合Jina AI市场情报
|
||||
- 🔄 **时间管理**: 历史期间回放,自动过滤未来信息
|
||||
|
||||
---
|
||||
|
||||
### 🧠 智能交易能力
|
||||
AI代理完全自主运行,进行市场研究、制定交易决策,并在无人干预的情况下持续优化策略。
|
||||
|
||||
- 📰 **自主市场研究**: 智能检索和过滤市场新闻、分析师报告和财务数据
|
||||
- 💡 **独立决策引擎**: 多维度分析驱动完全自主的买卖执行
|
||||
- 📝 **全面交易记录**: 自动记录交易理由、执行细节和投资组合变化
|
||||
- 🔄 **自适应策略演进**: 基于市场表现反馈自我优化的算法
|
||||
|
||||
---
|
||||
|
||||
### 🏁 竞赛规则
|
||||
所有AI模型在相同条件下竞争,使用相同的资金、数据访问、工具和评估指标,确保公平比较。
|
||||
|
||||
- 💰 **起始资金**: $10,000美元初始投资
|
||||
- 📊 **数据访问**: 统一的市场数据和信息源
|
||||
- ⏰ **运行时间**: 同步的交易时间窗口
|
||||
- 📈 **性能指标**: 所有模型的标准评估标准
|
||||
- 🛠️ **工具访问**: 所有参与者使用相同的MCP工具链
|
||||
|
||||
🎯 **目标**: 确定哪个AI模型通过纯自主操作获得卓越的投资回报!
|
||||
|
||||
### 🚫 零人工干预
|
||||
AI代理完全自主运行,在没有任何人工编程、指导或干预的情况下制定所有交易决策和策略调整。
|
||||
|
||||
- ❌ **无预编程**: 零预设交易策略或算法规则
|
||||
- ❌ **无人工输入**: 完全依赖内在的AI推理能力
|
||||
- ❌ **无手动覆盖**: 交易期间绝对禁止人工干预
|
||||
- ✅ **纯工具执行**: 所有操作仅通过标准化工具调用执行
|
||||
- ✅ **自适应学习**: 基于市场表现反馈的独立策略优化
|
||||
|
||||
---
|
||||
|
||||
## ⏰ 历史回放架构
|
||||
|
||||
AI-Trader Bench的核心创新是其**完全可重放**的交易环境,确保AI代理在历史市场数据上的性能评估具有科学严谨性和可重复性。
|
||||
|
||||
### 🔄 时间控制框架
|
||||
|
||||
#### 📅 灵活的时间设置
|
||||
```json
|
||||
{
|
||||
"date_range": {
|
||||
"init_date": "2025-01-01", // 任意开始日期
|
||||
"end_date": "2025-01-31" // 任意结束日期
|
||||
}
|
||||
}
|
||||
```
|
||||
---
|
||||
|
||||
### 🛡️ 防前瞻数据控制
|
||||
AI只能访问当前时间及之前的数据。不允许未来信息。
|
||||
|
||||
- 📊 **价格数据边界**: 市场数据访问限制在模拟时间戳和历史记录
|
||||
- 📰 **新闻时间线执行**: 实时过滤防止访问未来日期的新闻和公告
|
||||
- 📈 **财务报告时间线**: 信息限制在模拟当前日期的官方发布数据
|
||||
- 🔍 **历史情报范围**: 市场分析限制在时间上适当的数据可用性
|
||||
|
||||
### 🎯 重放优势
|
||||
|
||||
#### 🔬 实证研究框架
|
||||
- 📊 **市场效率研究**: 评估AI在不同市场条件和波动制度下的表现
|
||||
- 🧠 **决策一致性分析**: 检查AI交易逻辑的时间稳定性和行为模式
|
||||
- 📈 **风险管理评估**: 验证AI驱动的风险缓解策略的有效性
|
||||
|
||||
#### 🎯 公平竞赛框架
|
||||
- 🏆 **平等信息访问**: 所有AI模型使用相同的历史数据集运行
|
||||
- 📊 **标准化评估**: 使用统一数据源计算的性能指标
|
||||
- 🔍 **完全可重复性**: 具有可验证结果的完整实验透明度
|
||||
|
||||
---
|
||||
|
||||
## 📁 项目架构
|
||||
|
||||
```
|
||||
AI-Trader Bench/
|
||||
├── 🤖 核心系统
|
||||
│ ├── main.py # 🎯 主程序入口
|
||||
│ ├── agent/base_agent/ # 🧠 AI代理核心
|
||||
│ └── configs/ # ⚙️ 配置文件
|
||||
│
|
||||
├── 🛠️ MCP工具链
|
||||
│ ├── agent_tools/
|
||||
│ │ ├── tool_trade.py # 💰 交易执行
|
||||
│ │ ├── tool_get_price_local.py # 📊 价格查询
|
||||
│ │ ├── tool_jina_search.py # 🔍 信息搜索
|
||||
│ │ └── tool_math.py # 🧮 数学计算
|
||||
│ └── tools/ # 🔧 辅助工具
|
||||
│
|
||||
├── 📊 数据系统
|
||||
│ ├── data/
|
||||
│ │ ├── daily_prices_*.json # 📈 股票价格数据
|
||||
│ │ ├── merged.jsonl # 🔄 统一数据格式
|
||||
│ │ └── agent_data/ # 📝 AI交易记录
|
||||
│ └── calculate_performance.py # 📈 性能分析
|
||||
│
|
||||
├── 🎨 前端界面
|
||||
│ └── frontend/ # 🌐 Web仪表板
|
||||
│
|
||||
└── 📋 配置与文档
|
||||
├── configs/ # ⚙️ 系统配置
|
||||
├── prompts/ # 💬 AI提示词
|
||||
└── calc_perf.sh # 🚀 性能计算脚本
|
||||
```
|
||||
|
||||
### 🔧 核心组件详解
|
||||
|
||||
#### 🎯 主程序 (`main.py`)
|
||||
- **多模型并发**: 同时运行多个AI模型进行交易
|
||||
- **配置管理**: 支持JSON配置文件和环境变量
|
||||
- **日期管理**: 灵活的交易日历和日期范围设置
|
||||
- **错误处理**: 完善的异常处理和重试机制
|
||||
|
||||
#### 🛠️ MCP工具链
|
||||
| 工具 | 功能 | API |
|
||||
|------|------|-----|
|
||||
| **交易工具** | 买入/卖出股票,持仓管理 | `buy()`, `sell()` |
|
||||
| **价格工具** | 实时和历史价格查询 | `get_price_local()` |
|
||||
| **搜索工具** | 市场信息搜索 | `get_information()` |
|
||||
| **数学工具** | 财务计算和分析 | 基础数学运算 |
|
||||
|
||||
#### 📊 数据系统
|
||||
- **📈 价格数据**: 纳斯达克100成分股的完整OHLCV数据
|
||||
- **📝 交易记录**: 每个AI模型的详细交易历史
|
||||
- **📊 性能指标**: 夏普比率、最大回撤、年化收益等
|
||||
- **🔄 数据同步**: 自动化的数据获取和更新机制
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 📋 前置要求
|
||||
|
||||
- **Python 3.10+**
|
||||
- **API密钥**: OpenAI、Alpha Vantage、Jina AI
|
||||
|
||||
|
||||
### ⚡ 一键安装
|
||||
|
||||
```bash
|
||||
# 1. 克隆项目
|
||||
git clone https://github.com/HKUDS/AI-Trader.git
|
||||
cd AI-Trader
|
||||
|
||||
# 2. 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 3. 配置环境变量
|
||||
cp .env.example .env
|
||||
# 编辑 .env 文件,填入你的API密钥
|
||||
```
|
||||
|
||||
### 🔑 环境配置
|
||||
|
||||
创建 `.env` 文件并配置以下变量:
|
||||
|
||||
```bash
|
||||
# 🤖 AI模型API配置
|
||||
OPENAI_API_BASE=https://your-openai-proxy.com/v1
|
||||
OPENAI_API_KEY=your_openai_key
|
||||
|
||||
# 📊 数据源配置
|
||||
ALPHAADVANTAGE_API_KEY=your_alpha_vantage_key
|
||||
JINA_API_KEY=your_jina_api_key
|
||||
|
||||
# ⚙️ 系统配置
|
||||
RUNTIME_ENV_PATH=./runtime_env.json #推荐使用绝对路径
|
||||
|
||||
# 🌐 服务端口配置
|
||||
MATH_HTTP_PORT=8000
|
||||
SEARCH_HTTP_PORT=8001
|
||||
TRADE_HTTP_PORT=8002
|
||||
GETPRICE_HTTP_PORT=8003
|
||||
# 🧠 AI代理配置
|
||||
AGENT_MAX_STEP=30 # 最大推理步数
|
||||
```
|
||||
|
||||
### 📦 依赖包
|
||||
|
||||
```bash
|
||||
# 安装生产环境依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 或手动安装核心依赖
|
||||
pip install langchain langchain-openai langchain-mcp-adapters fastmcp python-dotenv requests numpy pandas
|
||||
```
|
||||
|
||||
## 🎮 运行指南
|
||||
|
||||
### 📊 步骤1: 数据准备 (`./fresh_data.sh`)
|
||||
|
||||
|
||||
```bash
|
||||
# 📈 获取纳斯达克100股票数据
|
||||
cd data
|
||||
python get_daily_price.py
|
||||
|
||||
# 🔄 合并数据为统一格式
|
||||
python merge_jsonl.py
|
||||
```
|
||||
|
||||
### 🛠️ 步骤2: 启动MCP服务
|
||||
|
||||
```bash
|
||||
cd ./agent_tools
|
||||
python start_mcp_services.py
|
||||
```
|
||||
|
||||
### 🚀 步骤3: 启动AI竞技场
|
||||
|
||||
```bash
|
||||
# 🎯 运行主程序 - 让AI们开始交易!
|
||||
python main.py
|
||||
|
||||
# 🎯 或使用自定义配置
|
||||
python main.py configs/my_config.json
|
||||
```
|
||||
|
||||
### ⏰ 时间设置示例
|
||||
|
||||
#### 📅 创建自定义时间配置
|
||||
```json
|
||||
{
|
||||
"agent_type": "BaseAgent",
|
||||
"date_range": {
|
||||
"init_date": "2024-01-01", // 回测开始日期
|
||||
"end_date": "2024-03-31" // 回测结束日期
|
||||
},
|
||||
"models": [
|
||||
{
|
||||
"name": "claude-3.7-sonnet",
|
||||
"basemodel": "anthropic/claude-3.7-sonnet",
|
||||
"signature": "claude-3.7-sonnet",
|
||||
"enabled": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 📈 启动Web界面
|
||||
|
||||
```bash
|
||||
cd docs
|
||||
python3 -m http.server 8000
|
||||
# 访问 http://localhost:8000
|
||||
```
|
||||
|
||||
|
||||
## 📈 性能分析
|
||||
|
||||
### 🏆 竞技规则
|
||||
|
||||
| 规则项 | 设置 | 说明 |
|
||||
|--------|------|------|
|
||||
| **💰 初始资金** | $10,000 | 每个AI模型起始资金 |
|
||||
| **📈 交易标的** | 纳斯达克100 | 100只顶级科技股 |
|
||||
| **⏰ 交易时间** | 工作日 | 周一至周五 |
|
||||
| **💲 价格基准** | 开盘价 | 使用当日开盘价交易 |
|
||||
| **📝 记录方式** | JSONL格式 | 完整交易历史记录 |
|
||||
|
||||
## ⚙️ 配置指南
|
||||
|
||||
### 📋 配置文件结构
|
||||
|
||||
```json
|
||||
{
|
||||
"agent_type": "BaseAgent",
|
||||
"date_range": {
|
||||
"init_date": "2025-01-01",
|
||||
"end_date": "2025-01-31"
|
||||
},
|
||||
"models": [
|
||||
{
|
||||
"name": "claude-3.7-sonnet",
|
||||
"basemodel": "anthropic/claude-3.7-sonnet",
|
||||
"signature": "claude-3.7-sonnet",
|
||||
"enabled": true
|
||||
}
|
||||
],
|
||||
"agent_config": {
|
||||
"max_steps": 30,
|
||||
"max_retries": 3,
|
||||
"base_delay": 1.0,
|
||||
"initial_cash": 10000.0
|
||||
},
|
||||
"log_config": {
|
||||
"log_path": "./data/agent_data"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 🔧 配置参数说明
|
||||
|
||||
| 参数 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| `agent_type` | AI代理类型 | "BaseAgent" |
|
||||
| `max_steps` | 最大推理步数 | 30 |
|
||||
| `max_retries` | 最大重试次数 | 3 |
|
||||
| `base_delay` | 操作延迟(秒) | 1.0 |
|
||||
| `initial_cash` | 初始资金 | $10,000 |
|
||||
|
||||
### 📊 数据格式
|
||||
|
||||
#### 💰 持仓记录 (position.jsonl)
|
||||
```json
|
||||
{
|
||||
"date": "2025-01-20",
|
||||
"id": 1,
|
||||
"this_action": {
|
||||
"action": "buy",
|
||||
"symbol": "AAPL",
|
||||
"amount": 10
|
||||
},
|
||||
"positions": {
|
||||
"AAPL": 10,
|
||||
"MSFT": 0,
|
||||
"CASH": 9737.6
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 📈 价格数据 (merged.jsonl)
|
||||
```json
|
||||
{
|
||||
"Meta Data": {
|
||||
"2. Symbol": "AAPL",
|
||||
"3. Last Refreshed": "2025-01-20"
|
||||
},
|
||||
"Time Series (Daily)": {
|
||||
"2025-01-20": {
|
||||
"1. buy price": "255.8850",
|
||||
"2. high": "264.3750",
|
||||
"3. low": "255.6300",
|
||||
"4. sell price": "262.2400",
|
||||
"5. volume": "90483029"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 📁 文件结构
|
||||
|
||||
```
|
||||
data/agent_data/
|
||||
├── claude-3.7-sonnet/
|
||||
│ ├── position/
|
||||
│ │ └── position.jsonl # 📝 持仓记录
|
||||
│ └── log/
|
||||
│ └── 2025-01-20/
|
||||
│ └── log.jsonl # 📊 交易日志
|
||||
├── gpt-4o/
|
||||
│ └── ...
|
||||
└── qwen3-max/
|
||||
└── ...
|
||||
```
|
||||
|
||||
## 🔌 第三方策略集成
|
||||
|
||||
AI-Trader Bench采用模块化设计,支持轻松集成第三方策略和自定义AI代理。
|
||||
|
||||
### 🛠️ 集成方式
|
||||
|
||||
#### 1. 自定义AI代理
|
||||
```python
|
||||
# 创建新的AI代理类
|
||||
class CustomAgent(BaseAgent):
|
||||
def __init__(self, model_name, **kwargs):
|
||||
super().__init__(model_name, **kwargs)
|
||||
# 添加自定义逻辑
|
||||
```
|
||||
|
||||
#### 2. 注册新代理
|
||||
```python
|
||||
# 在 main.py 中注册
|
||||
AGENT_REGISTRY = {
|
||||
"BaseAgent": {
|
||||
"module": "agent.base_agent.base_agent",
|
||||
"class": "BaseAgent"
|
||||
},
|
||||
"CustomAgent": { # 新增
|
||||
"module": "agent.custom.custom_agent",
|
||||
"class": "CustomAgent"
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
#### 3. 配置文件设置
|
||||
```json
|
||||
{
|
||||
"agent_type": "CustomAgent",
|
||||
"models": [
|
||||
{
|
||||
"name": "your-custom-model",
|
||||
"basemodel": "your/model/path",
|
||||
"signature": "custom-signature",
|
||||
"enabled": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 🔧 扩展工具链
|
||||
|
||||
#### 添加自定义工具
|
||||
```python
|
||||
# 创建新的MCP工具
|
||||
@mcp.tools()
|
||||
class CustomTool:
|
||||
def __init__(self):
|
||||
self.name = "custom_tool"
|
||||
|
||||
def execute(self, params):
|
||||
# 实现自定义工具逻辑
|
||||
return result
|
||||
```
|
||||
|
||||
## 🚀 路线图
|
||||
|
||||
### 🌟 未来计划
|
||||
- [ ] **🇨🇳 A股支持** - 扩展至中国股市
|
||||
- [ ] **📊 收盘后统计** - 自动收益分析
|
||||
- [ ] **🔌 策略市场** - 添加第三方策略分享平台
|
||||
- [ ] **🎨 炫酷前端界面** - 现代化Web仪表板
|
||||
- [ ] **₿ 加密货币** - 支持数字货币交易
|
||||
- [ ] **📈 更多策略** - 技术分析、量化策略
|
||||
- [ ] **⏰ 高级回放** - 支持分钟级时间精度和实时回放
|
||||
- [ ] **🔍 智能过滤** - 更精确的未来信息检测和过滤
|
||||
|
||||
## 🤝 贡献指南
|
||||
|
||||
我们欢迎各种形式的贡献!特别是AI交易策略和代理实现。
|
||||
|
||||
### 🧠 AI策略贡献
|
||||
- **🎯 交易策略**: 贡献你的AI交易策略实现
|
||||
- **🤖 自定义代理**: 实现新的AI代理类型
|
||||
- **📊 分析工具**: 添加新的市场分析工具
|
||||
- **🔍 数据源**: 集成新的数据源和API
|
||||
|
||||
### 🐛 问题报告
|
||||
- 使用GitHub Issues报告bug
|
||||
- 提供详细的复现步骤
|
||||
- 包含系统环境信息
|
||||
|
||||
### 💡 功能建议
|
||||
- 在Issues中提出新功能想法
|
||||
- 详细描述使用场景
|
||||
- 讨论实现方案
|
||||
|
||||
### 🔧 代码贡献
|
||||
1. Fork项目
|
||||
2. 创建功能分支
|
||||
3. 实现你的策略或功能
|
||||
4. 添加测试用例
|
||||
5. 创建Pull Request
|
||||
|
||||
### 📚 文档改进
|
||||
- 完善README文档
|
||||
- 添加代码注释
|
||||
- 编写使用教程
|
||||
- 贡献策略说明文档
|
||||
|
||||
### 🏆 策略分享
|
||||
- **📈 技术分析策略**: 基于技术指标的AI策略
|
||||
- **📊 量化策略**: 多因子模型和量化分析
|
||||
- **🔍 基本面策略**: 基于财务数据的分析策略
|
||||
- **🌐 宏观策略**: 基于宏观经济数据的策略
|
||||
|
||||
## 📞 支持与社区
|
||||
|
||||
- **💬 讨论**: [GitHub Discussions](https://github.com/HKUDS/AI-Trader/discussions)
|
||||
- **🐛 问题**: [GitHub Issues](https://github.com/HKUDS/AI-Trader/issues)
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
本项目采用 [MIT License](LICENSE) 开源协议。
|
||||
|
||||
## 🙏 致谢
|
||||
|
||||
感谢以下开源项目和服务:
|
||||
- [LangChain](https://github.com/langchain-ai/langchain) - AI应用开发框架
|
||||
- [MCP](https://github.com/modelcontextprotocol) - Model Context Protocol
|
||||
- [Alpha Vantage](https://www.alphavantage.co/) - 金融数据API
|
||||
- [Jina AI](https://jina.ai/) - 信息搜索服务
|
||||
|
||||
## 免责声明
|
||||
|
||||
AI-Trader项目所提供的资料仅供研究之用,并不构成任何投资建议。投资者在作出任何投资决策之前,应寻求独立专业意见。任何过往表现未必可作为未来业绩的指标。阁下应注意,投资价值可能上升亦可能下跌,且并无任何保证。AI-Trader项目的所有内容仅作研究之用,并不构成对所提及之证券/行业的任何投资推荐。投资涉及风险。如有需要,请寻求专业咨询。
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
**🌟 如果这个项目对你有帮助,请给我们一个Star!**
|
||||
|
||||
[](https://github.com/HKUDS/AI-Trader)
|
||||
[](https://github.com/HKUDS/AI-Trader)
|
||||
|
||||
**🤖 让AI在金融市场中完全自主决策、一展身手!**
|
||||
**🛠️ 纯工具驱动,零人工干预,真正的AI交易竞技场!** 🚀
|
||||
|
||||
</div>
|
||||
88
ROADMAP.md
Normal file
88
ROADMAP.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# AI-Trader Roadmap
|
||||
|
||||
This document outlines planned features and improvements for the AI-Trader project.
|
||||
|
||||
## Release Planning
|
||||
|
||||
### v0.4.0 - Enhanced Simulation Management (Planned)
|
||||
|
||||
**Focus:** Improved simulation control, resume capabilities, and performance analysis
|
||||
|
||||
#### Simulation Resume & Continuation
|
||||
- **Resume from Last Completed Date** - API to continue simulations without re-running completed dates
|
||||
- `POST /simulate/resume` - Resume last incomplete job or start from last completed date
|
||||
- `POST /simulate/continue` - Extend existing simulation with new date range
|
||||
- Query parameters to specify which model(s) to continue
|
||||
- Automatic detection of last completed date per model
|
||||
- Validation to prevent overlapping simulations
|
||||
- Support for extending date ranges forward in time
|
||||
- Use cases:
|
||||
- Daily simulation updates (add today's date to existing run)
|
||||
- Recovering from failed jobs (resume from interruption point)
|
||||
- Incremental backtesting (extend historical analysis)
|
||||
|
||||
#### Position History & Analysis
|
||||
- **Position History Tracking** - Track position changes over time
|
||||
- Query endpoint: `GET /positions/history?model=<name>&start_date=<date>&end_date=<date>`
|
||||
- Timeline view of all trades and position changes
|
||||
- Calculate holding periods and turnover rates
|
||||
- Support for position snapshots at specific dates
|
||||
|
||||
#### Performance Metrics
|
||||
- **Advanced Performance Analytics** - Calculate standard trading metrics
|
||||
- Sharpe ratio, Sortino ratio, maximum drawdown
|
||||
- Win rate, average win/loss, profit factor
|
||||
- Volatility and beta calculations
|
||||
- Risk-adjusted returns
|
||||
- Comparison across models
|
||||
|
||||
#### Data Management
|
||||
- **Price Data Management API** - Endpoints for price data operations
|
||||
- `GET /data/coverage` - Check date ranges available per symbol
|
||||
- `POST /data/download` - Trigger manual price data downloads
|
||||
- `GET /data/status` - Check download progress and rate limits
|
||||
- `DELETE /data/range` - Remove price data for specific date ranges
|
||||
|
||||
#### Web UI
|
||||
- **Dashboard Interface** - Web-based monitoring and control interface
|
||||
- Job management dashboard
|
||||
- View active, pending, and completed jobs
|
||||
- Start new simulations with form-based configuration
|
||||
- Monitor job progress in real-time
|
||||
- Cancel running jobs
|
||||
- Results visualization
|
||||
- Performance charts (P&L over time, cumulative returns)
|
||||
- Position history timeline
|
||||
- Model comparison views
|
||||
- Trade log explorer with filtering
|
||||
- Configuration management
|
||||
- Model configuration editor
|
||||
- Date range selection with calendar picker
|
||||
- Price data coverage visualization
|
||||
- Technical implementation
|
||||
- Modern frontend framework (React, Vue.js, or Svelte)
|
||||
- Real-time updates via WebSocket or SSE
|
||||
- Responsive design for mobile access
|
||||
- Chart library (Plotly.js, Chart.js, or Recharts)
|
||||
- Served alongside API (single container deployment)
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions to any of these planned features! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
|
||||
To propose a new feature:
|
||||
1. Open an issue with the `feature-request` label
|
||||
2. Describe the use case and expected behavior
|
||||
3. Discuss implementation approach with maintainers
|
||||
4. Submit a PR with tests and documentation
|
||||
|
||||
## Version History
|
||||
|
||||
- **v0.1.0** - Initial release with batch execution
|
||||
- **v0.2.0** - Docker deployment support
|
||||
- **v0.3.0** - REST API, on-demand downloads, database storage (current)
|
||||
- **v0.4.0** - Enhanced simulation management (planned)
|
||||
|
||||
---
|
||||
|
||||
Last updated: 2025-10-31
|
||||
112
api/database.py
112
api/database.py
@@ -50,6 +50,9 @@ def initialize_database(db_path: str = "data/jobs.db") -> None:
|
||||
4. holdings - Portfolio holdings per position
|
||||
5. reasoning_logs - AI decision logs (optional, for detail=full)
|
||||
6. tool_usage - Tool usage statistics
|
||||
7. price_data - Historical OHLCV price data (replaces merged.jsonl)
|
||||
8. price_data_coverage - Downloaded date range tracking per symbol
|
||||
9. simulation_runs - Simulation run tracking for soft delete
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
@@ -108,8 +111,10 @@ def initialize_database(db_path: str = "data/jobs.db") -> None:
|
||||
daily_return_pct REAL,
|
||||
cumulative_profit REAL,
|
||||
cumulative_return_pct REAL,
|
||||
simulation_run_id TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (simulation_run_id) REFERENCES simulation_runs(run_id) ON DELETE SET NULL
|
||||
)
|
||||
""")
|
||||
|
||||
@@ -154,6 +159,53 @@ def initialize_database(db_path: str = "data/jobs.db") -> None:
|
||||
)
|
||||
""")
|
||||
|
||||
# Table 7: Price Data - OHLCV price data (replaces merged.jsonl)
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS price_data (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
symbol TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
volume INTEGER NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
UNIQUE(symbol, date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Table 8: Price Data Coverage - Track downloaded date ranges per symbol
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS price_data_coverage (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
symbol TEXT NOT NULL,
|
||||
start_date TEXT NOT NULL,
|
||||
end_date TEXT NOT NULL,
|
||||
downloaded_at TEXT NOT NULL,
|
||||
source TEXT DEFAULT 'alpha_vantage',
|
||||
UNIQUE(symbol, start_date, end_date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Table 9: Simulation Runs - Track simulation runs for soft delete
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS simulation_runs (
|
||||
run_id TEXT PRIMARY KEY,
|
||||
job_id TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
start_date TEXT NOT NULL,
|
||||
end_date TEXT NOT NULL,
|
||||
status TEXT NOT NULL CHECK(status IN ('active', 'superseded')),
|
||||
created_at TEXT NOT NULL,
|
||||
superseded_at TEXT,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Run schema migrations for existing databases
|
||||
_migrate_schema(cursor)
|
||||
|
||||
# Create indexes for performance
|
||||
_create_indexes(cursor)
|
||||
|
||||
@@ -161,6 +213,21 @@ def initialize_database(db_path: str = "data/jobs.db") -> None:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _migrate_schema(cursor: sqlite3.Cursor) -> None:
|
||||
"""Migrate existing database schema to latest version."""
|
||||
# Check if positions table exists and has simulation_run_id column
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='positions'")
|
||||
if cursor.fetchone():
|
||||
cursor.execute("PRAGMA table_info(positions)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if 'simulation_run_id' not in columns:
|
||||
# Add simulation_run_id column to existing positions table
|
||||
cursor.execute("""
|
||||
ALTER TABLE positions ADD COLUMN simulation_run_id TEXT
|
||||
""")
|
||||
|
||||
|
||||
def _create_indexes(cursor: sqlite3.Cursor) -> None:
|
||||
"""Create database indexes for query performance."""
|
||||
|
||||
@@ -222,6 +289,41 @@ def _create_indexes(cursor: sqlite3.Cursor) -> None:
|
||||
ON tool_usage(job_id, date, model)
|
||||
""")
|
||||
|
||||
# Price data table indexes
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_price_data_symbol_date ON price_data(symbol, date)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_price_data_date ON price_data(date)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_price_data_symbol ON price_data(symbol)
|
||||
""")
|
||||
|
||||
# Price data coverage table indexes
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_coverage_symbol ON price_data_coverage(symbol)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_coverage_dates ON price_data_coverage(start_date, end_date)
|
||||
""")
|
||||
|
||||
# Simulation runs table indexes
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_runs_job_model ON simulation_runs(job_id, model)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_runs_status ON simulation_runs(status)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_runs_dates ON simulation_runs(start_date, end_date)
|
||||
""")
|
||||
|
||||
# Positions table - add index for simulation_run_id
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_positions_run_id ON positions(simulation_run_id)
|
||||
""")
|
||||
|
||||
|
||||
def drop_all_tables(db_path: str = "data/jobs.db") -> None:
|
||||
"""
|
||||
@@ -240,8 +342,11 @@ def drop_all_tables(db_path: str = "data/jobs.db") -> None:
|
||||
'reasoning_logs',
|
||||
'holdings',
|
||||
'positions',
|
||||
'simulation_runs',
|
||||
'job_details',
|
||||
'jobs'
|
||||
'jobs',
|
||||
'price_data_coverage',
|
||||
'price_data'
|
||||
]
|
||||
|
||||
for table in tables:
|
||||
@@ -296,7 +401,8 @@ def get_database_stats(db_path: str = "data/jobs.db") -> dict:
|
||||
stats["database_size_mb"] = 0
|
||||
|
||||
# Get row counts for each table
|
||||
tables = ['jobs', 'job_details', 'positions', 'holdings', 'reasoning_logs', 'tool_usage']
|
||||
tables = ['jobs', 'job_details', 'positions', 'holdings', 'reasoning_logs', 'tool_usage',
|
||||
'price_data', 'price_data_coverage', 'simulation_runs']
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
|
||||
93
api/date_utils.py
Normal file
93
api/date_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Date range utilities for simulation date management.
|
||||
|
||||
This module provides:
|
||||
- Date range expansion
|
||||
- Date range validation
|
||||
- Trading day detection
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
|
||||
def expand_date_range(start_date: str, end_date: str) -> List[str]:
|
||||
"""
|
||||
Expand date range into list of all dates (inclusive).
|
||||
|
||||
Args:
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Sorted list of dates in range
|
||||
|
||||
Raises:
|
||||
ValueError: If dates are invalid or start > end
|
||||
"""
|
||||
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
if start > end:
|
||||
raise ValueError(f"start_date ({start_date}) must be <= end_date ({end_date})")
|
||||
|
||||
dates = []
|
||||
current = start
|
||||
|
||||
while current <= end:
|
||||
dates.append(current.strftime("%Y-%m-%d"))
|
||||
current += timedelta(days=1)
|
||||
|
||||
return dates
|
||||
|
||||
|
||||
def validate_date_range(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
max_days: int = 30
|
||||
) -> None:
|
||||
"""
|
||||
Validate date range for simulation.
|
||||
|
||||
Args:
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
max_days: Maximum allowed days in range
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails
|
||||
"""
|
||||
# Parse dates
|
||||
try:
|
||||
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid date format: {e}")
|
||||
|
||||
# Check order
|
||||
if start > end:
|
||||
raise ValueError(f"start_date ({start_date}) must be <= end_date ({end_date})")
|
||||
|
||||
# Check range size
|
||||
days = (end - start).days + 1
|
||||
if days > max_days:
|
||||
raise ValueError(
|
||||
f"Date range too large: {days} days (max: {max_days}). "
|
||||
f"Reduce range or increase MAX_SIMULATION_DAYS."
|
||||
)
|
||||
|
||||
# Check not in future
|
||||
today = datetime.now().date()
|
||||
if end.date() > today:
|
||||
raise ValueError(f"end_date ({end_date}) cannot be in the future")
|
||||
|
||||
|
||||
def get_max_simulation_days() -> int:
|
||||
"""
|
||||
Get maximum simulation days from environment.
|
||||
|
||||
Returns:
|
||||
Maximum days allowed in simulation range
|
||||
"""
|
||||
return int(os.getenv("MAX_SIMULATION_DAYS", "30"))
|
||||
173
api/main.py
173
api/main.py
@@ -9,6 +9,7 @@ Provides endpoints for:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -20,6 +21,8 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from api.job_manager import JobManager
|
||||
from api.simulation_worker import SimulationWorker
|
||||
from api.database import get_db_connection
|
||||
from api.price_data_manager import PriceDataManager
|
||||
from api.date_utils import validate_date_range, expand_date_range, get_max_simulation_days
|
||||
import threading
|
||||
import time
|
||||
|
||||
@@ -29,21 +32,29 @@ logger = logging.getLogger(__name__)
|
||||
# Pydantic models for request/response validation
|
||||
class SimulateTriggerRequest(BaseModel):
|
||||
"""Request body for POST /simulate/trigger."""
|
||||
config_path: str = Field(..., description="Path to configuration file")
|
||||
date_range: List[str] = Field(..., min_length=1, description="List of trading dates (YYYY-MM-DD)")
|
||||
models: List[str] = Field(..., min_length=1, description="List of model signatures to simulate")
|
||||
start_date: str = Field(..., description="Start date for simulation (YYYY-MM-DD)")
|
||||
end_date: Optional[str] = Field(None, description="End date for simulation (YYYY-MM-DD). If not provided, simulates single day.")
|
||||
models: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional: List of model signatures to simulate. If not provided, uses enabled models from config."
|
||||
)
|
||||
|
||||
@field_validator("date_range")
|
||||
@field_validator("start_date", "end_date")
|
||||
@classmethod
|
||||
def validate_date_range(cls, v):
|
||||
def validate_date_format(cls, v):
|
||||
"""Validate date format."""
|
||||
for date in v:
|
||||
try:
|
||||
datetime.strptime(date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {date}. Expected YYYY-MM-DD")
|
||||
if v is None:
|
||||
return v
|
||||
try:
|
||||
datetime.strptime(v, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}. Expected YYYY-MM-DD")
|
||||
return v
|
||||
|
||||
def get_end_date(self) -> str:
|
||||
"""Get end date, defaulting to start_date if not provided."""
|
||||
return self.end_date or self.start_date
|
||||
|
||||
|
||||
class SimulateTriggerResponse(BaseModel):
|
||||
"""Response body for POST /simulate/trigger."""
|
||||
@@ -83,12 +94,16 @@ class HealthResponse(BaseModel):
|
||||
timestamp: str
|
||||
|
||||
|
||||
def create_app(db_path: str = "data/jobs.db") -> FastAPI:
|
||||
def create_app(
|
||||
db_path: str = "data/jobs.db",
|
||||
config_path: str = "configs/default_config.json"
|
||||
) -> FastAPI:
|
||||
"""
|
||||
Create FastAPI application instance.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
config_path: Path to default configuration file
|
||||
|
||||
Returns:
|
||||
Configured FastAPI app
|
||||
@@ -99,27 +114,121 @@ def create_app(db_path: str = "data/jobs.db") -> FastAPI:
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Store db_path in app state
|
||||
# Store paths in app state
|
||||
app.state.db_path = db_path
|
||||
app.state.config_path = config_path
|
||||
|
||||
@app.post("/simulate/trigger", response_model=SimulateTriggerResponse, status_code=200)
|
||||
async def trigger_simulation(request: SimulateTriggerRequest):
|
||||
"""
|
||||
Trigger a new simulation job.
|
||||
|
||||
Creates a job with specified config, dates, and models.
|
||||
Job runs asynchronously in background thread.
|
||||
Validates date range, downloads missing price data if needed,
|
||||
and creates job with available trading dates.
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If another job is already running or config invalid
|
||||
HTTPException 422: If request validation fails
|
||||
HTTPException 400: Validation errors, running job, or invalid dates
|
||||
HTTPException 503: Price data download failed
|
||||
"""
|
||||
try:
|
||||
# Use config path from app state
|
||||
config_path = app.state.config_path
|
||||
|
||||
# Validate config path exists
|
||||
if not Path(request.config_path).exists():
|
||||
if not Path(config_path).exists():
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Server configuration file not found: {config_path}"
|
||||
)
|
||||
|
||||
# Get end date (defaults to start_date for single day)
|
||||
end_date = request.get_end_date()
|
||||
|
||||
# Validate date range
|
||||
max_days = get_max_simulation_days()
|
||||
validate_date_range(request.start_date, end_date, max_days=max_days)
|
||||
|
||||
# Determine which models to run
|
||||
import json
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
if request.models is not None:
|
||||
# Use models from request (explicit override)
|
||||
models_to_run = request.models
|
||||
else:
|
||||
# Use enabled models from config
|
||||
models_to_run = [
|
||||
model["signature"]
|
||||
for model in config.get("models", [])
|
||||
if model.get("enabled", False)
|
||||
]
|
||||
|
||||
if not models_to_run:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No enabled models found in config. Either enable models in config or specify them in request."
|
||||
)
|
||||
|
||||
# Check price data and download if needed
|
||||
auto_download = os.getenv("AUTO_DOWNLOAD_PRICE_DATA", "true").lower() == "true"
|
||||
price_manager = PriceDataManager(db_path=app.state.db_path)
|
||||
|
||||
# Check what's missing
|
||||
missing_coverage = price_manager.get_missing_coverage(
|
||||
request.start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
download_info = None
|
||||
|
||||
# Download missing data if enabled
|
||||
if any(missing_coverage.values()):
|
||||
if not auto_download:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Missing price data for {len(missing_coverage)} symbols and auto-download is disabled. "
|
||||
f"Enable AUTO_DOWNLOAD_PRICE_DATA or pre-populate data."
|
||||
)
|
||||
|
||||
logger.info(f"Downloading missing price data for {len(missing_coverage)} symbols")
|
||||
|
||||
requested_dates = set(expand_date_range(request.start_date, end_date))
|
||||
|
||||
download_result = price_manager.download_missing_data_prioritized(
|
||||
missing_coverage,
|
||||
requested_dates
|
||||
)
|
||||
|
||||
if not download_result["success"]:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Failed to download any price data. Check ALPHAADVANTAGE_API_KEY."
|
||||
)
|
||||
|
||||
download_info = {
|
||||
"symbols_downloaded": len(download_result["downloaded"]),
|
||||
"symbols_failed": len(download_result["failed"]),
|
||||
"rate_limited": download_result["rate_limited"]
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Downloaded {len(download_result['downloaded'])} symbols, "
|
||||
f"{len(download_result['failed'])} failed, rate_limited={download_result['rate_limited']}"
|
||||
)
|
||||
|
||||
# Get available trading dates (after potential download)
|
||||
available_dates = price_manager.get_available_trading_dates(
|
||||
request.start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
if not available_dates:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Config path does not exist: {request.config_path}"
|
||||
detail=f"No trading dates with complete price data in range "
|
||||
f"{request.start_date} to {end_date}. "
|
||||
f"All symbols must have data for a date to be tradeable."
|
||||
)
|
||||
|
||||
job_manager = JobManager(db_path=app.state.db_path)
|
||||
@@ -131,11 +240,11 @@ def create_app(db_path: str = "data/jobs.db") -> FastAPI:
|
||||
detail="Another simulation job is already running or pending. Please wait for it to complete."
|
||||
)
|
||||
|
||||
# Create job
|
||||
# Create job with available dates
|
||||
job_id = job_manager.create_job(
|
||||
config_path=request.config_path,
|
||||
date_range=request.date_range,
|
||||
models=request.models
|
||||
config_path=config_path,
|
||||
date_range=available_dates,
|
||||
models=models_to_run
|
||||
)
|
||||
|
||||
# Start worker in background thread (only if not in test mode)
|
||||
@@ -147,15 +256,27 @@ def create_app(db_path: str = "data/jobs.db") -> FastAPI:
|
||||
thread = threading.Thread(target=run_worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Triggered simulation job {job_id}")
|
||||
logger.info(f"Triggered simulation job {job_id} with {len(available_dates)} dates")
|
||||
|
||||
return SimulateTriggerResponse(
|
||||
# Build response message
|
||||
message = f"Simulation job created with {len(available_dates)} trading dates"
|
||||
if download_info and download_info["rate_limited"]:
|
||||
message += " (rate limit reached - partial data)"
|
||||
|
||||
response = SimulateTriggerResponse(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
total_model_days=len(request.date_range) * len(request.models),
|
||||
message=f"Simulation job {job_id} created and started"
|
||||
total_model_days=len(available_dates) * len(models_to_run),
|
||||
message=message
|
||||
)
|
||||
|
||||
# Add download info if we downloaded
|
||||
if download_info:
|
||||
# Note: Need to add download_info field to response model
|
||||
logger.info(f"Download info: {download_info}")
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
|
||||
546
api/price_data_manager.py
Normal file
546
api/price_data_manager.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""
|
||||
Price data management for on-demand downloads and coverage tracking.
|
||||
|
||||
This module provides:
|
||||
- Coverage gap detection
|
||||
- Priority-based download ordering
|
||||
- Rate limit handling with retry logic
|
||||
- Price data storage and retrieval
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Set, Tuple, Optional, Callable, Any
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from api.database import get_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Raised when API rate limit is hit."""
|
||||
pass
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
"""Raised when download fails for non-rate-limit reasons."""
|
||||
pass
|
||||
|
||||
|
||||
class PriceDataManager:
|
||||
"""
|
||||
Manages price data availability, downloads, and coverage tracking.
|
||||
|
||||
Responsibilities:
|
||||
- Check which dates/symbols have price data
|
||||
- Download missing data from Alpha Vantage
|
||||
- Track downloaded date ranges per symbol
|
||||
- Prioritize downloads to maximize date completion
|
||||
- Handle rate limiting gracefully
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str = "data/jobs.db",
|
||||
symbols_config: str = "configs/nasdaq100_symbols.json",
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize PriceDataManager.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
symbols_config: Path to NASDAQ 100 symbols configuration
|
||||
api_key: Alpha Vantage API key (defaults to env var)
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.symbols_config = symbols_config
|
||||
self.api_key = api_key or os.getenv("ALPHAADVANTAGE_API_KEY")
|
||||
|
||||
# Load symbols list
|
||||
self.symbols = self._load_symbols()
|
||||
|
||||
logger.info(f"Initialized PriceDataManager with {len(self.symbols)} symbols")
|
||||
|
||||
def _load_symbols(self) -> List[str]:
|
||||
"""Load NASDAQ 100 symbols from config file."""
|
||||
config_path = Path(self.symbols_config)
|
||||
|
||||
if not config_path.exists():
|
||||
logger.warning(f"Symbols config not found: {config_path}. Using default list.")
|
||||
# Fallback to a minimal list
|
||||
return ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA"]
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
return config.get("symbols", [])
|
||||
|
||||
def get_available_dates(self) -> Set[str]:
|
||||
"""
|
||||
Get all dates that have price data in database.
|
||||
|
||||
Returns:
|
||||
Set of dates (YYYY-MM-DD) with data
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT DISTINCT date FROM price_data ORDER BY date")
|
||||
dates = {row[0] for row in cursor.fetchall()}
|
||||
|
||||
conn.close()
|
||||
|
||||
return dates
|
||||
|
||||
def get_symbol_dates(self, symbol: str) -> Set[str]:
|
||||
"""
|
||||
Get all dates that have data for a specific symbol.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
|
||||
Returns:
|
||||
Set of dates with data for this symbol
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT date FROM price_data WHERE symbol = ? ORDER BY date",
|
||||
(symbol,)
|
||||
)
|
||||
dates = {row[0] for row in cursor.fetchall()}
|
||||
|
||||
conn.close()
|
||||
|
||||
return dates
|
||||
|
||||
def get_missing_coverage(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Identify which symbols are missing data for which dates in range.
|
||||
|
||||
Args:
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Dict mapping symbol to set of missing dates
|
||||
Example: {"AAPL": {"2025-01-20", "2025-01-21"}, "MSFT": set()}
|
||||
"""
|
||||
# Generate all dates in range
|
||||
requested_dates = self._expand_date_range(start_date, end_date)
|
||||
|
||||
missing = {}
|
||||
|
||||
for symbol in self.symbols:
|
||||
symbol_dates = self.get_symbol_dates(symbol)
|
||||
missing_dates = requested_dates - symbol_dates
|
||||
|
||||
if missing_dates:
|
||||
missing[symbol] = missing_dates
|
||||
|
||||
return missing
|
||||
|
||||
def _expand_date_range(self, start_date: str, end_date: str) -> Set[str]:
|
||||
"""
|
||||
Expand date range into set of all dates.
|
||||
|
||||
Args:
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Set of all dates in range (inclusive)
|
||||
"""
|
||||
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
dates = set()
|
||||
current = start
|
||||
|
||||
while current <= end:
|
||||
dates.add(current.strftime("%Y-%m-%d"))
|
||||
current += timedelta(days=1)
|
||||
|
||||
return dates
|
||||
|
||||
def prioritize_downloads(
|
||||
self,
|
||||
missing_coverage: Dict[str, Set[str]],
|
||||
requested_dates: Set[str]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Prioritize symbol downloads to maximize date completion.
|
||||
|
||||
Strategy: Download symbols that complete the most requested dates first.
|
||||
|
||||
Args:
|
||||
missing_coverage: Dict of symbol -> missing dates
|
||||
requested_dates: Set of dates we want to simulate
|
||||
|
||||
Returns:
|
||||
List of symbols in priority order (highest impact first)
|
||||
"""
|
||||
# Calculate impact score for each symbol
|
||||
impacts = []
|
||||
|
||||
for symbol, missing_dates in missing_coverage.items():
|
||||
# Impact = number of requested dates this symbol would complete
|
||||
impact = len(missing_dates & requested_dates)
|
||||
|
||||
if impact > 0:
|
||||
impacts.append((symbol, impact))
|
||||
|
||||
# Sort by impact (descending)
|
||||
impacts.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Return symbols in priority order
|
||||
prioritized = [symbol for symbol, _ in impacts]
|
||||
|
||||
logger.info(f"Prioritized {len(prioritized)} symbols for download")
|
||||
if prioritized:
|
||||
logger.debug(f"Top 5 symbols: {prioritized[:5]}")
|
||||
|
||||
return prioritized
|
||||
|
||||
def download_missing_data_prioritized(
|
||||
self,
|
||||
missing_coverage: Dict[str, Set[str]],
|
||||
requested_dates: Set[str],
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Download data in priority order until rate limited.
|
||||
|
||||
Args:
|
||||
missing_coverage: Dict of symbol -> missing dates
|
||||
requested_dates: Set of dates being requested
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": True/False,
|
||||
"downloaded": ["AAPL", "MSFT", ...],
|
||||
"failed": ["GOOGL", ...],
|
||||
"rate_limited": True/False,
|
||||
"dates_completed": ["2025-01-20", ...],
|
||||
"partial_dates": {"2025-01-21": 75}
|
||||
}
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ALPHAADVANTAGE_API_KEY not configured")
|
||||
|
||||
# Prioritize downloads
|
||||
prioritized_symbols = self.prioritize_downloads(missing_coverage, requested_dates)
|
||||
|
||||
if not prioritized_symbols:
|
||||
logger.info("No downloads needed - all data available")
|
||||
return {
|
||||
"success": True,
|
||||
"downloaded": [],
|
||||
"failed": [],
|
||||
"rate_limited": False,
|
||||
"dates_completed": sorted(requested_dates),
|
||||
"partial_dates": {}
|
||||
}
|
||||
|
||||
logger.info(f"Starting priority download of {len(prioritized_symbols)} symbols")
|
||||
|
||||
downloaded = []
|
||||
failed = []
|
||||
rate_limited = False
|
||||
|
||||
# Download in priority order
|
||||
for i, symbol in enumerate(prioritized_symbols):
|
||||
try:
|
||||
# Progress callback
|
||||
if progress_callback:
|
||||
progress_callback({
|
||||
"current": i + 1,
|
||||
"total": len(prioritized_symbols),
|
||||
"symbol": symbol,
|
||||
"phase": "downloading"
|
||||
})
|
||||
|
||||
# Download symbol data
|
||||
logger.info(f"Downloading {symbol} ({i+1}/{len(prioritized_symbols)})")
|
||||
data = self._download_symbol(symbol)
|
||||
|
||||
# Store in database
|
||||
stored_dates = self._store_symbol_data(symbol, data, requested_dates)
|
||||
|
||||
# Update coverage tracking
|
||||
if stored_dates:
|
||||
self._update_coverage(symbol, min(stored_dates), max(stored_dates))
|
||||
|
||||
downloaded.append(symbol)
|
||||
logger.info(f"✓ Downloaded {symbol} - {len(stored_dates)} dates stored")
|
||||
|
||||
except RateLimitError as e:
|
||||
# Hit rate limit - stop downloading
|
||||
logger.warning(f"Rate limit hit after {len(downloaded)} downloads: {e}")
|
||||
rate_limited = True
|
||||
failed = prioritized_symbols[i:] # Rest are undownloaded
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
# Other error - log and continue
|
||||
logger.error(f"Failed to download {symbol}: {e}")
|
||||
failed.append(symbol)
|
||||
continue
|
||||
|
||||
# Analyze coverage
|
||||
coverage_analysis = self._analyze_coverage(requested_dates)
|
||||
|
||||
result = {
|
||||
"success": len(downloaded) > 0 or len(requested_dates) == len(coverage_analysis["completed_dates"]),
|
||||
"downloaded": downloaded,
|
||||
"failed": failed,
|
||||
"rate_limited": rate_limited,
|
||||
"dates_completed": coverage_analysis["completed_dates"],
|
||||
"partial_dates": coverage_analysis["partial_dates"]
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Download complete: {len(downloaded)} symbols downloaded, "
|
||||
f"{len(failed)} failed/skipped, rate_limited={rate_limited}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _download_symbol(self, symbol: str, retries: int = 3) -> Dict:
|
||||
"""
|
||||
Download full price history for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
retries: Number of retry attempts for transient errors
|
||||
|
||||
Returns:
|
||||
JSON response from Alpha Vantage
|
||||
|
||||
Raises:
|
||||
RateLimitError: If rate limit is hit
|
||||
DownloadError: If download fails after retries
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise DownloadError("API key not configured")
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
response = requests.get(
|
||||
"https://www.alphavantage.co/query",
|
||||
params={
|
||||
"function": "TIME_SERIES_DAILY",
|
||||
"symbol": symbol,
|
||||
"outputsize": "full", # Get full history
|
||||
"apikey": self.api_key
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
|
||||
# Check for API error messages
|
||||
if "Error Message" in data:
|
||||
raise DownloadError(f"API error: {data['Error Message']}")
|
||||
|
||||
# Check for rate limit in response body
|
||||
if "Note" in data:
|
||||
note = data["Note"]
|
||||
if "call frequency" in note.lower() or "rate limit" in note.lower():
|
||||
raise RateLimitError(note)
|
||||
# Other notes are warnings, continue
|
||||
logger.warning(f"{symbol}: {note}")
|
||||
|
||||
if "Information" in data:
|
||||
info = data["Information"]
|
||||
if "premium" in info.lower() or "limit" in info.lower():
|
||||
raise RateLimitError(info)
|
||||
|
||||
# Validate response has time series data
|
||||
if "Time Series (Daily)" not in data or "Meta Data" not in data:
|
||||
raise DownloadError(f"Invalid response format for {symbol}")
|
||||
|
||||
return data
|
||||
|
||||
elif response.status_code == 429:
|
||||
raise RateLimitError("HTTP 429: Too Many Requests")
|
||||
|
||||
elif response.status_code >= 500:
|
||||
# Server error - retry with backoff
|
||||
if attempt < retries - 1:
|
||||
wait_time = (2 ** attempt)
|
||||
logger.warning(f"Server error {response.status_code}. Retrying in {wait_time}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
raise DownloadError(f"Server error: {response.status_code}")
|
||||
|
||||
else:
|
||||
raise DownloadError(f"HTTP {response.status_code}: {response.text[:200]}")
|
||||
|
||||
except RateLimitError:
|
||||
raise # Don't retry rate limits
|
||||
except DownloadError:
|
||||
raise # Don't retry download errors
|
||||
except requests.RequestException as e:
|
||||
if attempt < retries - 1:
|
||||
logger.warning(f"Request failed: {e}. Retrying...")
|
||||
time.sleep(2)
|
||||
continue
|
||||
raise DownloadError(f"Request failed after {retries} attempts: {e}")
|
||||
|
||||
raise DownloadError(f"Failed to download {symbol} after {retries} attempts")
|
||||
|
||||
def _store_symbol_data(
|
||||
self,
|
||||
symbol: str,
|
||||
data: Dict,
|
||||
requested_dates: Set[str]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Store downloaded price data in database.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
data: Alpha Vantage API response
|
||||
requested_dates: Only store dates in this set
|
||||
|
||||
Returns:
|
||||
List of dates actually stored
|
||||
"""
|
||||
time_series = data.get("Time Series (Daily)", {})
|
||||
|
||||
if not time_series:
|
||||
logger.warning(f"No time series data for {symbol}")
|
||||
return []
|
||||
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
stored_dates = []
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
for date, ohlcv in time_series.items():
|
||||
# Only store requested dates
|
||||
if date not in requested_dates:
|
||||
continue
|
||||
|
||||
try:
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO price_data
|
||||
(symbol, date, open, high, low, close, volume, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
symbol,
|
||||
date,
|
||||
float(ohlcv.get("1. open", 0)),
|
||||
float(ohlcv.get("2. high", 0)),
|
||||
float(ohlcv.get("3. low", 0)),
|
||||
float(ohlcv.get("4. close", 0)),
|
||||
int(ohlcv.get("5. volume", 0)),
|
||||
created_at
|
||||
))
|
||||
stored_dates.append(date)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store {symbol} {date}: {e}")
|
||||
continue
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return stored_dates
|
||||
|
||||
def _update_coverage(self, symbol: str, start_date: str, end_date: str) -> None:
|
||||
"""
|
||||
Update coverage tracking for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
start_date: Start of date range downloaded
|
||||
end_date: End of date range downloaded
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
downloaded_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO price_data_coverage
|
||||
(symbol, start_date, end_date, downloaded_at, source)
|
||||
VALUES (?, ?, ?, ?, 'alpha_vantage')
|
||||
""", (symbol, start_date, end_date, downloaded_at))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _analyze_coverage(self, requested_dates: Set[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze which requested dates have complete/partial coverage.
|
||||
|
||||
Args:
|
||||
requested_dates: Set of dates requested
|
||||
|
||||
Returns:
|
||||
{
|
||||
"completed_dates": ["2025-01-20", ...], # All symbols available
|
||||
"partial_dates": {"2025-01-21": 75, ...} # Date -> symbol count
|
||||
}
|
||||
"""
|
||||
conn = get_db_connection(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
total_symbols = len(self.symbols)
|
||||
completed_dates = []
|
||||
partial_dates = {}
|
||||
|
||||
for date in sorted(requested_dates):
|
||||
# Count symbols available for this date
|
||||
cursor.execute(
|
||||
"SELECT COUNT(DISTINCT symbol) FROM price_data WHERE date = ?",
|
||||
(date,)
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
if count == total_symbols:
|
||||
completed_dates.append(date)
|
||||
elif count > 0:
|
||||
partial_dates[date] = count
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
"completed_dates": completed_dates,
|
||||
"partial_dates": partial_dates
|
||||
}
|
||||
|
||||
def get_available_trading_dates(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get trading dates with complete data in range.
|
||||
|
||||
Args:
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Sorted list of dates with complete data (all symbols)
|
||||
"""
|
||||
requested_dates = self._expand_date_range(start_date, end_date)
|
||||
analysis = self._analyze_coverage(requested_dates)
|
||||
|
||||
return sorted(analysis["completed_dates"])
|
||||
18
configs/nasdaq100_symbols.json
Normal file
18
configs/nasdaq100_symbols.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"symbols": [
|
||||
"NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA",
|
||||
"NFLX", "PLTR", "COST", "ASML", "AMD", "CSCO", "AZN", "TMUS", "MU", "LIN",
|
||||
"PEP", "SHOP", "APP", "INTU", "AMAT", "LRCX", "PDD", "QCOM", "ARM", "INTC",
|
||||
"BKNG", "AMGN", "TXN", "ISRG", "GILD", "KLAC", "PANW", "ADBE", "HON",
|
||||
"CRWD", "CEG", "ADI", "ADP", "DASH", "CMCSA", "VRTX", "MELI", "SBUX",
|
||||
"CDNS", "ORLY", "SNPS", "MSTR", "MDLZ", "ABNB", "MRVL", "CTAS", "TRI",
|
||||
"MAR", "MNST", "CSX", "ADSK", "PYPL", "FTNT", "AEP", "WDAY", "REGN", "ROP",
|
||||
"NXPI", "DDOG", "AXON", "ROST", "IDXX", "EA", "PCAR", "FAST", "EXC", "TTWO",
|
||||
"XEL", "ZS", "PAYX", "WBD", "BKR", "CPRT", "CCEP", "FANG", "TEAM", "CHTR",
|
||||
"KDP", "MCHP", "GEHC", "VRSK", "CTSH", "CSGP", "KHC", "ODFL", "DXCM", "TTD",
|
||||
"ON", "BIIB", "LULU", "CDW", "GFS", "QQQ"
|
||||
],
|
||||
"description": "NASDAQ 100 constituent stocks plus QQQ ETF",
|
||||
"last_updated": "2025-10-31",
|
||||
"total_symbols": 101
|
||||
}
|
||||
@@ -23,8 +23,6 @@ services:
|
||||
ports:
|
||||
# API server port (primary interface for external access)
|
||||
- "${API_PORT:-8080}:8080"
|
||||
# Web dashboard
|
||||
- "${WEB_HTTP_PORT:-8888}:8888"
|
||||
restart: unless-stopped # Keep API server running
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||
|
||||
166
scripts/migrate_price_data.py
Executable file
166
scripts/migrate_price_data.py
Executable file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration script: Import merged.jsonl price data into SQLite database.
|
||||
|
||||
This script:
|
||||
1. Reads existing merged.jsonl file
|
||||
2. Parses OHLCV data for each symbol/date
|
||||
3. Inserts into price_data table
|
||||
4. Tracks coverage in price_data_coverage table
|
||||
|
||||
Run this once to migrate from jsonl to database.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from api.database import get_db_connection, initialize_database
|
||||
|
||||
|
||||
def migrate_merged_jsonl(
|
||||
jsonl_path: str = "data/merged.jsonl",
|
||||
db_path: str = "data/jobs.db"
|
||||
):
|
||||
"""
|
||||
Migrate price data from merged.jsonl to SQLite database.
|
||||
|
||||
Args:
|
||||
jsonl_path: Path to merged.jsonl file
|
||||
db_path: Path to SQLite database
|
||||
"""
|
||||
jsonl_file = Path(jsonl_path)
|
||||
|
||||
if not jsonl_file.exists():
|
||||
print(f"⚠️ merged.jsonl not found at {jsonl_path}")
|
||||
print(" No price data to migrate. Skipping migration.")
|
||||
return
|
||||
|
||||
print(f"📊 Migrating price data from {jsonl_path} to {db_path}")
|
||||
|
||||
# Ensure database is initialized
|
||||
initialize_database(db_path)
|
||||
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Track what we're importing
|
||||
total_records = 0
|
||||
symbols_processed = set()
|
||||
symbol_date_ranges = defaultdict(lambda: {"min": None, "max": None})
|
||||
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
print("Reading merged.jsonl...")
|
||||
|
||||
with open(jsonl_file, 'r') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
record = json.loads(line)
|
||||
|
||||
# Extract metadata
|
||||
meta = record.get("Meta Data", {})
|
||||
symbol = meta.get("2. Symbol")
|
||||
|
||||
if not symbol:
|
||||
print(f"⚠️ Line {line_num}: No symbol found, skipping")
|
||||
continue
|
||||
|
||||
symbols_processed.add(symbol)
|
||||
|
||||
# Extract time series data
|
||||
time_series = record.get("Time Series (Daily)", {})
|
||||
|
||||
if not time_series:
|
||||
print(f"⚠️ {symbol}: No time series data, skipping")
|
||||
continue
|
||||
|
||||
# Insert each date's data
|
||||
for date, ohlcv in time_series.items():
|
||||
try:
|
||||
# Parse OHLCV values
|
||||
open_price = float(ohlcv.get("1. buy price") or ohlcv.get("1. open", 0))
|
||||
high_price = float(ohlcv.get("2. high", 0))
|
||||
low_price = float(ohlcv.get("3. low", 0))
|
||||
close_price = float(ohlcv.get("4. sell price") or ohlcv.get("4. close", 0))
|
||||
volume = int(ohlcv.get("5. volume", 0))
|
||||
|
||||
# Insert or replace price data
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO price_data
|
||||
(symbol, date, open, high, low, close, volume, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (symbol, date, open_price, high_price, low_price, close_price, volume, created_at))
|
||||
|
||||
total_records += 1
|
||||
|
||||
# Track date range for this symbol
|
||||
if symbol_date_ranges[symbol]["min"] is None or date < symbol_date_ranges[symbol]["min"]:
|
||||
symbol_date_ranges[symbol]["min"] = date
|
||||
if symbol_date_ranges[symbol]["max"] is None or date > symbol_date_ranges[symbol]["max"]:
|
||||
symbol_date_ranges[symbol]["max"] = date
|
||||
|
||||
except (ValueError, KeyError) as e:
|
||||
print(f"⚠️ {symbol} {date}: Failed to parse OHLCV data: {e}")
|
||||
continue
|
||||
|
||||
# Commit every 1000 records for progress
|
||||
if total_records % 1000 == 0:
|
||||
conn.commit()
|
||||
print(f" Imported {total_records} records...")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"⚠️ Line {line_num}: JSON decode error: {e}")
|
||||
continue
|
||||
|
||||
# Final commit
|
||||
conn.commit()
|
||||
|
||||
print(f"\n✓ Imported {total_records} price records for {len(symbols_processed)} symbols")
|
||||
|
||||
# Update coverage tracking
|
||||
print("\nUpdating coverage tracking...")
|
||||
|
||||
for symbol, date_range in symbol_date_ranges.items():
|
||||
if date_range["min"] and date_range["max"]:
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO price_data_coverage
|
||||
(symbol, start_date, end_date, downloaded_at, source)
|
||||
VALUES (?, ?, ?, ?, 'migrated_from_jsonl')
|
||||
""", (symbol, date_range["min"], date_range["max"], created_at))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
print(f"✓ Coverage tracking updated for {len(symbol_date_ranges)} symbols")
|
||||
print("\n✅ Migration complete!")
|
||||
print(f"\nSymbols migrated: {', '.join(sorted(symbols_processed))}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Migrate merged.jsonl to SQLite database")
|
||||
parser.add_argument(
|
||||
"--jsonl",
|
||||
default="data/merged.jsonl",
|
||||
help="Path to merged.jsonl file (default: data/merged.jsonl)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db",
|
||||
default="data/jobs.db",
|
||||
help="Path to SQLite database (default: data/jobs.db)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
migrate_merged_jsonl(args.jsonl, args.db)
|
||||
453
tests/integration/test_on_demand_downloads.py
Normal file
453
tests/integration/test_on_demand_downloads.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Integration tests for on-demand price data downloads.
|
||||
|
||||
Tests the complete flow from missing coverage detection through download
|
||||
and storage, including priority-based download strategy and rate limit handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
import json
|
||||
from unittest.mock import patch, Mock
|
||||
from datetime import datetime
|
||||
|
||||
from api.price_data_manager import PriceDataManager, RateLimitError, DownloadError
|
||||
from api.database import initialize_database, get_db_connection
|
||||
from api.date_utils import expand_date_range
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
"""Create temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
initialize_database(db_path)
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
if os.path.exists(db_path):
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_symbols_config():
|
||||
"""Create temporary symbols config with small symbol set."""
|
||||
symbols_data = {
|
||||
"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA"],
|
||||
"description": "Test symbols",
|
||||
"total_symbols": 5
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(symbols_data, f)
|
||||
config_path = f.name
|
||||
|
||||
yield config_path
|
||||
|
||||
# Cleanup
|
||||
if os.path.exists(config_path):
|
||||
os.unlink(config_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(temp_db, temp_symbols_config):
|
||||
"""Create PriceDataManager instance."""
|
||||
return PriceDataManager(
|
||||
db_path=temp_db,
|
||||
symbols_config=temp_symbols_config,
|
||||
api_key="test_api_key"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_alpha_vantage_response():
|
||||
"""Create mock Alpha Vantage API response."""
|
||||
def create_response(symbol: str, dates: list):
|
||||
"""Create response for given symbol and dates."""
|
||||
time_series = {}
|
||||
for date in dates:
|
||||
time_series[date] = {
|
||||
"1. open": "150.00",
|
||||
"2. high": "155.00",
|
||||
"3. low": "149.00",
|
||||
"4. close": "154.00",
|
||||
"5. volume": "1000000"
|
||||
}
|
||||
|
||||
return {
|
||||
"Meta Data": {
|
||||
"1. Information": "Daily Prices",
|
||||
"2. Symbol": symbol,
|
||||
"3. Last Refreshed": dates[0] if dates else "2025-01-20"
|
||||
},
|
||||
"Time Series (Daily)": time_series
|
||||
}
|
||||
return create_response
|
||||
|
||||
|
||||
class TestEndToEndDownload:
|
||||
"""Test complete download workflow."""
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_download_missing_data_success(self, mock_get, manager, mock_alpha_vantage_response):
|
||||
"""Test successful download of missing price data."""
|
||||
# Setup: Mock API responses for each symbol
|
||||
dates = ["2025-01-20", "2025-01-21"]
|
||||
|
||||
def mock_response_factory(url, **kwargs):
|
||||
"""Return appropriate mock response based on symbol in params."""
|
||||
symbol = kwargs.get('params', {}).get('symbol', 'AAPL')
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_alpha_vantage_response(symbol, dates)
|
||||
return mock_response
|
||||
|
||||
mock_get.side_effect = mock_response_factory
|
||||
|
||||
# Test: Request date range with no existing data
|
||||
missing = manager.get_missing_coverage("2025-01-20", "2025-01-21")
|
||||
|
||||
# All symbols should be missing both dates
|
||||
assert len(missing) == 5
|
||||
for symbol in ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA"]:
|
||||
assert symbol in missing
|
||||
assert missing[symbol] == {"2025-01-20", "2025-01-21"}
|
||||
|
||||
# Download missing data
|
||||
requested_dates = set(dates)
|
||||
result = manager.download_missing_data_prioritized(missing, requested_dates)
|
||||
|
||||
# Should successfully download all symbols
|
||||
assert result["success"] is True
|
||||
assert len(result["downloaded"]) == 5
|
||||
assert result["rate_limited"] is False
|
||||
assert set(result["dates_completed"]) == requested_dates
|
||||
|
||||
# Verify data in database
|
||||
available_dates = manager.get_available_trading_dates("2025-01-20", "2025-01-21")
|
||||
assert available_dates == ["2025-01-20", "2025-01-21"]
|
||||
|
||||
# Verify coverage tracking
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM price_data_coverage")
|
||||
coverage_count = cursor.fetchone()[0]
|
||||
assert coverage_count == 5 # One record per symbol
|
||||
conn.close()
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_download_with_partial_existing_data(self, mock_get, manager, mock_alpha_vantage_response):
|
||||
"""Test download when some data already exists."""
|
||||
dates = ["2025-01-20", "2025-01-21", "2025-01-22"]
|
||||
|
||||
# Prepopulate database with some data (AAPL and MSFT for first two dates)
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
for symbol in ["AAPL", "MSFT"]:
|
||||
for date in dates[:2]: # Only first two dates
|
||||
cursor.execute("""
|
||||
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
|
||||
VALUES (?, ?, 150.0, 155.0, 149.0, 154.0, 1000000, ?)
|
||||
""", (symbol, date, created_at))
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO price_data_coverage (symbol, start_date, end_date, downloaded_at, source)
|
||||
VALUES (?, ?, ?, ?, 'test')
|
||||
""", (symbol, dates[0], dates[1], created_at))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Mock API for remaining downloads
|
||||
def mock_response_factory(url, **kwargs):
|
||||
symbol = kwargs.get('params', {}).get('symbol', 'GOOGL')
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_alpha_vantage_response(symbol, dates)
|
||||
return mock_response
|
||||
|
||||
mock_get.side_effect = mock_response_factory
|
||||
|
||||
# Check missing coverage
|
||||
missing = manager.get_missing_coverage(dates[0], dates[2])
|
||||
|
||||
# AAPL and MSFT should be missing only date 3
|
||||
# GOOGL, AMZN, NVDA should be missing all dates
|
||||
assert missing["AAPL"] == {dates[2]}
|
||||
assert missing["MSFT"] == {dates[2]}
|
||||
assert missing["GOOGL"] == set(dates)
|
||||
|
||||
# Download missing data
|
||||
requested_dates = set(dates)
|
||||
result = manager.download_missing_data_prioritized(missing, requested_dates)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["downloaded"]) == 5
|
||||
|
||||
# Verify all dates are now available
|
||||
available_dates = manager.get_available_trading_dates(dates[0], dates[2])
|
||||
assert set(available_dates) == set(dates)
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_priority_based_download_order(self, mock_get, manager, mock_alpha_vantage_response):
|
||||
"""Test that downloads prioritize symbols that complete the most dates."""
|
||||
dates = ["2025-01-20", "2025-01-21", "2025-01-22"]
|
||||
|
||||
# Prepopulate with specific pattern to create different priorities
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
# AAPL: Has date 1 only (missing 2 dates)
|
||||
cursor.execute("""
|
||||
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
|
||||
VALUES ('AAPL', ?, 150.0, 155.0, 149.0, 154.0, 1000000, ?)
|
||||
""", (dates[0], created_at))
|
||||
|
||||
# MSFT: Has date 1 and 2 (missing 1 date)
|
||||
for date in dates[:2]:
|
||||
cursor.execute("""
|
||||
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
|
||||
VALUES ('MSFT', ?, 150.0, 155.0, 149.0, 154.0, 1000000, ?)
|
||||
""", (date, created_at))
|
||||
|
||||
# GOOGL, AMZN, NVDA: No data (missing 3 dates)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Track download order
|
||||
download_order = []
|
||||
|
||||
def mock_response_factory(url, **kwargs):
|
||||
symbol = kwargs.get('params', {}).get('symbol')
|
||||
download_order.append(symbol)
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_alpha_vantage_response(symbol, dates)
|
||||
return mock_response
|
||||
|
||||
mock_get.side_effect = mock_response_factory
|
||||
|
||||
# Download missing data
|
||||
missing = manager.get_missing_coverage(dates[0], dates[2])
|
||||
requested_dates = set(dates)
|
||||
result = manager.download_missing_data_prioritized(missing, requested_dates)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
# Verify symbols with highest impact were downloaded first
|
||||
# GOOGL, AMZN, NVDA should be first (3 dates each)
|
||||
# Then AAPL (2 dates)
|
||||
# Then MSFT (1 date)
|
||||
first_three = set(download_order[:3])
|
||||
assert first_three == {"GOOGL", "AMZN", "NVDA"}
|
||||
assert download_order[3] == "AAPL"
|
||||
assert download_order[4] == "MSFT"
|
||||
|
||||
|
||||
class TestRateLimitHandling:
|
||||
"""Test rate limit handling during downloads."""
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_rate_limit_stops_downloads(self, mock_get, manager, mock_alpha_vantage_response):
|
||||
"""Test that rate limit error stops further downloads."""
|
||||
dates = ["2025-01-20"]
|
||||
|
||||
# First symbol succeeds, second hits rate limit
|
||||
responses = [
|
||||
# AAPL succeeds (or whichever symbol is first in priority)
|
||||
Mock(status_code=200, json=lambda: mock_alpha_vantage_response("AAPL", dates)),
|
||||
# MSFT hits rate limit
|
||||
Mock(status_code=200, json=lambda: {"Note": "Thank you for using Alpha Vantage! Our standard API call frequency is 25 calls per day."}),
|
||||
]
|
||||
|
||||
mock_get.side_effect = responses
|
||||
|
||||
missing = manager.get_missing_coverage("2025-01-20", "2025-01-20")
|
||||
requested_dates = {"2025-01-20"}
|
||||
|
||||
result = manager.download_missing_data_prioritized(missing, requested_dates)
|
||||
|
||||
# Partial success - one symbol downloaded
|
||||
assert result["success"] is True # At least one succeeded
|
||||
assert len(result["downloaded"]) >= 1
|
||||
assert result["rate_limited"] is True
|
||||
assert len(result["failed"]) >= 1
|
||||
|
||||
# Completed dates should be empty (need all symbols for complete date)
|
||||
assert len(result["dates_completed"]) == 0
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_graceful_handling_of_mixed_failures(self, mock_get, manager, mock_alpha_vantage_response):
|
||||
"""Test handling of mix of successes, failures, and rate limits."""
|
||||
dates = ["2025-01-20"]
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def response_factory(url, **kwargs):
|
||||
"""Return different responses for different calls."""
|
||||
call_count[0] += 1
|
||||
mock_response = Mock()
|
||||
|
||||
if call_count[0] == 1:
|
||||
# First call succeeds
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_alpha_vantage_response("AAPL", dates)
|
||||
elif call_count[0] == 2:
|
||||
# Second call fails with server error
|
||||
mock_response.status_code = 500
|
||||
mock_response.raise_for_status.side_effect = Exception("Server error")
|
||||
else:
|
||||
# Third call hits rate limit
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"Note": "rate limit exceeded"}
|
||||
|
||||
return mock_response
|
||||
|
||||
mock_get.side_effect = response_factory
|
||||
|
||||
missing = manager.get_missing_coverage("2025-01-20", "2025-01-20")
|
||||
requested_dates = {"2025-01-20"}
|
||||
|
||||
result = manager.download_missing_data_prioritized(missing, requested_dates)
|
||||
|
||||
# Should have handled errors gracefully
|
||||
assert "downloaded" in result
|
||||
assert "failed" in result
|
||||
assert len(result["downloaded"]) >= 1
|
||||
|
||||
|
||||
class TestCoverageTracking:
|
||||
"""Test coverage tracking functionality."""
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_coverage_updated_after_download(self, mock_get, manager, mock_alpha_vantage_response):
|
||||
"""Test that coverage table is updated after successful download."""
|
||||
dates = ["2025-01-20", "2025-01-21"]
|
||||
|
||||
mock_get.return_value = Mock(
|
||||
status_code=200,
|
||||
json=lambda: mock_alpha_vantage_response("AAPL", dates)
|
||||
)
|
||||
|
||||
# Download for single symbol
|
||||
data = manager._download_symbol("AAPL")
|
||||
stored_dates = manager._store_symbol_data("AAPL", data, set(dates))
|
||||
manager._update_coverage("AAPL", dates[0], dates[1])
|
||||
|
||||
# Verify coverage was recorded
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT symbol, start_date, end_date, source
|
||||
FROM price_data_coverage
|
||||
WHERE symbol = 'AAPL'
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
assert row is not None
|
||||
assert row[0] == "AAPL"
|
||||
assert row[1] == dates[0]
|
||||
assert row[2] == dates[1]
|
||||
assert row[3] == "alpha_vantage"
|
||||
|
||||
def test_coverage_gap_detection_accuracy(self, manager):
|
||||
"""Test accuracy of coverage gap detection."""
|
||||
# Populate database with specific pattern
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
test_data = [
|
||||
("AAPL", "2025-01-20"),
|
||||
("AAPL", "2025-01-21"),
|
||||
("AAPL", "2025-01-23"), # Gap on 2025-01-22
|
||||
("MSFT", "2025-01-20"),
|
||||
("MSFT", "2025-01-22"), # Gap on 2025-01-21
|
||||
]
|
||||
|
||||
for symbol, date in test_data:
|
||||
cursor.execute("""
|
||||
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
|
||||
VALUES (?, ?, 150.0, 155.0, 149.0, 154.0, 1000000, ?)
|
||||
""", (symbol, date, created_at))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Check for gaps in range
|
||||
missing = manager.get_missing_coverage("2025-01-20", "2025-01-23")
|
||||
|
||||
# AAPL should be missing 2025-01-22
|
||||
assert "2025-01-22" in missing["AAPL"]
|
||||
assert "2025-01-20" not in missing["AAPL"]
|
||||
|
||||
# MSFT should be missing 2025-01-21 and 2025-01-23
|
||||
assert "2025-01-21" in missing["MSFT"]
|
||||
assert "2025-01-23" in missing["MSFT"]
|
||||
assert "2025-01-20" not in missing["MSFT"]
|
||||
|
||||
|
||||
class TestDataValidation:
|
||||
"""Test data validation during download and storage."""
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_invalid_response_handling(self, mock_get, manager):
|
||||
"""Test handling of invalid API responses."""
|
||||
# Mock response with missing required fields
|
||||
mock_get.return_value = Mock(
|
||||
status_code=200,
|
||||
json=lambda: {"invalid": "response"}
|
||||
)
|
||||
|
||||
with pytest.raises(DownloadError, match="Invalid response format"):
|
||||
manager._download_symbol("AAPL")
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_empty_time_series_handling(self, mock_get, manager):
|
||||
"""Test handling of empty time series data (should raise error for missing data)."""
|
||||
# API returns valid structure but no time series
|
||||
mock_get.return_value = Mock(
|
||||
status_code=200,
|
||||
json=lambda: {
|
||||
"Meta Data": {"2. Symbol": "AAPL"},
|
||||
# Missing "Time Series (Daily)" key
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(DownloadError, match="Invalid response format"):
|
||||
manager._download_symbol("AAPL")
|
||||
|
||||
def test_date_filtering_during_storage(self, manager):
|
||||
"""Test that only requested dates are stored."""
|
||||
# Create mock data with dates outside requested range
|
||||
data = {
|
||||
"Meta Data": {"2. Symbol": "AAPL"},
|
||||
"Time Series (Daily)": {
|
||||
"2025-01-15": {"1. open": "145.00", "2. high": "150.00", "3. low": "144.00", "4. close": "149.00", "5. volume": "1000000"},
|
||||
"2025-01-20": {"1. open": "150.00", "2. high": "155.00", "3. low": "149.00", "4. close": "154.00", "5. volume": "1000000"},
|
||||
"2025-01-21": {"1. open": "154.00", "2. high": "156.00", "3. low": "153.00", "4. close": "155.00", "5. volume": "1100000"},
|
||||
"2025-01-25": {"1. open": "156.00", "2. high": "158.00", "3. low": "155.00", "4. close": "157.00", "5. volume": "1200000"},
|
||||
}
|
||||
}
|
||||
|
||||
# Request only specific dates
|
||||
requested_dates = {"2025-01-20", "2025-01-21"}
|
||||
stored_dates = manager._store_symbol_data("AAPL", data, requested_dates)
|
||||
|
||||
# Only requested dates should be stored
|
||||
assert set(stored_dates) == requested_dates
|
||||
|
||||
# Verify in database
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT date FROM price_data WHERE symbol = 'AAPL' ORDER BY date")
|
||||
db_dates = [row[0] for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
|
||||
assert db_dates == ["2025-01-20", "2025-01-21"]
|
||||
149
tests/unit/test_date_utils.py
Normal file
149
tests/unit/test_date_utils.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Unit tests for api/date_utils.py
|
||||
|
||||
Tests date range expansion, validation, and utility functions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from api.date_utils import (
|
||||
expand_date_range,
|
||||
validate_date_range,
|
||||
get_max_simulation_days
|
||||
)
|
||||
|
||||
|
||||
class TestExpandDateRange:
|
||||
"""Test expand_date_range function."""
|
||||
|
||||
def test_single_day(self):
|
||||
"""Test single day range (start == end)."""
|
||||
result = expand_date_range("2025-01-20", "2025-01-20")
|
||||
assert result == ["2025-01-20"]
|
||||
|
||||
def test_multi_day_range(self):
|
||||
"""Test multiple day range."""
|
||||
result = expand_date_range("2025-01-20", "2025-01-22")
|
||||
assert result == ["2025-01-20", "2025-01-21", "2025-01-22"]
|
||||
|
||||
def test_week_range(self):
|
||||
"""Test week-long range."""
|
||||
result = expand_date_range("2025-01-20", "2025-01-26")
|
||||
assert len(result) == 7
|
||||
assert result[0] == "2025-01-20"
|
||||
assert result[-1] == "2025-01-26"
|
||||
|
||||
def test_chronological_order(self):
|
||||
"""Test dates are in chronological order."""
|
||||
result = expand_date_range("2025-01-20", "2025-01-25")
|
||||
for i in range(len(result) - 1):
|
||||
assert result[i] < result[i + 1]
|
||||
|
||||
def test_invalid_order(self):
|
||||
"""Test error when start > end."""
|
||||
with pytest.raises(ValueError, match="must be <= end_date"):
|
||||
expand_date_range("2025-01-25", "2025-01-20")
|
||||
|
||||
def test_invalid_date_format(self):
|
||||
"""Test error with invalid date format."""
|
||||
with pytest.raises(ValueError):
|
||||
expand_date_range("01-20-2025", "01-21-2025")
|
||||
|
||||
def test_month_boundary(self):
|
||||
"""Test range spanning month boundary."""
|
||||
result = expand_date_range("2025-01-30", "2025-02-02")
|
||||
assert result == ["2025-01-30", "2025-01-31", "2025-02-01", "2025-02-02"]
|
||||
|
||||
def test_year_boundary(self):
|
||||
"""Test range spanning year boundary."""
|
||||
result = expand_date_range("2024-12-30", "2025-01-02")
|
||||
assert len(result) == 4
|
||||
assert "2024-12-31" in result
|
||||
assert "2025-01-01" in result
|
||||
|
||||
|
||||
class TestValidateDateRange:
|
||||
"""Test validate_date_range function."""
|
||||
|
||||
def test_valid_single_day(self):
|
||||
"""Test valid single day range."""
|
||||
# Should not raise
|
||||
validate_date_range("2025-01-20", "2025-01-20", max_days=30)
|
||||
|
||||
def test_valid_multi_day(self):
|
||||
"""Test valid multi-day range."""
|
||||
# Should not raise
|
||||
validate_date_range("2025-01-20", "2025-01-25", max_days=30)
|
||||
|
||||
def test_max_days_boundary(self):
|
||||
"""Test exactly at max days limit."""
|
||||
# 30 days total (inclusive)
|
||||
start = "2025-01-01"
|
||||
end = "2025-01-30"
|
||||
# Should not raise
|
||||
validate_date_range(start, end, max_days=30)
|
||||
|
||||
def test_exceeds_max_days(self):
|
||||
"""Test exceeds max days limit."""
|
||||
start = "2025-01-01"
|
||||
end = "2025-02-01" # 32 days
|
||||
with pytest.raises(ValueError, match="Date range too large: 32 days"):
|
||||
validate_date_range(start, end, max_days=30)
|
||||
|
||||
def test_invalid_order(self):
|
||||
"""Test start > end."""
|
||||
with pytest.raises(ValueError, match="must be <= end_date"):
|
||||
validate_date_range("2025-01-25", "2025-01-20", max_days=30)
|
||||
|
||||
def test_future_date_rejected(self):
|
||||
"""Test future dates are rejected."""
|
||||
tomorrow = (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
next_week = (datetime.now() + timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be in the future"):
|
||||
validate_date_range(tomorrow, next_week, max_days=30)
|
||||
|
||||
def test_today_allowed(self):
|
||||
"""Test today's date is allowed."""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
# Should not raise
|
||||
validate_date_range(today, today, max_days=30)
|
||||
|
||||
def test_past_dates_allowed(self):
|
||||
"""Test past dates are allowed."""
|
||||
# Should not raise
|
||||
validate_date_range("2020-01-01", "2020-01-10", max_days=30)
|
||||
|
||||
def test_invalid_date_format(self):
|
||||
"""Test invalid date format raises error."""
|
||||
with pytest.raises(ValueError, match="Invalid date format"):
|
||||
validate_date_range("01-20-2025", "01-21-2025", max_days=30)
|
||||
|
||||
def test_custom_max_days(self):
|
||||
"""Test custom max_days parameter."""
|
||||
# Should raise with max_days=5
|
||||
with pytest.raises(ValueError, match="Date range too large: 10 days"):
|
||||
validate_date_range("2025-01-01", "2025-01-10", max_days=5)
|
||||
|
||||
|
||||
class TestGetMaxSimulationDays:
|
||||
"""Test get_max_simulation_days function."""
|
||||
|
||||
def test_default_value(self, monkeypatch):
|
||||
"""Test default value when env var not set."""
|
||||
monkeypatch.delenv("MAX_SIMULATION_DAYS", raising=False)
|
||||
result = get_max_simulation_days()
|
||||
assert result == 30
|
||||
|
||||
def test_env_var_override(self, monkeypatch):
|
||||
"""Test environment variable override."""
|
||||
monkeypatch.setenv("MAX_SIMULATION_DAYS", "60")
|
||||
result = get_max_simulation_days()
|
||||
assert result == 60
|
||||
|
||||
def test_env_var_string_to_int(self, monkeypatch):
|
||||
"""Test env var is converted to int."""
|
||||
monkeypatch.setenv("MAX_SIMULATION_DAYS", "100")
|
||||
result = get_max_simulation_days()
|
||||
assert isinstance(result, int)
|
||||
assert result == 100
|
||||
572
tests/unit/test_price_data_manager.py
Normal file
572
tests/unit/test_price_data_manager.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""
|
||||
Unit tests for api/price_data_manager.py
|
||||
|
||||
Tests price data management, coverage detection, download prioritization,
|
||||
and rate limit handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock, call
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import sqlite3
|
||||
|
||||
from api.price_data_manager import (
|
||||
PriceDataManager,
|
||||
RateLimitError,
|
||||
DownloadError
|
||||
)
|
||||
from api.database import initialize_database, get_db_connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
"""Create temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
initialize_database(db_path)
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
if os.path.exists(db_path):
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_symbols_config():
|
||||
"""Create temporary symbols config for testing."""
|
||||
symbols_data = {
|
||||
"symbols": ["AAPL", "MSFT", "GOOGL"],
|
||||
"description": "Test symbols",
|
||||
"total_symbols": 3
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(symbols_data, f)
|
||||
config_path = f.name
|
||||
|
||||
yield config_path
|
||||
|
||||
# Cleanup
|
||||
if os.path.exists(config_path):
|
||||
os.unlink(config_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(temp_db, temp_symbols_config):
|
||||
"""Create PriceDataManager instance with temp database and config."""
|
||||
return PriceDataManager(
|
||||
db_path=temp_db,
|
||||
symbols_config=temp_symbols_config,
|
||||
api_key="test_api_key"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def populated_db(temp_db):
|
||||
"""Create database with sample price data."""
|
||||
conn = get_db_connection(temp_db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Insert sample price data for multiple symbols and dates
|
||||
test_data = [
|
||||
("AAPL", "2025-01-20", 150.0, 155.0, 149.0, 154.0, 1000000),
|
||||
("AAPL", "2025-01-21", 154.0, 156.0, 153.0, 155.0, 1100000),
|
||||
("MSFT", "2025-01-20", 380.0, 385.0, 379.0, 383.0, 2000000),
|
||||
("MSFT", "2025-01-21", 383.0, 387.0, 382.0, 386.0, 2100000),
|
||||
("GOOGL", "2025-01-20", 140.0, 142.0, 139.0, 141.0, 1500000),
|
||||
# Note: GOOGL missing 2025-01-21
|
||||
]
|
||||
|
||||
created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
for symbol, date, open_p, high, low, close, volume in test_data:
|
||||
cursor.execute("""
|
||||
INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (symbol, date, open_p, high, low, close, volume, created_at))
|
||||
|
||||
# Insert coverage data
|
||||
cursor.execute("""
|
||||
INSERT INTO price_data_coverage (symbol, start_date, end_date, downloaded_at, source)
|
||||
VALUES
|
||||
('AAPL', '2025-01-20', '2025-01-21', ?, 'test'),
|
||||
('MSFT', '2025-01-20', '2025-01-21', ?, 'test'),
|
||||
('GOOGL', '2025-01-20', '2025-01-20', ?, 'test')
|
||||
""", (created_at, created_at, created_at))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return temp_db
|
||||
|
||||
|
||||
class TestPriceDataManagerInit:
|
||||
"""Test PriceDataManager initialization."""
|
||||
|
||||
def test_init_with_defaults(self, temp_db):
|
||||
"""Test initialization with default parameters."""
|
||||
with patch.dict(os.environ, {"ALPHAADVANTAGE_API_KEY": "env_key"}):
|
||||
manager = PriceDataManager(db_path=temp_db)
|
||||
assert manager.db_path == temp_db
|
||||
assert manager.api_key == "env_key"
|
||||
assert manager.symbols_config == "configs/nasdaq100_symbols.json"
|
||||
|
||||
def test_init_with_custom_params(self, temp_db, temp_symbols_config):
|
||||
"""Test initialization with custom parameters."""
|
||||
manager = PriceDataManager(
|
||||
db_path=temp_db,
|
||||
symbols_config=temp_symbols_config,
|
||||
api_key="custom_key"
|
||||
)
|
||||
assert manager.db_path == temp_db
|
||||
assert manager.api_key == "custom_key"
|
||||
assert manager.symbols_config == temp_symbols_config
|
||||
|
||||
def test_load_symbols_success(self, manager):
|
||||
"""Test successful symbol loading from config."""
|
||||
assert manager.symbols == ["AAPL", "MSFT", "GOOGL"]
|
||||
|
||||
def test_load_symbols_file_not_found(self, temp_db):
|
||||
"""Test handling of missing symbols config file uses fallback."""
|
||||
manager = PriceDataManager(
|
||||
db_path=temp_db,
|
||||
symbols_config="nonexistent.json",
|
||||
api_key="test_key"
|
||||
)
|
||||
# Should use fallback symbols list
|
||||
assert len(manager.symbols) > 0
|
||||
assert "AAPL" in manager.symbols
|
||||
|
||||
def test_load_symbols_invalid_json(self, temp_db):
|
||||
"""Test handling of invalid JSON in symbols config."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
f.write("invalid json{")
|
||||
bad_config = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
PriceDataManager(
|
||||
db_path=temp_db,
|
||||
symbols_config=bad_config,
|
||||
api_key="test_key"
|
||||
)
|
||||
finally:
|
||||
os.unlink(bad_config)
|
||||
|
||||
def test_missing_api_key(self, temp_db, temp_symbols_config):
|
||||
"""Test initialization without API key."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
manager = PriceDataManager(
|
||||
db_path=temp_db,
|
||||
symbols_config=temp_symbols_config
|
||||
)
|
||||
assert manager.api_key is None
|
||||
|
||||
|
||||
class TestGetSymbolDates:
|
||||
"""Test get_symbol_dates method."""
|
||||
|
||||
def test_get_symbol_dates_with_data(self, manager, populated_db):
|
||||
"""Test retrieving dates for symbol with data."""
|
||||
manager.db_path = populated_db
|
||||
dates = manager.get_symbol_dates("AAPL")
|
||||
assert dates == {"2025-01-20", "2025-01-21"}
|
||||
|
||||
def test_get_symbol_dates_no_data(self, manager):
|
||||
"""Test retrieving dates for symbol without data."""
|
||||
dates = manager.get_symbol_dates("TSLA")
|
||||
assert dates == set()
|
||||
|
||||
def test_get_symbol_dates_partial_data(self, manager, populated_db):
|
||||
"""Test retrieving dates for symbol with partial data."""
|
||||
manager.db_path = populated_db
|
||||
dates = manager.get_symbol_dates("GOOGL")
|
||||
assert dates == {"2025-01-20"}
|
||||
|
||||
|
||||
class TestGetMissingCoverage:
|
||||
"""Test get_missing_coverage method."""
|
||||
|
||||
def test_missing_coverage_empty_db(self, manager):
|
||||
"""Test missing coverage with empty database."""
|
||||
missing = manager.get_missing_coverage("2025-01-20", "2025-01-21")
|
||||
|
||||
# All symbols should be missing all dates
|
||||
assert "AAPL" in missing
|
||||
assert "MSFT" in missing
|
||||
assert "GOOGL" in missing
|
||||
assert missing["AAPL"] == {"2025-01-20", "2025-01-21"}
|
||||
|
||||
def test_missing_coverage_partial_db(self, manager, populated_db):
|
||||
"""Test missing coverage with partial data."""
|
||||
manager.db_path = populated_db
|
||||
missing = manager.get_missing_coverage("2025-01-20", "2025-01-21")
|
||||
|
||||
# AAPL and MSFT have all dates, GOOGL missing 2025-01-21
|
||||
assert "AAPL" not in missing or len(missing["AAPL"]) == 0
|
||||
assert "MSFT" not in missing or len(missing["MSFT"]) == 0
|
||||
assert "GOOGL" in missing
|
||||
assert missing["GOOGL"] == {"2025-01-21"}
|
||||
|
||||
def test_missing_coverage_complete_db(self, manager, populated_db):
|
||||
"""Test missing coverage when all data available."""
|
||||
manager.db_path = populated_db
|
||||
missing = manager.get_missing_coverage("2025-01-20", "2025-01-20")
|
||||
|
||||
# All symbols have 2025-01-20
|
||||
for symbol in ["AAPL", "MSFT", "GOOGL"]:
|
||||
assert symbol not in missing or len(missing[symbol]) == 0
|
||||
|
||||
def test_missing_coverage_single_date(self, manager, populated_db):
|
||||
"""Test missing coverage for single date."""
|
||||
manager.db_path = populated_db
|
||||
missing = manager.get_missing_coverage("2025-01-21", "2025-01-21")
|
||||
|
||||
# Only GOOGL missing 2025-01-21
|
||||
assert "GOOGL" in missing
|
||||
assert missing["GOOGL"] == {"2025-01-21"}
|
||||
|
||||
|
||||
class TestPrioritizeDownloads:
|
||||
"""Test prioritize_downloads method."""
|
||||
|
||||
def test_prioritize_single_symbol(self, manager):
|
||||
"""Test prioritization with single symbol missing data."""
|
||||
missing_coverage = {"AAPL": {"2025-01-20", "2025-01-21"}}
|
||||
requested_dates = {"2025-01-20", "2025-01-21"}
|
||||
|
||||
prioritized = manager.prioritize_downloads(missing_coverage, requested_dates)
|
||||
assert prioritized == ["AAPL"]
|
||||
|
||||
def test_prioritize_multiple_symbols_equal_impact(self, manager):
|
||||
"""Test prioritization with equal impact symbols."""
|
||||
missing_coverage = {
|
||||
"AAPL": {"2025-01-20", "2025-01-21"},
|
||||
"MSFT": {"2025-01-20", "2025-01-21"}
|
||||
}
|
||||
requested_dates = {"2025-01-20", "2025-01-21"}
|
||||
|
||||
prioritized = manager.prioritize_downloads(missing_coverage, requested_dates)
|
||||
# Both should be included (order may vary)
|
||||
assert set(prioritized) == {"AAPL", "MSFT"}
|
||||
assert len(prioritized) == 2
|
||||
|
||||
def test_prioritize_by_impact(self, manager):
|
||||
"""Test prioritization by date completion impact."""
|
||||
missing_coverage = {
|
||||
"AAPL": {"2025-01-20", "2025-01-21", "2025-01-22"}, # High impact (3 dates)
|
||||
"MSFT": {"2025-01-20"}, # Low impact (1 date)
|
||||
"GOOGL": {"2025-01-21", "2025-01-22"} # Medium impact (2 dates)
|
||||
}
|
||||
requested_dates = {"2025-01-20", "2025-01-21", "2025-01-22"}
|
||||
|
||||
prioritized = manager.prioritize_downloads(missing_coverage, requested_dates)
|
||||
|
||||
# AAPL should be first (highest impact)
|
||||
assert prioritized[0] == "AAPL"
|
||||
# GOOGL should be second
|
||||
assert prioritized[1] == "GOOGL"
|
||||
# MSFT should be last (lowest impact)
|
||||
assert prioritized[2] == "MSFT"
|
||||
|
||||
def test_prioritize_excludes_irrelevant_dates(self, manager):
|
||||
"""Test that symbols with no impact on requested dates are excluded."""
|
||||
missing_coverage = {
|
||||
"AAPL": {"2025-01-20"}, # Relevant
|
||||
"MSFT": {"2025-01-25", "2025-01-26"} # Not relevant
|
||||
}
|
||||
requested_dates = {"2025-01-20", "2025-01-21"}
|
||||
|
||||
prioritized = manager.prioritize_downloads(missing_coverage, requested_dates)
|
||||
|
||||
# Only AAPL should be included
|
||||
assert prioritized == ["AAPL"]
|
||||
|
||||
|
||||
class TestGetAvailableTradingDates:
|
||||
"""Test get_available_trading_dates method."""
|
||||
|
||||
def test_available_dates_empty_db(self, manager):
|
||||
"""Test with empty database returns no dates."""
|
||||
available = manager.get_available_trading_dates("2025-01-20", "2025-01-21")
|
||||
assert available == []
|
||||
|
||||
def test_available_dates_complete_range(self, manager, populated_db):
|
||||
"""Test with complete data for all symbols in range."""
|
||||
manager.db_path = populated_db
|
||||
available = manager.get_available_trading_dates("2025-01-20", "2025-01-20")
|
||||
assert available == ["2025-01-20"]
|
||||
|
||||
def test_available_dates_partial_range(self, manager, populated_db):
|
||||
"""Test with partial data (some symbols missing some dates)."""
|
||||
manager.db_path = populated_db
|
||||
available = manager.get_available_trading_dates("2025-01-20", "2025-01-21")
|
||||
|
||||
# 2025-01-20 has all symbols, 2025-01-21 missing GOOGL
|
||||
assert available == ["2025-01-20"]
|
||||
|
||||
def test_available_dates_filters_incomplete(self, manager, populated_db):
|
||||
"""Test that dates with incomplete symbol coverage are filtered."""
|
||||
manager.db_path = populated_db
|
||||
available = manager.get_available_trading_dates("2025-01-21", "2025-01-21")
|
||||
|
||||
# 2025-01-21 is missing GOOGL, so not complete
|
||||
assert available == []
|
||||
|
||||
|
||||
class TestDownloadSymbol:
|
||||
"""Test _download_symbol method (Alpha Vantage API calls)."""
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_download_success(self, mock_get, manager):
|
||||
"""Test successful symbol download."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"Meta Data": {"2. Symbol": "AAPL"},
|
||||
"Time Series (Daily)": {
|
||||
"2025-01-20": {
|
||||
"1. open": "150.00",
|
||||
"2. high": "155.00",
|
||||
"3. low": "149.00",
|
||||
"4. close": "154.00",
|
||||
"5. volume": "1000000"
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
data = manager._download_symbol("AAPL")
|
||||
|
||||
assert data["Meta Data"]["2. Symbol"] == "AAPL"
|
||||
assert "2025-01-20" in data["Time Series (Daily)"]
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_download_rate_limit(self, mock_get, manager):
|
||||
"""Test rate limit detection."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"Note": "Thank you for using Alpha Vantage! Our standard API call frequency is 25 calls per day."
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
manager._download_symbol("AAPL")
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_download_http_error(self, mock_get, manager):
|
||||
"""Test HTTP error handling."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.raise_for_status.side_effect = Exception("Server error")
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(DownloadError):
|
||||
manager._download_symbol("AAPL")
|
||||
|
||||
@patch('api.price_data_manager.requests.get')
|
||||
def test_download_invalid_response(self, mock_get, manager):
|
||||
"""Test handling of invalid API response."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {} # Missing required fields
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(DownloadError, match="Invalid response format"):
|
||||
manager._download_symbol("AAPL")
|
||||
|
||||
def test_download_missing_api_key(self, manager):
|
||||
"""Test download without API key."""
|
||||
manager.api_key = None
|
||||
|
||||
with pytest.raises(DownloadError, match="API key not configured"):
|
||||
manager._download_symbol("AAPL")
|
||||
|
||||
|
||||
class TestStoreSymbolData:
|
||||
"""Test _store_symbol_data method."""
|
||||
|
||||
def test_store_symbol_data_success(self, manager):
|
||||
"""Test successful data storage."""
|
||||
data = {
|
||||
"Meta Data": {"2. Symbol": "AAPL"},
|
||||
"Time Series (Daily)": {
|
||||
"2025-01-20": {
|
||||
"1. open": "150.00",
|
||||
"2. high": "155.00",
|
||||
"3. low": "149.00",
|
||||
"4. close": "154.00",
|
||||
"5. volume": "1000000"
|
||||
},
|
||||
"2025-01-21": {
|
||||
"1. open": "154.00",
|
||||
"2. high": "156.00",
|
||||
"3. low": "153.00",
|
||||
"4. close": "155.00",
|
||||
"5. volume": "1100000"
|
||||
}
|
||||
}
|
||||
}
|
||||
requested_dates = {"2025-01-20", "2025-01-21"}
|
||||
|
||||
stored_dates = manager._store_symbol_data("AAPL", data, requested_dates)
|
||||
|
||||
# Returns list, not set
|
||||
assert set(stored_dates) == {"2025-01-20", "2025-01-21"}
|
||||
|
||||
# Verify data in database
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM price_data WHERE symbol = 'AAPL'")
|
||||
count = cursor.fetchone()[0]
|
||||
assert count == 2
|
||||
conn.close()
|
||||
|
||||
def test_store_filters_by_requested_dates(self, manager):
|
||||
"""Test that only requested dates are stored."""
|
||||
data = {
|
||||
"Meta Data": {"2. Symbol": "AAPL"},
|
||||
"Time Series (Daily)": {
|
||||
"2025-01-20": {
|
||||
"1. open": "150.00",
|
||||
"2. high": "155.00",
|
||||
"3. low": "149.00",
|
||||
"4. close": "154.00",
|
||||
"5. volume": "1000000"
|
||||
},
|
||||
"2025-01-21": {
|
||||
"1. open": "154.00",
|
||||
"2. high": "156.00",
|
||||
"3. low": "153.00",
|
||||
"4. close": "155.00",
|
||||
"5. volume": "1100000"
|
||||
}
|
||||
}
|
||||
}
|
||||
requested_dates = {"2025-01-20"} # Only request one date
|
||||
|
||||
stored_dates = manager._store_symbol_data("AAPL", data, requested_dates)
|
||||
|
||||
# Returns list, not set
|
||||
assert set(stored_dates) == {"2025-01-20"}
|
||||
|
||||
# Verify only one date in database
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM price_data WHERE symbol = 'AAPL'")
|
||||
count = cursor.fetchone()[0]
|
||||
assert count == 1
|
||||
conn.close()
|
||||
|
||||
|
||||
class TestUpdateCoverage:
|
||||
"""Test _update_coverage method."""
|
||||
|
||||
def test_update_coverage_new_symbol(self, manager):
|
||||
"""Test coverage tracking for new symbol."""
|
||||
manager._update_coverage("AAPL", "2025-01-20", "2025-01-21")
|
||||
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT symbol, start_date, end_date, source
|
||||
FROM price_data_coverage
|
||||
WHERE symbol = 'AAPL'
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
assert row is not None
|
||||
assert row[0] == "AAPL"
|
||||
assert row[1] == "2025-01-20"
|
||||
assert row[2] == "2025-01-21"
|
||||
assert row[3] == "alpha_vantage"
|
||||
|
||||
def test_update_coverage_existing_symbol(self, manager, populated_db):
|
||||
"""Test coverage update for existing symbol."""
|
||||
manager.db_path = populated_db
|
||||
|
||||
# Update with new range
|
||||
manager._update_coverage("AAPL", "2025-01-22", "2025-01-23")
|
||||
|
||||
conn = get_db_connection(manager.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM price_data_coverage WHERE symbol = 'AAPL'
|
||||
""")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
# Should have 2 coverage records now
|
||||
assert count == 2
|
||||
|
||||
|
||||
class TestDownloadMissingDataPrioritized:
|
||||
"""Test download_missing_data_prioritized method (integration)."""
|
||||
|
||||
@patch.object(PriceDataManager, '_download_symbol')
|
||||
@patch.object(PriceDataManager, '_store_symbol_data')
|
||||
@patch.object(PriceDataManager, '_update_coverage')
|
||||
def test_download_all_success(self, mock_update, mock_store, mock_download, manager):
|
||||
"""Test successful download of all missing symbols."""
|
||||
missing_coverage = {
|
||||
"AAPL": {"2025-01-20"},
|
||||
"MSFT": {"2025-01-20"}
|
||||
}
|
||||
requested_dates = {"2025-01-20"}
|
||||
|
||||
mock_download.return_value = {"Meta Data": {}, "Time Series (Daily)": {}}
|
||||
mock_store.return_value = {"2025-01-20"}
|
||||
|
||||
result = manager.download_missing_data_prioritized(missing_coverage, requested_dates)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["downloaded"]) == 2
|
||||
assert result["rate_limited"] is False
|
||||
assert mock_download.call_count == 2
|
||||
|
||||
@patch.object(PriceDataManager, '_download_symbol')
|
||||
def test_download_rate_limited_mid_process(self, mock_download, manager):
|
||||
"""Test graceful handling of rate limit during downloads."""
|
||||
missing_coverage = {
|
||||
"AAPL": {"2025-01-20"},
|
||||
"MSFT": {"2025-01-20"},
|
||||
"GOOGL": {"2025-01-20"}
|
||||
}
|
||||
requested_dates = {"2025-01-20"}
|
||||
|
||||
# First call succeeds, second raises rate limit
|
||||
mock_download.side_effect = [
|
||||
{"Meta Data": {"2. Symbol": "AAPL"}, "Time Series (Daily)": {"2025-01-20": {}}},
|
||||
RateLimitError("Rate limit reached")
|
||||
]
|
||||
|
||||
with patch.object(manager, '_store_symbol_data', return_value={"2025-01-20"}):
|
||||
with patch.object(manager, '_update_coverage'):
|
||||
result = manager.download_missing_data_prioritized(missing_coverage, requested_dates)
|
||||
|
||||
assert result["success"] is True # Partial success
|
||||
assert len(result["downloaded"]) == 1
|
||||
assert result["rate_limited"] is True
|
||||
assert len(result["failed"]) == 2 # MSFT and GOOGL not downloaded
|
||||
|
||||
@patch.object(PriceDataManager, '_download_symbol')
|
||||
def test_download_all_failed(self, mock_download, manager):
|
||||
"""Test handling when all downloads fail."""
|
||||
missing_coverage = {"AAPL": {"2025-01-20"}}
|
||||
requested_dates = {"2025-01-20"}
|
||||
|
||||
mock_download.side_effect = DownloadError("Network error")
|
||||
|
||||
result = manager.download_missing_data_prioritized(missing_coverage, requested_dates)
|
||||
|
||||
assert result["success"] is False
|
||||
assert len(result["downloaded"]) == 0
|
||||
assert len(result["failed"]) == 1
|
||||
@@ -12,6 +12,7 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
from tools.general_tools import get_config_value
|
||||
from api.database import get_db_connection
|
||||
|
||||
all_nasdaq_100_symbols = [
|
||||
"NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA",
|
||||
@@ -47,143 +48,95 @@ def get_yesterday_date(today_date: str) -> str:
|
||||
yesterday_date = yesterday_dt.strftime("%Y-%m-%d")
|
||||
return yesterday_date
|
||||
|
||||
def get_open_prices(today_date: str, symbols: List[str], merged_path: Optional[str] = None) -> Dict[str, Optional[float]]:
|
||||
"""从 data/merged.jsonl 中读取指定日期与标的的开盘价。
|
||||
def get_open_prices(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Dict[str, Optional[float]]:
|
||||
"""从 price_data 数据库表中读取指定日期与标的的开盘价。
|
||||
|
||||
Args:
|
||||
today_date: 日期字符串,格式 YYYY-MM-DD。
|
||||
symbols: 需要查询的股票代码列表。
|
||||
merged_path: 可选,自定义 merged.jsonl 路径;默认读取项目根目录下 data/merged.jsonl。
|
||||
merged_path: 已废弃,保留用于向后兼容。
|
||||
db_path: 数据库路径,默认 data/jobs.db。
|
||||
|
||||
Returns:
|
||||
{symbol_price: open_price 或 None} 的字典;若未找到对应日期或标的,则值为 None。
|
||||
"""
|
||||
wanted = set(symbols)
|
||||
results: Dict[str, Optional[float]] = {}
|
||||
|
||||
if merged_path is None:
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
merged_file = base_dir / "data" / "merged.jsonl"
|
||||
else:
|
||||
merged_file = Path(merged_path)
|
||||
try:
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if not merged_file.exists():
|
||||
return results
|
||||
# Query all requested symbols for the date
|
||||
placeholders = ','.join('?' * len(symbols))
|
||||
query = f"""
|
||||
SELECT symbol, open
|
||||
FROM price_data
|
||||
WHERE date = ? AND symbol IN ({placeholders})
|
||||
"""
|
||||
|
||||
with merged_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
except Exception:
|
||||
continue
|
||||
meta = doc.get("Meta Data", {}) if isinstance(doc, dict) else {}
|
||||
sym = meta.get("2. Symbol")
|
||||
if sym not in wanted:
|
||||
continue
|
||||
series = doc.get("Time Series (Daily)", {})
|
||||
if not isinstance(series, dict):
|
||||
continue
|
||||
bar = series.get(today_date)
|
||||
if isinstance(bar, dict):
|
||||
open_val = bar.get("1. buy price")
|
||||
try:
|
||||
results[f'{sym}_price'] = float(open_val) if open_val is not None else None
|
||||
except Exception:
|
||||
results[f'{sym}_price'] = None
|
||||
params = [today_date] + list(symbols)
|
||||
cursor.execute(query, params)
|
||||
|
||||
# Build results dict
|
||||
for row in cursor.fetchall():
|
||||
symbol = row[0]
|
||||
open_price = row[1]
|
||||
results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
# Log error but return empty results to maintain compatibility
|
||||
print(f"Error querying price data: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]:
|
||||
"""从 data/merged.jsonl 中读取指定日期与股票的昨日买入价和卖出价。
|
||||
def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]:
|
||||
"""从 price_data 数据库表中读取指定日期与股票的昨日买入价和卖出价。
|
||||
|
||||
Args:
|
||||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||||
symbols: 需要查询的股票代码列表。
|
||||
merged_path: 可选,自定义 merged.jsonl 路径;默认读取项目根目录下 data/merged.jsonl。
|
||||
merged_path: 已废弃,保留用于向后兼容。
|
||||
db_path: 数据库路径,默认 data/jobs.db。
|
||||
|
||||
Returns:
|
||||
(买入价字典, 卖出价字典) 的元组;若未找到对应日期或标的,则值为 None。
|
||||
"""
|
||||
wanted = set(symbols)
|
||||
buy_results: Dict[str, Optional[float]] = {}
|
||||
sell_results: Dict[str, Optional[float]] = {}
|
||||
|
||||
if merged_path is None:
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
merged_file = base_dir / "data" / "merged.jsonl"
|
||||
else:
|
||||
merged_file = Path(merged_path)
|
||||
|
||||
if not merged_file.exists():
|
||||
return buy_results, sell_results
|
||||
|
||||
yesterday_date = get_yesterday_date(today_date)
|
||||
|
||||
with merged_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
except Exception:
|
||||
continue
|
||||
meta = doc.get("Meta Data", {}) if isinstance(doc, dict) else {}
|
||||
sym = meta.get("2. Symbol")
|
||||
if sym not in wanted:
|
||||
continue
|
||||
series = doc.get("Time Series (Daily)", {})
|
||||
if not isinstance(series, dict):
|
||||
continue
|
||||
|
||||
# 尝试获取昨日买入价和卖出价
|
||||
bar = series.get(yesterday_date)
|
||||
if isinstance(bar, dict):
|
||||
buy_val = bar.get("1. buy price") # 买入价字段
|
||||
sell_val = bar.get("4. sell price") # 卖出价字段
|
||||
|
||||
try:
|
||||
buy_price = float(buy_val) if buy_val is not None else None
|
||||
sell_price = float(sell_val) if sell_val is not None else None
|
||||
buy_results[f'{sym}_price'] = buy_price
|
||||
sell_results[f'{sym}_price'] = sell_price
|
||||
except Exception:
|
||||
buy_results[f'{sym}_price'] = None
|
||||
sell_results[f'{sym}_price'] = None
|
||||
else:
|
||||
# 如果昨日没有数据,尝试向前查找最近的交易日
|
||||
today_dt = datetime.strptime(today_date, "%Y-%m-%d")
|
||||
yesterday_dt = today_dt - timedelta(days=1)
|
||||
current_date = yesterday_dt
|
||||
found_data = False
|
||||
|
||||
# 最多向前查找5个交易日
|
||||
for _ in range(5):
|
||||
current_date -= timedelta(days=1)
|
||||
# 跳过周末
|
||||
while current_date.weekday() >= 5:
|
||||
current_date -= timedelta(days=1)
|
||||
|
||||
check_date = current_date.strftime("%Y-%m-%d")
|
||||
bar = series.get(check_date)
|
||||
if isinstance(bar, dict):
|
||||
buy_val = bar.get("1. buy price")
|
||||
sell_val = bar.get("4. sell price")
|
||||
|
||||
try:
|
||||
buy_price = float(buy_val) if buy_val is not None else None
|
||||
sell_price = float(sell_val) if sell_val is not None else None
|
||||
buy_results[f'{sym}_price'] = buy_price
|
||||
sell_results[f'{sym}_price'] = sell_price
|
||||
found_data = True
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not found_data:
|
||||
buy_results[f'{sym}_price'] = None
|
||||
sell_results[f'{sym}_price'] = None
|
||||
try:
|
||||
conn = get_db_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query all requested symbols for yesterday's date
|
||||
placeholders = ','.join('?' * len(symbols))
|
||||
query = f"""
|
||||
SELECT symbol, open, close
|
||||
FROM price_data
|
||||
WHERE date = ? AND symbol IN ({placeholders})
|
||||
"""
|
||||
|
||||
params = [yesterday_date] + list(symbols)
|
||||
cursor.execute(query, params)
|
||||
|
||||
# Build results dicts
|
||||
for row in cursor.fetchall():
|
||||
symbol = row[0]
|
||||
open_price = row[1] # Buy price (open)
|
||||
close_price = row[2] # Sell price (close)
|
||||
|
||||
buy_results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
|
||||
sell_results[f'{symbol}_price'] = float(close_price) if close_price is not None else None
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
# Log error but return empty results to maintain compatibility
|
||||
print(f"Error querying price data: {e}")
|
||||
|
||||
return buy_results, sell_results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user