mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 01:27:24 -04:00
Previously, profit calculations compared portfolio value to the previous day's final value. This caused trades to appear as losses since buying stocks decreases cash and increases stock value equally (net zero change). Now profit calculations compare to the start-of-day portfolio value (action_id=0 for current date), which accurately reflects gains/losses from price movements and trading decisions. Changes: - agent_tools/tool_trade.py: Fixed profit calc in _buy_impl() and _sell_impl() - tools/price_tools.py: Fixed profit calc in add_no_trade_record_to_db() Test: test_profit_calculation_accuracy now passes
494 lines
18 KiB
Python
494 lines
18 KiB
Python
import os
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
import json
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple
|
||
import sys
|
||
|
||
# 将项目根目录加入 Python 路径,便于从子目录直接运行本文件
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
if project_root not in sys.path:
|
||
sys.path.insert(0, project_root)
|
||
from tools.general_tools import get_config_value
|
||
from api.database import get_db_connection
|
||
|
||
all_nasdaq_100_symbols = [
|
||
"NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA",
|
||
"NFLX", "PLTR", "COST", "ASML", "AMD", "CSCO", "AZN", "TMUS", "MU", "LIN",
|
||
"PEP", "SHOP", "APP", "INTU", "AMAT", "LRCX", "PDD", "QCOM", "ARM", "INTC",
|
||
"BKNG", "AMGN", "TXN", "ISRG", "GILD", "KLAC", "PANW", "ADBE", "HON",
|
||
"CRWD", "CEG", "ADI", "ADP", "DASH", "CMCSA", "VRTX", "MELI", "SBUX",
|
||
"CDNS", "ORLY", "SNPS", "MSTR", "MDLZ", "ABNB", "MRVL", "CTAS", "TRI",
|
||
"MAR", "MNST", "CSX", "ADSK", "PYPL", "FTNT", "AEP", "WDAY", "REGN", "ROP",
|
||
"NXPI", "DDOG", "AXON", "ROST", "IDXX", "EA", "PCAR", "FAST", "EXC", "TTWO",
|
||
"XEL", "ZS", "PAYX", "WBD", "BKR", "CPRT", "CCEP", "FANG", "TEAM", "CHTR",
|
||
"KDP", "MCHP", "GEHC", "VRSK", "CTSH", "CSGP", "KHC", "ODFL", "DXCM", "TTD",
|
||
"ON", "BIIB", "LULU", "CDW", "GFS"
|
||
]
|
||
|
||
def get_yesterday_date(today_date: str) -> str:
|
||
"""
|
||
获取昨日日期,考虑休市日。
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
|
||
Returns:
|
||
yesterday_date: 昨日日期字符串,格式 YYYY-MM-DD。
|
||
"""
|
||
# 计算昨日日期,考虑休市日
|
||
today_dt = datetime.strptime(today_date, "%Y-%m-%d")
|
||
yesterday_dt = today_dt - timedelta(days=1)
|
||
|
||
# 如果昨日是周末,向前找到最近的交易日
|
||
while yesterday_dt.weekday() >= 5: # 5=Saturday, 6=Sunday
|
||
yesterday_dt -= timedelta(days=1)
|
||
|
||
yesterday_date = yesterday_dt.strftime("%Y-%m-%d")
|
||
return yesterday_date
|
||
|
||
def get_open_prices(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Dict[str, Optional[float]]:
|
||
"""从 price_data 数据库表中读取指定日期与标的的开盘价。
|
||
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD。
|
||
symbols: 需要查询的股票代码列表。
|
||
merged_path: 已废弃,保留用于向后兼容。
|
||
db_path: 数据库路径,默认 data/jobs.db。
|
||
|
||
Returns:
|
||
{symbol_price: open_price 或 None} 的字典;若未找到对应日期或标的,则值为 None。
|
||
"""
|
||
results: Dict[str, Optional[float]] = {}
|
||
|
||
try:
|
||
conn = get_db_connection(db_path)
|
||
cursor = conn.cursor()
|
||
|
||
# Query all requested symbols for the date
|
||
placeholders = ','.join('?' * len(symbols))
|
||
query = f"""
|
||
SELECT symbol, open
|
||
FROM price_data
|
||
WHERE date = ? AND symbol IN ({placeholders})
|
||
"""
|
||
|
||
params = [today_date] + list(symbols)
|
||
cursor.execute(query, params)
|
||
|
||
# Build results dict
|
||
for row in cursor.fetchall():
|
||
symbol = row[0]
|
||
open_price = row[1]
|
||
results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
|
||
|
||
conn.close()
|
||
|
||
except Exception as e:
|
||
# Log error but return empty results to maintain compatibility
|
||
print(f"Error querying price data: {e}")
|
||
|
||
return results
|
||
|
||
def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None, db_path: str = "data/jobs.db") -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]:
|
||
"""从 price_data 数据库表中读取指定日期与股票的昨日买入价和卖出价。
|
||
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
symbols: 需要查询的股票代码列表。
|
||
merged_path: 已废弃,保留用于向后兼容。
|
||
db_path: 数据库路径,默认 data/jobs.db。
|
||
|
||
Returns:
|
||
(买入价字典, 卖出价字典) 的元组;若未找到对应日期或标的,则值为 None。
|
||
"""
|
||
buy_results: Dict[str, Optional[float]] = {}
|
||
sell_results: Dict[str, Optional[float]] = {}
|
||
|
||
yesterday_date = get_yesterday_date(today_date)
|
||
|
||
try:
|
||
conn = get_db_connection(db_path)
|
||
cursor = conn.cursor()
|
||
|
||
# Query all requested symbols for yesterday's date
|
||
placeholders = ','.join('?' * len(symbols))
|
||
query = f"""
|
||
SELECT symbol, open, close
|
||
FROM price_data
|
||
WHERE date = ? AND symbol IN ({placeholders})
|
||
"""
|
||
|
||
params = [yesterday_date] + list(symbols)
|
||
cursor.execute(query, params)
|
||
|
||
# Build results dicts
|
||
for row in cursor.fetchall():
|
||
symbol = row[0]
|
||
open_price = row[1] # Buy price (open)
|
||
close_price = row[2] # Sell price (close)
|
||
|
||
buy_results[f'{symbol}_price'] = float(open_price) if open_price is not None else None
|
||
sell_results[f'{symbol}_price'] = float(close_price) if close_price is not None else None
|
||
|
||
conn.close()
|
||
|
||
except Exception as e:
|
||
# Log error but return empty results to maintain compatibility
|
||
print(f"Error querying price data: {e}")
|
||
|
||
return buy_results, sell_results
|
||
|
||
def get_yesterday_profit(today_date: str, yesterday_buy_prices: Dict[str, Optional[float]], yesterday_sell_prices: Dict[str, Optional[float]], yesterday_init_position: Dict[str, float]) -> Dict[str, float]:
|
||
"""
|
||
获取今日开盘时持仓的收益,收益计算方式为:(昨日收盘价格 - 昨日开盘价格)*当前持仓。
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
yesterday_buy_prices: 昨日开盘价格字典,格式为 {symbol_price: price}
|
||
yesterday_sell_prices: 昨日收盘价格字典,格式为 {symbol_price: price}
|
||
yesterday_init_position: 昨日初始持仓字典,格式为 {symbol: weight}
|
||
|
||
Returns:
|
||
{symbol: profit} 的字典;若未找到对应日期或标的,则值为 0.0。
|
||
"""
|
||
profit_dict = {}
|
||
|
||
# 遍历所有股票代码
|
||
for symbol in all_nasdaq_100_symbols:
|
||
symbol_price_key = f'{symbol}_price'
|
||
|
||
# 获取昨日开盘价和收盘价
|
||
buy_price = yesterday_buy_prices.get(symbol_price_key)
|
||
sell_price = yesterday_sell_prices.get(symbol_price_key)
|
||
|
||
# 获取昨日持仓权重
|
||
position_weight = yesterday_init_position.get(symbol, 0.0)
|
||
|
||
# 计算收益:(收盘价 - 开盘价) * 持仓权重
|
||
if buy_price is not None and sell_price is not None and position_weight > 0:
|
||
profit = (sell_price - buy_price) * position_weight
|
||
profit_dict[symbol] = round(profit, 4) # 保留4位小数
|
||
else:
|
||
profit_dict[symbol] = 0.0
|
||
|
||
return profit_dict
|
||
|
||
def get_today_init_position(today_date: str, modelname: str) -> Dict[str, float]:
|
||
"""
|
||
获取今日开盘时的初始持仓(即文件中上一个交易日代表的持仓)。从../data/agent_data/{modelname}/position/position.jsonl中读取。
|
||
如果同一日期有多条记录,选择id最大的记录作为初始持仓。
|
||
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
modelname: 模型名称,用于构建文件路径。
|
||
|
||
Returns:
|
||
{symbol: weight} 的字典;若未找到对应日期,则返回空字典。
|
||
"""
|
||
base_dir = Path(__file__).resolve().parents[1]
|
||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||
|
||
if not position_file.exists():
|
||
print(f"Position file {position_file} does not exist")
|
||
return {}
|
||
|
||
yesterday_date = get_yesterday_date(today_date)
|
||
max_id = -1
|
||
latest_positions = {}
|
||
|
||
with position_file.open("r", encoding="utf-8") as f:
|
||
for line in f:
|
||
if not line.strip():
|
||
continue
|
||
try:
|
||
doc = json.loads(line)
|
||
if doc.get("date") == yesterday_date:
|
||
current_id = doc.get("id", 0)
|
||
if current_id > max_id:
|
||
max_id = current_id
|
||
latest_positions = doc.get("positions", {})
|
||
except Exception:
|
||
continue
|
||
|
||
return latest_positions
|
||
|
||
def get_latest_position(today_date: str, modelname: str) -> Tuple[Dict[str, float], int]:
|
||
"""
|
||
获取最新持仓。从 ../data/agent_data/{modelname}/position/position.jsonl 中读取。
|
||
优先选择当天 (today_date) 中 id 最大的记录;
|
||
若当天无记录,则回退到上一个交易日,选择该日中 id 最大的记录。
|
||
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
modelname: 模型名称,用于构建文件路径。
|
||
|
||
Returns:
|
||
(positions, max_id):
|
||
- positions: {symbol: weight} 的字典;若未找到任何记录,则为空字典。
|
||
- max_id: 选中记录的最大 id;若未找到任何记录,则为 -1.
|
||
"""
|
||
base_dir = Path(__file__).resolve().parents[1]
|
||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||
|
||
if not position_file.exists():
|
||
return {}, -1
|
||
|
||
# 先尝试读取当天记录
|
||
max_id_today = -1
|
||
latest_positions_today: Dict[str, float] = {}
|
||
|
||
with position_file.open("r", encoding="utf-8") as f:
|
||
for line in f:
|
||
if not line.strip():
|
||
continue
|
||
try:
|
||
doc = json.loads(line)
|
||
if doc.get("date") == today_date:
|
||
current_id = doc.get("id", -1)
|
||
if current_id > max_id_today:
|
||
max_id_today = current_id
|
||
latest_positions_today = doc.get("positions", {})
|
||
except Exception:
|
||
continue
|
||
|
||
if max_id_today >= 0:
|
||
return latest_positions_today, max_id_today
|
||
|
||
# 当天没有记录,则回退到上一个交易日
|
||
prev_date = get_yesterday_date(today_date)
|
||
max_id_prev = -1
|
||
latest_positions_prev: Dict[str, float] = {}
|
||
|
||
with position_file.open("r", encoding="utf-8") as f:
|
||
for line in f:
|
||
if not line.strip():
|
||
continue
|
||
try:
|
||
doc = json.loads(line)
|
||
if doc.get("date") == prev_date:
|
||
current_id = doc.get("id", -1)
|
||
if current_id > max_id_prev:
|
||
max_id_prev = current_id
|
||
latest_positions_prev = doc.get("positions", {})
|
||
except Exception:
|
||
continue
|
||
|
||
return latest_positions_prev, max_id_prev
|
||
|
||
def add_no_trade_record(today_date: str, modelname: str):
|
||
"""
|
||
添加不交易记录。从 ../data/agent_data/{modelname}/position/position.jsonl 中前一日最后一条持仓,并更新在今日的position.jsonl文件中。
|
||
Args:
|
||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||
modelname: 模型名称,用于构建文件路径。
|
||
|
||
Returns:
|
||
None
|
||
"""
|
||
save_item = {}
|
||
current_position, current_action_id = get_latest_position(today_date, modelname)
|
||
print(current_position, current_action_id)
|
||
save_item["date"] = today_date
|
||
save_item["id"] = current_action_id+1
|
||
save_item["this_action"] = {"action":"no_trade","symbol":"","amount":0}
|
||
|
||
save_item["positions"] = current_position
|
||
base_dir = Path(__file__).resolve().parents[1]
|
||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||
|
||
with position_file.open("a", encoding="utf-8") as f:
|
||
f.write(json.dumps(save_item) + "\n")
|
||
return
|
||
|
||
|
||
def get_today_init_position_from_db(
|
||
today_date: str,
|
||
modelname: str,
|
||
job_id: str
|
||
) -> Dict[str, float]:
|
||
"""
|
||
Query yesterday's position from SQLite database.
|
||
|
||
Args:
|
||
today_date: Current trading date (YYYY-MM-DD)
|
||
modelname: Model signature
|
||
job_id: Job UUID
|
||
|
||
Returns:
|
||
Position dict: {"AAPL": 50, "MSFT": 30, "CASH": 5000.0}
|
||
If no position exists: {"CASH": 10000.0} (initial cash)
|
||
"""
|
||
import logging
|
||
from api.database import get_db_connection
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
db_path = "data/jobs.db"
|
||
conn = get_db_connection(db_path)
|
||
cursor = conn.cursor()
|
||
|
||
try:
|
||
# Get most recent position before today
|
||
cursor.execute("""
|
||
SELECT p.id, p.cash
|
||
FROM positions p
|
||
WHERE p.job_id = ? AND p.model = ? AND p.date < ?
|
||
ORDER BY p.date DESC, p.action_id DESC
|
||
LIMIT 1
|
||
""", (job_id, modelname, today_date))
|
||
|
||
row = cursor.fetchone()
|
||
|
||
if not row:
|
||
# First day - return initial cash
|
||
logger.info(f"No previous position found for {modelname}, returning initial cash")
|
||
return {"CASH": 10000.0}
|
||
|
||
position_id, cash = row
|
||
position_dict = {"CASH": cash}
|
||
|
||
# Get holdings for this position
|
||
cursor.execute("""
|
||
SELECT symbol, quantity
|
||
FROM holdings
|
||
WHERE position_id = ?
|
||
""", (position_id,))
|
||
|
||
for symbol, quantity in cursor.fetchall():
|
||
position_dict[symbol] = quantity
|
||
|
||
logger.debug(f"Loaded position for {modelname}: {position_dict}")
|
||
return position_dict
|
||
|
||
except Exception as e:
|
||
logger.error(f"Database error in get_today_init_position_from_db: {e}")
|
||
raise
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def add_no_trade_record_to_db(
|
||
today_date: str,
|
||
modelname: str,
|
||
job_id: str,
|
||
session_id: int
|
||
) -> None:
|
||
"""
|
||
Create no-trade position record in SQLite database.
|
||
|
||
Args:
|
||
today_date: Current trading date (YYYY-MM-DD)
|
||
modelname: Model signature
|
||
job_id: Job UUID
|
||
session_id: Trading session ID
|
||
"""
|
||
import logging
|
||
from api.database import get_db_connection
|
||
from agent_tools.tool_trade import get_current_position_from_db
|
||
from datetime import datetime
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
db_path = "data/jobs.db"
|
||
conn = get_db_connection(db_path)
|
||
cursor = conn.cursor()
|
||
|
||
try:
|
||
# Get current position
|
||
current_position, next_action_id = get_current_position_from_db(
|
||
job_id, modelname, today_date
|
||
)
|
||
|
||
# Calculate portfolio value
|
||
cash = current_position.get("CASH", 0.0)
|
||
portfolio_value = cash
|
||
|
||
# Add stock values
|
||
for symbol, qty in current_position.items():
|
||
if symbol != "CASH":
|
||
try:
|
||
price = get_open_prices(today_date, [symbol])[f'{symbol}_price']
|
||
portfolio_value += qty * price
|
||
except KeyError:
|
||
logger.warning(f"Price not found for {symbol} on {today_date}")
|
||
pass
|
||
|
||
# Get start-of-day portfolio value (action_id=0 for today) for P&L calculation
|
||
cursor.execute("""
|
||
SELECT portfolio_value
|
||
FROM positions
|
||
WHERE job_id = ? AND model = ? AND date = ? AND action_id = 0
|
||
LIMIT 1
|
||
""", (job_id, modelname, today_date))
|
||
|
||
row = cursor.fetchone()
|
||
|
||
if row:
|
||
# Compare to start of day (action_id=0)
|
||
start_of_day_value = row[0]
|
||
daily_profit = portfolio_value - start_of_day_value
|
||
daily_return_pct = (daily_profit / start_of_day_value * 100) if start_of_day_value > 0 else 0
|
||
else:
|
||
# First action of first day - no baseline yet
|
||
daily_profit = 0.0
|
||
daily_return_pct = 0.0
|
||
|
||
# Insert position record
|
||
created_at = datetime.utcnow().isoformat() + "Z"
|
||
|
||
cursor.execute("""
|
||
INSERT INTO positions (
|
||
job_id, date, model, action_id, action_type,
|
||
cash, portfolio_value, daily_profit, daily_return_pct,
|
||
session_id, created_at
|
||
)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""", (
|
||
job_id, today_date, modelname, next_action_id, "no_trade",
|
||
cash, portfolio_value, daily_profit, daily_return_pct,
|
||
session_id, created_at
|
||
))
|
||
|
||
position_id = cursor.lastrowid
|
||
|
||
# Insert holdings (unchanged from previous position)
|
||
for symbol, qty in current_position.items():
|
||
if symbol != "CASH":
|
||
cursor.execute("""
|
||
INSERT INTO holdings (position_id, symbol, quantity)
|
||
VALUES (?, ?, ?)
|
||
""", (position_id, symbol, qty))
|
||
|
||
conn.commit()
|
||
logger.info(f"Created no-trade record for {modelname} on {today_date}")
|
||
|
||
except Exception as e:
|
||
conn.rollback()
|
||
logger.error(f"Database error in add_no_trade_record_to_db: {e}")
|
||
raise
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
today_date = get_config_value("TODAY_DATE")
|
||
signature = get_config_value("SIGNATURE")
|
||
if signature is None:
|
||
raise ValueError("SIGNATURE environment variable is not set")
|
||
print(today_date, signature)
|
||
yesterday_date = get_yesterday_date(today_date)
|
||
# print(yesterday_date)
|
||
today_buy_price = get_open_prices(today_date, all_nasdaq_100_symbols)
|
||
# print(today_buy_price)
|
||
yesterday_buy_prices, yesterday_sell_prices = get_yesterday_open_and_close_price(today_date, all_nasdaq_100_symbols)
|
||
# print(yesterday_buy_prices)
|
||
# print(yesterday_sell_prices)
|
||
today_init_position = get_today_init_position(today_date, signature)
|
||
# print(today_init_position)
|
||
latest_position, latest_action_id = get_latest_position(today_date, signature)
|
||
print(latest_position, latest_action_id)
|
||
yesterday_profit = get_yesterday_profit(today_date, yesterday_buy_prices, yesterday_sell_prices, today_init_position)
|
||
# print(yesterday_profit)
|
||
add_no_trade_record(today_date, signature)
|