Files
AI-Trader/tools/price_tools.py
Bill 1bfcdd78b8 feat: complete v0.3.0 database migration and configuration
Final phase of v0.3.0 implementation - all core features complete.

Price Tools Migration:
- Update get_open_prices() to query price_data table
- Update get_yesterday_open_and_close_price() to query database
- Remove merged.jsonl file I/O (replaced with SQLite queries)
- Maintain backward-compatible function signatures
- Add db_path parameter (default: data/jobs.db)

Configuration:
- Add AUTO_DOWNLOAD_PRICE_DATA to .env.example (default: true)
- Add MAX_SIMULATION_DAYS to .env.example (default: 30)
- Document new configuration options

Documentation:
- Comprehensive CHANGELOG updates for v0.3.0
- Document all breaking changes (API format, data storage, config)
- Document new features (on-demand downloads, date ranges, database)
- Document migration path (scripts/migrate_price_data.py)
- Clear upgrade instructions

Breaking Changes (v0.3.0):
1. API request format: date_range -> start_date/end_date
2. Data storage: merged.jsonl -> price_data table
3. Config variables: removed RUNTIME_ENV_PATH, MCP ports, WEB_HTTP_PORT
4. Added AUTO_DOWNLOAD_PRICE_DATA, MAX_SIMULATION_DAYS

Migration Steps:
1. Run: python scripts/migrate_price_data.py
2. Update API clients to use new date format
3. Update .env with new variables
4. Remove old config variables

Status: v0.3.0 implementation complete
Ready for: Testing, deployment, and release
2025-10-31 16:44:46 -04:00

