mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 17:37:24 -04:00
Final phase of v0.3.0 implementation - all core features complete. Price Tools Migration: - Update get_open_prices() to query price_data table - Update get_yesterday_open_and_close_price() to query database - Remove merged.jsonl file I/O (replaced with SQLite queries) - Maintain backward-compatible function signatures - Add db_path parameter (default: data/jobs.db) Configuration: - Add AUTO_DOWNLOAD_PRICE_DATA to .env.example (default: true) - Add MAX_SIMULATION_DAYS to .env.example (default: 30) - Document new configuration options Documentation: - Comprehensive CHANGELOG updates for v0.3.0 - Document all breaking changes (API format, data storage, config) - Document new features (on-demand downloads, date ranges, database) - Document migration path (scripts/migrate_price_data.py) - Clear upgrade instructions Breaking Changes (v0.3.0): 1. API request format: date_range -> start_date/end_date 2. Data storage: merged.jsonl -> price_data table 3. Config variables: removed RUNTIME_ENV_PATH, MCP ports, WEB_HTTP_PORT 4. Added AUTO_DOWNLOAD_PRICE_DATA, MAX_SIMULATION_DAYS Migration Steps: 1. Run: python scripts/migrate_price_data.py 2. Update API clients to use new date format 3. Update .env with new variables 4. Remove old config variables Status: v0.3.0 implementation complete Ready for: Testing, deployment, and release
324 lines
13 KiB
Python
324 lines
13 KiB
Python
import os
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
import json
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple
|
||
import sys
|
||
|
||
# 将项目根目录加入 Python 路径,便于从子目录直接运行本文件
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
if project_root not in sys.path:
|
||
sys.path.insert(0, project_root)
|
||
from tools.general_tools import get_config_value
|
||
from api.database import get_db_connection
|
||
|
||
all_nasdaq_100_symbols = [
|
||
"NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA",
|
||
"NFLX", "PLTR", "COST", "ASML", "AMD", "CSCO", "AZN", "TMUS", "MU", "LIN",
|
||
"PEP", "SHOP", "APP", "INTU", "AMAT", "LRCX", "PDD", "QCOM", "ARM", "INTC",
|
||
"BKNG", "AMGN", "TXN", "ISRG", "GILD", "KLAC", "PANW", "ADBE", "HON",
|
||
"CRWD", "CEG", "ADI", "ADP", "DASH", "CMCSA", "VRTX", "MELI", "SBUX",
|
||
"CDNS", "ORLY", "SNPS", "MSTR", "MDLZ", "ABNB", "MRVL", "CTAS", "TRI",
|
||
"MAR", "MNST", "CSX", "ADSK", "PYPL", "FTNT", "AEP", "WDAY", "REGN", "ROP",
|
||
"NXPI", "DDOG", "AXON", "ROST", "IDXX", "EA", "PCAR", "FAST", "EXC", "TTWO",
|
||
"XEL", "ZS", "PAYX", "WBD", "BKR", "CPRT", "CCEP", "FANG", "TEAM", "CHTR",
|
||
"KDP", "MCHP", "GEHC", "VRSK", "CTSH", "CSGP", "KHC", "ODFL", "DXCM", "TTD",
|
||
"ON", "BIIB", "LULU", "CDW", "GFS"
|
||
]
|
||
|
||
def get_yesterday_date(today_date: str) -> str:
|
||
"""
|
||
获取昨日日期,考虑休市日。
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
|
||
Returns:
|
||
yesterday_date: 昨日日期字符串,格式 YYYY-MM-DD。
|
||
"""
|
||
# 计算昨日日期,考虑休市日
|
||
today_dt = datetime.strptime(today_date, "%Y-%m-%d")
|
||
yesterday_dt = today_dt - timedelta(days=1)
|
||
|
||
# 如果昨日是周末,向前找到最近的交易日
|
||
while yesterday_dt.weekday() >= 5: # 5=Saturday, 6=Sunday
|
||
yesterday_dt -= timedelta(days=1)
|
||
|
||
yesterday_date = yesterday_dt.strftime("%Y-%m-%d")
|
||
return yesterday_date
|
||
|
||
def get_open_prices(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Dict[str, Optional[float]]:
|
||
"""从 price_data 数据库表中读取指定日期与标的的开盘价。
|
||
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD。
|
||
symbols: 需要查询的股票代码列表。
|
||
merged_path: 已废弃,保留用于向后兼容。
|
||
db_path: 数据库路径,默认 data/jobs.db。
|
||
|
||
Returns:
|
||
{symbol_price: open_price 或 None} 的字典;若未找到对应日期或标的,则值为 None。
|
||
"""
|
||
results: Dict[str, Optional[float]] = {}
|
||
|
||
try:
|
||
conn = get_db_connection(db_path)
|
||
cursor = conn.cursor()
|
||
|
||
# Query all requested symbols for the date
|
||
placeholders = ','.join('?' * len(symbols))
|
||
query = f"""
|
||
SELECT symbol, open
|
||
FROM price_data
|
||
WHERE date = ? AND symbol IN ({placeholders})
|
||
"""
|
||
|
||
params = [today_date] + list(symbols)
|
||
cursor.execute(query, params)
|
||
|
||
# Build results dict
|
||
for row in cursor.fetchall():
|
||
symbol = row[0]
|
||
open_price = row[1]
|
||
results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
|
||
|
||
conn.close()
|
||
|
||
except Exception as e:
|
||
# Log error but return empty results to maintain compatibility
|
||
print(f"Error querying price data: {e}")
|
||
|
||
return results
|
||
|
||
def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]:
|
||
"""从 price_data 数据库表中读取指定日期与股票的昨日买入价和卖出价。
|
||
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
symbols: 需要查询的股票代码列表。
|
||
merged_path: 已废弃,保留用于向后兼容。
|
||
db_path: 数据库路径,默认 data/jobs.db。
|
||
|
||
Returns:
|
||
(买入价字典, 卖出价字典) 的元组;若未找到对应日期或标的,则值为 None。
|
||
"""
|
||
buy_results: Dict[str, Optional[float]] = {}
|
||
sell_results: Dict[str, Optional[float]] = {}
|
||
|
||
yesterday_date = get_yesterday_date(today_date)
|
||
|
||
try:
|
||
conn = get_db_connection(db_path)
|
||
cursor = conn.cursor()
|
||
|
||
# Query all requested symbols for yesterday's date
|
||
placeholders = ','.join('?' * len(symbols))
|
||
query = f"""
|
||
SELECT symbol, open, close
|
||
FROM price_data
|
||
WHERE date = ? AND symbol IN ({placeholders})
|
||
"""
|
||
|
||
params = [yesterday_date] + list(symbols)
|
||
cursor.execute(query, params)
|
||
|
||
# Build results dicts
|
||
for row in cursor.fetchall():
|
||
symbol = row[0]
|
||
open_price = row[1] # Buy price (open)
|
||
close_price = row[2] # Sell price (close)
|
||
|
||
buy_results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
|
||
sell_results[f'{symbol}_price'] = float(close_price) if close_price is not None else None
|
||
|
||
conn.close()
|
||
|
||
except Exception as e:
|
||
# Log error but return empty results to maintain compatibility
|
||
print(f"Error querying price data: {e}")
|
||
|
||
return buy_results, sell_results
|
||
|
||
def get_yesterday_profit(today_date: str, yesterday_buy_prices: Dict[str, Optional[float]], yesterday_sell_prices: Dict[str, Optional[float]], yesterday_init_position: Dict[str, float]) -> Dict[str, float]:
|
||
"""
|
||
获取今日开盘时持仓的收益,收益计算方式为:(昨日收盘价格 - 昨日开盘价格)*当前持仓。
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
yesterday_buy_prices: 昨日开盘价格字典,格式为 {symbol_price: price}
|
||
yesterday_sell_prices: 昨日收盘价格字典,格式为 {symbol_price: price}
|
||
yesterday_init_position: 昨日初始持仓字典,格式为 {symbol: weight}
|
||
|
||
Returns:
|
||
{symbol: profit} 的字典;若未找到对应日期或标的,则值为 0.0。
|
||
"""
|
||
profit_dict = {}
|
||
|
||
# 遍历所有股票代码
|
||
for symbol in all_nasdaq_100_symbols:
|
||
symbol_price_key = f'{symbol}_price'
|
||
|
||
# 获取昨日开盘价和收盘价
|
||
buy_price = yesterday_buy_prices.get(symbol_price_key)
|
||
sell_price = yesterday_sell_prices.get(symbol_price_key)
|
||
|
||
# 获取昨日持仓权重
|
||
position_weight = yesterday_init_position.get(symbol, 0.0)
|
||
|
||
# 计算收益:(收盘价 - 开盘价) * 持仓权重
|
||
if buy_price is not None and sell_price is not None and position_weight > 0:
|
||
profit = (sell_price - buy_price) * position_weight
|
||
profit_dict[symbol] = round(profit, 4) # 保留4位小数
|
||
else:
|
||
profit_dict[symbol] = 0.0
|
||
|
||
return profit_dict
|
||
|
||
def get_today_init_position(today_date: str, modelname: str) -> Dict[str, float]:
|
||
"""
|
||
获取今日开盘时的初始持仓(即文件中上一个交易日代表的持仓)。从../data/agent_data/{modelname}/position/position.jsonl中读取。
|
||
如果同一日期有多条记录,选择id最大的记录作为初始持仓。
|
||
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
modelname: 模型名称,用于构建文件路径。
|
||
|
||
Returns:
|
||
{symbol: weight} 的字典;若未找到对应日期,则返回空字典。
|
||
"""
|
||
base_dir = Path(__file__).resolve().parents[1]
|
||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||
|
||
if not position_file.exists():
|
||
print(f"Position file {position_file} does not exist")
|
||
return {}
|
||
|
||
yesterday_date = get_yesterday_date(today_date)
|
||
max_id = -1
|
||
latest_positions = {}
|
||
|
||
with position_file.open("r", encoding="utf-8") as f:
|
||
for line in f:
|
||
if not line.strip():
|
||
continue
|
||
try:
|
||
doc = json.loads(line)
|
||
if doc.get("date") == yesterday_date:
|
||
current_id = doc.get("id", 0)
|
||
if current_id > max_id:
|
||
max_id = current_id
|
||
latest_positions = doc.get("positions", {})
|
||
except Exception:
|
||
continue
|
||
|
||
return latest_positions
|
||
|
||
def get_latest_position(today_date: str, modelname: str) -> Tuple[Dict[str, float], int]:
|
||
"""
|
||
获取最新持仓。从 ../data/agent_data/{modelname}/position/position.jsonl 中读取。
|
||
优先选择当天 (today_date) 中 id 最大的记录;
|
||
若当天无记录,则回退到上一个交易日,选择该日中 id 最大的记录。
|
||
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
modelname: 模型名称,用于构建文件路径。
|
||
|
||
Returns:
|
||
(positions, max_id):
|
||
- positions: {symbol: weight} 的字典;若未找到任何记录,则为空字典。
|
||
- max_id: 选中记录的最大 id;若未找到任何记录,则为 -1.
|
||
"""
|
||
base_dir = Path(__file__).resolve().parents[1]
|
||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||
|
||
if not position_file.exists():
|
||
return {}, -1
|
||
|
||
# 先尝试读取当天记录
|
||
max_id_today = -1
|
||
latest_positions_today: Dict[str, float] = {}
|
||
|
||
with position_file.open("r", encoding="utf-8") as f:
|
||
for line in f:
|
||
if not line.strip():
|
||
continue
|
||
try:
|
||
doc = json.loads(line)
|
||
if doc.get("date") == today_date:
|
||
current_id = doc.get("id", -1)
|
||
if current_id > max_id_today:
|
||
max_id_today = current_id
|
||
latest_positions_today = doc.get("positions", {})
|
||
except Exception:
|
||
continue
|
||
|
||
if max_id_today >= 0:
|
||
return latest_positions_today, max_id_today
|
||
|
||
# 当天没有记录,则回退到上一个交易日
|
||
prev_date = get_yesterday_date(today_date)
|
||
max_id_prev = -1
|
||
latest_positions_prev: Dict[str, float] = {}
|
||
|
||
with position_file.open("r", encoding="utf-8") as f:
|
||
for line in f:
|
||
if not line.strip():
|
||
continue
|
||
try:
|
||
doc = json.loads(line)
|
||
if doc.get("date") == prev_date:
|
||
current_id = doc.get("id", -1)
|
||
if current_id > max_id_prev:
|
||
max_id_prev = current_id
|
||
latest_positions_prev = doc.get("positions", {})
|
||
except Exception:
|
||
continue
|
||
|
||
return latest_positions_prev, max_id_prev
|
||
|
||
def add_no_trade_record(today_date: str, modelname: str):
|
||
"""
|
||
添加不交易记录。从 ../data/agent_data/{modelname}/position/position.jsonl 中前一日最后一条持仓,并更新在今日的position.jsonl文件中。
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
modelname: 模型名称,用于构建文件路径。
|
||
|
||
Returns:
|
||
None
|
||
"""
|
||
save_item = {}
|
||
current_position, current_action_id = get_latest_position(today_date, modelname)
|
||
print(current_position, current_action_id)
|
||
save_item["date"] = today_date
|
||
save_item["id"] = current_action_id+1
|
||
save_item["this_action"] = {"action":"no_trade","symbol":"","amount":0}
|
||
|
||
save_item["positions"] = current_position
|
||
base_dir = Path(__file__).resolve().parents[1]
|
||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||
|
||
with position_file.open("a", encoding="utf-8") as f:
|
||
f.write(json.dumps(save_item) + "\n")
|
||
return
|
||
|
||
if __name__ == "__main__":
|
||
today_date = get_config_value("TODAY_DATE")
|
||
signature = get_config_value("SIGNATURE")
|
||
if signature is None:
|
||
raise ValueError("SIGNATURE environment variable is not set")
|
||
print(today_date, signature)
|
||
yesterday_date = get_yesterday_date(today_date)
|
||
# print(yesterday_date)
|
||
today_buy_price = get_open_prices(today_date, all_nasdaq_100_symbols)
|
||
# print(today_buy_price)
|
||
yesterday_buy_prices, yesterday_sell_prices = get_yesterday_open_and_close_price(today_date, all_nasdaq_100_symbols)
|
||
# print(yesterday_buy_prices)
|
||
# print(yesterday_sell_prices)
|
||
today_init_position = get_today_init_position(today_date, signature)
|
||
# print(today_init_position)
|
||
latest_position, latest_action_id = get_latest_position(today_date, signature)
|
||
print(latest_position, latest_action_id)
|
||
yesterday_profit = get_yesterday_profit(today_date, yesterday_buy_prices, yesterday_sell_prices, today_init_position)
|
||
# print(yesterday_profit)
|
||
add_no_trade_record(today_date, signature)
|