324 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from dotenv import load_dotenv
load_dotenv()
import json
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import sys
# 将项目根目录加入 Python 路径,便于从子目录直接运行本文件
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from tools.general_tools import get_config_value
from api.database import get_db_connection
all_nasdaq_100_symbols = [
"NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA",
"NFLX", "PLTR", "COST", "ASML", "AMD", "CSCO", "AZN", "TMUS", "MU", "LIN",
"PEP", "SHOP", "APP", "INTU", "AMAT", "LRCX", "PDD", "QCOM", "ARM", "INTC",
"BKNG", "AMGN", "TXN", "ISRG", "GILD", "KLAC", "PANW", "ADBE", "HON",
"CRWD", "CEG", "ADI", "ADP", "DASH", "CMCSA", "VRTX", "MELI", "SBUX",
"CDNS", "ORLY", "SNPS", "MSTR", "MDLZ", "ABNB", "MRVL", "CTAS", "TRI",
"MAR", "MNST", "CSX", "ADSK", "PYPL", "FTNT", "AEP", "WDAY", "REGN", "ROP",
"NXPI", "DDOG", "AXON", "ROST", "IDXX", "EA", "PCAR", "FAST", "EXC", "TTWO",
"XEL", "ZS", "PAYX", "WBD", "BKR", "CPRT", "CCEP", "FANG", "TEAM", "CHTR",
"KDP", "MCHP", "GEHC", "VRSK", "CTSH", "CSGP", "KHC", "ODFL", "DXCM", "TTD",
"ON", "BIIB", "LULU", "CDW", "GFS"
]
def get_yesterday_date(today_date: str) -> str:
"""
获取昨日日期,考虑休市日。
Args:
today_date: 日期字符串,格式 YYYY-MM-DD代表今天日期。
Returns:
yesterday_date: 昨日日期字符串,格式 YYYY-MM-DD。
"""
# 计算昨日日期,考虑休市日
today_dt = datetime.strptime(today_date, "%Y-%m-%d")
yesterday_dt = today_dt - timedelta(days=1)
# 如果昨日是周末,向前找到最近的交易日
while yesterday_dt.weekday() >= 5: # 5=Saturday, 6=Sunday
yesterday_dt -= timedelta(days=1)
yesterday_date = yesterday_dt.strftime("%Y-%m-%d")
return yesterday_date
def get_open_prices(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Dict[str, Optional[float]]:
"""从 price_data 数据库表中读取指定日期与标的的开盘价。
Args:
today_date: 日期字符串,格式 YYYY-MM-DD。
symbols: 需要查询的股票代码列表。
merged_path: 已废弃,保留用于向后兼容。
db_path: 数据库路径,默认 data/jobs.db。
Returns:
{symbol_price: open_price 或 None} 的字典;若未找到对应日期或标的,则值为 None。
"""
results: Dict[str, Optional[float]] = {}
try:
conn = get_db_connection(db_path)
cursor = conn.cursor()
# Query all requested symbols for the date
placeholders = ','.join('?' * len(symbols))
query = f"""
SELECT symbol, open
FROM price_data
WHERE date = ? AND symbol IN ({placeholders})
"""
params = [today_date] + list(symbols)
cursor.execute(query, params)
# Build results dict
for row in cursor.fetchall():
symbol = row[0]
open_price = row[1]
results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
conn.close()
except Exception as e:
# Log error but return empty results to maintain compatibility
print(f"Error querying price data: {e}")
return results
def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]:
"""从 price_data 数据库表中读取指定日期与股票的昨日买入价和卖出价。
Args:
today_date: 日期字符串,格式 YYYY-MM-DD代表今天日期。
symbols: 需要查询的股票代码列表。
merged_path: 已废弃,保留用于向后兼容。
db_path: 数据库路径,默认 data/jobs.db。
Returns:
(买入价字典, 卖出价字典) 的元组;若未找到对应日期或标的,则值为 None。
"""
buy_results: Dict[str, Optional[float]] = {}
sell_results: Dict[str, Optional[float]] = {}
yesterday_date = get_yesterday_date(today_date)
try:
conn = get_db_connection(db_path)
cursor = conn.cursor()
# Query all requested symbols for yesterday's date
placeholders = ','.join('?' * len(symbols))
query = f"""
SELECT symbol, open, close
FROM price_data
WHERE date = ? AND symbol IN ({placeholders})
"""
params = [yesterday_date] + list(symbols)
cursor.execute(query, params)
# Build results dicts
for row in cursor.fetchall():
symbol = row[0]
open_price = row[1] # Buy price (open)
close_price = row[2] # Sell price (close)
buy_results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
sell_results[f'{symbol}_price'] = float(close_price) if close_price is not None else None
conn.close()
except Exception as e:
# Log error but return empty results to maintain compatibility
print(f"Error querying price data: {e}")
return buy_results, sell_results
def get_yesterday_profit(today_date: str, yesterday_buy_prices: Dict[str, Optional[float]], yesterday_sell_prices: Dict[str, Optional[float]], yesterday_init_position: Dict[str, float]) -> Dict[str, float]:
"""
获取今日开盘时持仓的收益,收益计算方式为:(昨日收盘价格 - 昨日开盘价格)*当前持仓。
Args:
today_date: 日期字符串,格式 YYYY-MM-DD代表今天日期。
yesterday_buy_prices: 昨日开盘价格字典,格式为 {symbol_price: price}
yesterday_sell_prices: 昨日收盘价格字典,格式为 {symbol_price: price}
yesterday_init_position: 昨日初始持仓字典,格式为 {symbol: weight}
Returns:
{symbol: profit} 的字典;若未找到对应日期或标的,则值为 0.0。
"""
profit_dict = {}
# 遍历所有股票代码
for symbol in all_nasdaq_100_symbols:
symbol_price_key = f'{symbol}_price'
# 获取昨日开盘价和收盘价
buy_price = yesterday_buy_prices.get(symbol_price_key)
sell_price = yesterday_sell_prices.get(symbol_price_key)
# 获取昨日持仓权重
position_weight = yesterday_init_position.get(symbol, 0.0)
# 计算收益:(收盘价 - 开盘价) * 持仓权重
if buy_price is not None and sell_price is not None and position_weight > 0:
profit = (sell_price - buy_price) * position_weight
profit_dict[symbol] = round(profit, 4) # 保留4位小数
else:
profit_dict[symbol] = 0.0
return profit_dict
def get_today_init_position(today_date: str, modelname: str) -> Dict[str, float]:
"""
获取今日开盘时的初始持仓(即文件中上一个交易日代表的持仓)。从../data/agent_data/{modelname}/position/position.jsonl中读取。
如果同一日期有多条记录选择id最大的记录作为初始持仓。
Args:
today_date: 日期字符串,格式 YYYY-MM-DD代表今天日期。
modelname: 模型名称,用于构建文件路径。
Returns:
{symbol: weight} 的字典;若未找到对应日期,则返回空字典。
"""
base_dir = Path(__file__).resolve().parents[1]
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
if not position_file.exists():
print(f"Position file {position_file} does not exist")
return {}
yesterday_date = get_yesterday_date(today_date)
max_id = -1
latest_positions = {}
with position_file.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
doc = json.loads(line)
if doc.get("date") == yesterday_date:
current_id = doc.get("id", 0)
if current_id > max_id:
max_id = current_id
latest_positions = doc.get("positions", {})
except Exception:
continue
return latest_positions
def get_latest_position(today_date: str, modelname: str) -> Tuple[Dict[str, float], int]:
"""
获取最新持仓。从 ../data/agent_data/{modelname}/position/position.jsonl 中读取。
优先选择当天 (today_date) 中 id 最大的记录;
若当天无记录,则回退到上一个交易日,选择该日中 id 最大的记录。
Args:
today_date: 日期字符串,格式 YYYY-MM-DD代表今天日期。
modelname: 模型名称,用于构建文件路径。
Returns:
(positions, max_id):
- positions: {symbol: weight} 的字典;若未找到任何记录,则为空字典。
- max_id: 选中记录的最大 id若未找到任何记录则为 -1.
"""
base_dir = Path(__file__).resolve().parents[1]
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
if not position_file.exists():
return {}, -1
# 先尝试读取当天记录
max_id_today = -1
latest_positions_today: Dict[str, float] = {}
with position_file.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
doc = json.loads(line)
if doc.get("date") == today_date:
current_id = doc.get("id", -1)
if current_id > max_id_today:
max_id_today = current_id
latest_positions_today = doc.get("positions", {})
except Exception:
continue
if max_id_today >= 0:
return latest_positions_today, max_id_today
# 当天没有记录,则回退到上一个交易日
prev_date = get_yesterday_date(today_date)
max_id_prev = -1
latest_positions_prev: Dict[str, float] = {}
with position_file.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
doc = json.loads(line)
if doc.get("date") == prev_date:
current_id = doc.get("id", -1)
if current_id > max_id_prev:
max_id_prev = current_id
latest_positions_prev = doc.get("positions", {})
except Exception:
continue
return latest_positions_prev, max_id_prev
def add_no_trade_record(today_date: str, modelname: str):
"""
添加不交易记录。从 ../data/agent_data/{modelname}/position/position.jsonl 中前一日最后一条持仓并更新在今日的position.jsonl文件中。
Args:
today_date: 日期字符串,格式 YYYY-MM-DD代表今天日期。
modelname: 模型名称,用于构建文件路径。
Returns:
None
"""
save_item = {}
current_position, current_action_id = get_latest_position(today_date, modelname)
print(current_position, current_action_id)
save_item["date"] = today_date
save_item["id"] = current_action_id+1
save_item["this_action"] = {"action":"no_trade","symbol":"","amount":0}
save_item["positions"] = current_position
base_dir = Path(__file__).resolve().parents[1]
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
with position_file.open("a", encoding="utf-8") as f:
f.write(json.dumps(save_item) + "\n")
return
if __name__ == "__main__":
today_date = get_config_value("TODAY_DATE")
signature = get_config_value("SIGNATURE")
if signature is None:
raise ValueError("SIGNATURE environment variable is not set")
print(today_date, signature)
yesterday_date = get_yesterday_date(today_date)
# print(yesterday_date)
today_buy_price = get_open_prices(today_date, all_nasdaq_100_symbols)
# print(today_buy_price)
yesterday_buy_prices, yesterday_sell_prices = get_yesterday_open_and_close_price(today_date, all_nasdaq_100_symbols)
# print(yesterday_buy_prices)
# print(yesterday_sell_prices)
today_init_position = get_today_init_position(today_date, signature)
# print(today_init_position)
latest_position, latest_action_id = get_latest_position(today_date, signature)
print(latest_position, latest_action_id)
yesterday_profit = get_yesterday_profit(today_date, yesterday_buy_prices, yesterday_sell_prices, today_init_position)
# print(yesterday_profit)
add_no_trade_record(today_date, signature)